import asyncio import logging import re from typing import Literal from helia.configuration import HeliaConfig from helia.db import init_db from helia.ingestion.s3 import S3DatasetLoader from helia.models.transcript import Transcript, Utterance logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def normalize_speaker(speaker: str) -> Literal["Interviewer", "Participant"]: speaker_lower = speaker.lower().strip() if speaker_lower == "ellie": return "Interviewer" if speaker_lower == "participant": return "Participant" raise ValueError(f"Unknown speaker: {speaker}") async def process_transcript( s3_loader: S3DatasetLoader, key: str, id_pattern: re.Pattern[str] ) -> bool: """Process a single transcript and return True if successful.""" if not (match := id_pattern.search(key)): logger.warning("Skipping unexpected file: %s", key) return False transcript_id = match.group(1) logger.info("Processing transcript %s...", transcript_id) try: raw_data = await s3_loader.load_transcript(key) utterances = [ Utterance( start_time=float(row["start_time"]), end_time=float(row["stop_time"]), speaker=normalize_speaker(row["speaker"]), value=row["value"].strip(), ) for row in raw_data ] if not utterances: logger.warning("No utterances found for transcript %s", transcript_id) return False await Transcript(transcript_id=transcript_id, utterances=utterances).insert() except Exception: logger.exception("Failed to process %s", key) return False else: return True async def migrate() -> None: logger.info("Starting initialization migration...") config = HeliaConfig() # type: ignore[call-arg] await init_db(config) s3_loader = S3DatasetLoader(config.s3) keys = await s3_loader.list_transcripts() logger.info("%d transcript files in S3.", len(keys)) id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$") results = await asyncio.gather( *(process_transcript(s3_loader, key, id_pattern) for key in keys), return_exceptions=True, ) count = sum(1 for success in results if success is True) logger.info("Successfully processed %d transcripts.", count) if __name__ == "__main__": asyncio.run(migrate())