86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
import asyncio
|
|
import logging
|
|
import re
|
|
from typing import Literal
|
|
|
|
from helia.configuration import HeliaConfig
|
|
from helia.db import init_db
|
|
from helia.ingestion.s3 import S3DatasetLoader
|
|
from helia.models.transcript import Transcript, Utterance
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def normalize_speaker(speaker: str) -> Literal["Interviewer", "Participant"]:
|
|
speaker_lower = speaker.lower().strip()
|
|
if speaker_lower == "ellie":
|
|
return "Interviewer"
|
|
if speaker_lower == "participant":
|
|
return "Participant"
|
|
raise ValueError(f"Unknown speaker: {speaker}")
|
|
|
|
|
|
async def process_transcript(
|
|
s3_loader: S3DatasetLoader, key: str, id_pattern: re.Pattern[str]
|
|
) -> bool:
|
|
"""Process a single transcript and return True if successful."""
|
|
if not (match := id_pattern.search(key)):
|
|
logger.warning("Skipping unexpected file: %s", key)
|
|
return False
|
|
|
|
transcript_id = match.group(1)
|
|
logger.info("Processing transcript %s...", transcript_id)
|
|
|
|
try:
|
|
raw_data = await s3_loader.load_transcript(key)
|
|
utterances = [
|
|
Utterance(
|
|
start_time=float(row["start_time"]),
|
|
end_time=float(row["stop_time"]),
|
|
speaker=normalize_speaker(row["speaker"]),
|
|
value=row["value"].strip(),
|
|
)
|
|
for row in raw_data
|
|
]
|
|
|
|
if not utterances:
|
|
logger.warning("No utterances found for transcript %s", transcript_id)
|
|
return False
|
|
|
|
await Transcript(transcript_id=transcript_id, utterances=utterances).insert()
|
|
|
|
except Exception:
|
|
logger.exception("Failed to process %s", key)
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
async def migrate() -> None:
|
|
logger.info("Starting initialization migration...")
|
|
|
|
config = HeliaConfig() # type: ignore[call-arg]
|
|
await init_db(config)
|
|
|
|
s3_loader = S3DatasetLoader(config.s3)
|
|
keys = await s3_loader.list_transcripts()
|
|
logger.info("Found %d transcript files in S3.", len(keys))
|
|
|
|
await Transcript.delete_all()
|
|
logger.info("Cleared existing transcripts.")
|
|
|
|
id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$")
|
|
|
|
results = await asyncio.gather(
|
|
*(process_transcript(s3_loader, key, id_pattern) for key in keys),
|
|
return_exceptions=True,
|
|
)
|
|
|
|
count = sum(1 for success in results if success is True)
|
|
logger.info("Successfully processed %d transcripts.", count)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(migrate())
|