diff --git a/example.config.yaml b/example.config.yaml index eeaef95..8699efb 100644 --- a/example.config.yaml +++ b/example.config.yaml @@ -1,7 +1,9 @@ # Helia Application Configuration # Copy this file to config.yaml and adjust values as needed. +log_level: "INFO" patient_limit: 5 +concurrency_limit: 1 mongo: uri: "mongodb://localhost:27017" diff --git a/src/helia/assessment/core.py b/src/helia/assessment/core.py index f477464..44b8929 100644 --- a/src/helia/assessment/core.py +++ b/src/helia/assessment/core.py @@ -2,12 +2,16 @@ from __future__ import annotations from typing import TYPE_CHECKING, cast -from helia.assessment.schema import AssessmentResponse, AssessmentResult, RunConfig -from helia.ingestion.parser import TranscriptParser +from helia.assessment.schema import ( + AssessmentResponse, + AssessmentResult, + EvaluatorConfig, +) from helia.llm.client import get_chat_model if TYPE_CHECKING: - from pathlib import Path + from helia.configuration import RunConfig + from helia.models.transcript import Transcript # PHQ-8 Scoring Constants DIAGNOSIS_THRESHOLD = 10 @@ -51,24 +55,23 @@ TRANSCRIPT: class PHQ8Evaluator: - def __init__(self, config: RunConfig) -> None: + def __init__(self, config: EvaluatorConfig) -> None: self.config = config - self.parser = TranscriptParser() self.llm = get_chat_model( model_name=self.config.model_name, api_key=self.config.api_key, - base_url=self.config.api_base or "", + base_url=self.config.api_base, temperature=self.config.temperature, ) - async def evaluate(self, file_path: Path) -> AssessmentResult: + async def evaluate( + self, transcript: Transcript, original_run_config: RunConfig + ) -> AssessmentResult: """ Asynchronously evaluate a transcript using the configured LLM. """ - # 1. Parse Transcript - utterances = self.parser.parse(file_path) - transcript_text = "\n".join([f"{u.speaker}: {u.text}" for u in utterances]) + transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances]) # 2. Prepare Prompt final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text) @@ -104,8 +107,8 @@ class PHQ8Evaluator: diagnosis_algorithm = "Other Depression" return AssessmentResult( - transcript_id=file_path.stem, - config=self.config, + transcript_id=transcript.transcript_id, + config=original_run_config, items=items, total_score=total_score, diagnosis_algorithm=diagnosis_algorithm, diff --git a/src/helia/assessment/schema.py b/src/helia/assessment/schema.py index c248506..3199cca 100644 --- a/src/helia/assessment/schema.py +++ b/src/helia/assessment/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from beanie import Document @@ -13,6 +15,17 @@ class Evidence(BaseModel): reasoning: str +class EvaluatorConfig(BaseModel): + """Configuration required to initialize the Evaluator.""" + + model_name: str + api_base: str + api_key: str | None = None + api_spec: str = "openai" + temperature: float = 0.0 + prompt_id: str = "default" + + class PHQ8Item(BaseModel): question_id: int question_text: str diff --git a/src/helia/configuration.py b/src/helia/configuration.py index 84e97f0..ea7d0bb 100644 --- a/src/helia/configuration.py +++ b/src/helia/configuration.py @@ -61,6 +61,7 @@ class HeliaConfig(BaseSettings): extra="ignore", ) + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" patient_limit: int = 5 concurrency_limit: int = 1 mongo: MongoConfig diff --git a/src/helia/main.py b/src/helia/main.py index 036e2ea..4254471 100644 --- a/src/helia/main.py +++ b/src/helia/main.py @@ -1,87 +1,72 @@ import argparse import asyncio import logging -from datetime import UTC, datetime from pathlib import Path from helia.assessment.core import PHQ8Evaluator -from helia.assessment.schema import RunConfig +from helia.assessment.schema import EvaluatorConfig from helia.configuration import ( HeliaConfig, - RunSpec, - S3Config, - load_config, + RunConfig, ) from helia.db import init_db -from helia.ingestion.s3 import S3DatasetLoader -from helia.preflight import check_all_connections +from helia.models.transcript import Transcript logger = logging.getLogger(__name__) +def load_config(path: Path) -> HeliaConfig: + """Load configuration from the specified path.""" + try: + return HeliaConfig(_yaml_file=str(path)) # ty:ignore[missing-argument, unknown-argument] + except Exception: + return HeliaConfig() # ty:ignore[missing-argument] + + async def process_run( - run_spec: RunSpec, - input_source: str, + run_name: str, + run_config: RunConfig, + transcript: Transcript, config: HeliaConfig, - s3_config: S3Config, semaphore: asyncio.Semaphore, ) -> None: """ Process a single run for a single transcript, bounded by a semaphore. """ async with semaphore: - # Resolve Provider - provider_name = run_spec.model.provider - if provider_name not in config.providers: - logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name) - return + provider_config = config.providers[run_config.model.provider] - provider_config = config.providers[provider_name] + full_run_name = f"{run_name}_{transcript.transcript_id}" - # Download from S3 (Async) - loader = S3DatasetLoader(s3_config) - local_file = Path("data/downloads") / input_source - if not local_file.exists(): - await loader.download_file_async(input_source, local_file) + logger.info("--- Processing: %s ---", full_run_name) - input_path = local_file - item_id = Path(input_source).stem.split("_")[0] # Extract 300 from 300_TRANSCRIPT - run_name = f"{run_spec.run_name}_{item_id}" - - logger.info("--- Processing: %s ---", run_name) - - run_config = RunConfig( - model_name=run_spec.model.model_name, + eval_config = EvaluatorConfig( + model_name=run_config.model.model_name, api_base=provider_config.api_base, api_key=provider_config.api_key, - api_format=provider_config.api_format, - prompt_id=run_spec.prompt_id, - temperature=run_spec.model.temperature, - timestamp=datetime.now(tz=UTC).isoformat(), + api_spec=provider_config.api_spec, + prompt_id=run_config.prompt_id, + temperature=run_config.model.temperature, ) - # DEBUG LOGGING masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None" - logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base) + logger.debug("Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base) try: - evaluator = PHQ8Evaluator(run_config) - # Await the async evaluation - result = await evaluator.evaluate(input_path) - - # Save to DB (Async) + evaluator = PHQ8Evaluator(eval_config) + result = await evaluator.evaluate(transcript, run_config) await result.insert() - logger.info("Assessment complete for %s.", run_name) + logger.info("Assessment complete for %s.", full_run_name) logger.info( "ID: %s | Score: %s | Diagnosis: %s", - result.id, + result.transcript_id, result.total_score, result.diagnosis_algorithm, ) except Exception: - logger.exception("Failed to process %s", run_name) + logger.exception("Failed to process %s", full_run_name) async def main() -> None: @@ -114,42 +99,31 @@ async def main() -> None: logger.exception("Error loading configuration") return - # Run Pre-flight Checks - if not await check_all_connections(config): - logger.error("Pre-flight checks failed. Exiting.") - return - - # Ensure DB is initialized (redundant if pre-flight passed, but safe) await init_db(config) - # Create S3 config once and reuse - s3_config = config.get_s3_config() - - # Discover transcripts (can remain sync or be made async, sync is fine for listing) - logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name) - loader = S3DatasetLoader(s3_config) - keys = loader.list_transcripts() - - # Apply limit if configured (Priority: CLI > Config) - limit = args.limit - limit = limit if limit is not None else config.limit + logger.debug("Discovering transcripts in MongoDB...") + limit=args.limit or config.patient_limit + query = Transcript.find_all(limit=limit) if limit is not None: - logger.info("Limiting processing to first %d transcripts", limit) - keys = keys[:limit] + logger.debug("Limiting processing to first %d transcripts", limit) - # Create task list - tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys] - logger.info("Starting batch assessment with %d total items...", len(tasks_data)) + transcripts = await query.to_list() - # Limit concurrency to 10 parallel requests - semaphore = asyncio.Semaphore(10) + if not transcripts: + logger.warning("No transcripts found in database.") + return - tasks = [ - process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data + tasks_data = [ + (run_name, run_spec, t) for run_name, run_spec in config.runs.items() for t in transcripts ] - # Run all tasks concurrently + logger.info("Starting batch assessment with %d total items...", len(tasks_data)) + + semaphore = asyncio.Semaphore(config.concurrency_limit) + + tasks = [process_run(r_name, r_spec, t, config, semaphore) for r_name, r_spec, t in tasks_data] + await asyncio.gather(*tasks) logger.info("Batch assessment complete.")