Files
helia/migrations/init_transcripts.py

83 lines
2.4 KiB
Python

import asyncio
import logging
import re
from typing import Literal
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader
from helia.models.transcript import Transcript, Utterance
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def normalize_speaker(speaker: str) -> Literal["Interviewer", "Participant"]:
speaker_lower = speaker.lower().strip()
if speaker_lower == "ellie":
return "Interviewer"
if speaker_lower == "participant":
return "Participant"
raise ValueError(f"Unknown speaker: {speaker}")
async def process_transcript(
s3_loader: S3DatasetLoader, key: str, id_pattern: re.Pattern[str]
) -> bool:
"""Process a single transcript and return True if successful."""
if not (match := id_pattern.search(key)):
logger.warning("Skipping unexpected file: %s", key)
return False
transcript_id = match.group(1)
logger.info("Processing transcript %s...", transcript_id)
try:
raw_data = await s3_loader.load_transcript(key)
utterances = [
Utterance(
start_time=float(row["start_time"]),
end_time=float(row["stop_time"]),
speaker=normalize_speaker(row["speaker"]),
value=row["value"].strip(),
)
for row in raw_data
]
if not utterances:
logger.warning("No utterances found for transcript %s", transcript_id)
return False
await Transcript(transcript_id=transcript_id, utterances=utterances).insert()
except Exception:
logger.exception("Failed to process %s", key)
return False
else:
return True
async def migrate() -> None:
logger.info("Starting initialization migration...")
config = HeliaConfig() # type: ignore[call-arg]
await init_db(config)
s3_loader = S3DatasetLoader(config.s3)
keys = await s3_loader.list_transcripts()
logger.info("%d transcript files in S3.", len(keys))
id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$")
results = await asyncio.gather(
*(process_transcript(s3_loader, key, id_pattern) for key in keys),
return_exceptions=True,
)
count = sum(1 for success in results if success is True)
logger.info("Successfully processed %d transcripts.", count)
if __name__ == "__main__":
asyncio.run(migrate())