feat: enhance configuration management and refactor evaluator logic for improved clarity and functionality

This commit is contained in:
Santiago Martinez-Avial
2025-12-22 18:44:22 +01:00
parent 572f59e9ce
commit 8efc2b6217
5 changed files with 75 additions and 82 deletions

View File

@@ -1,7 +1,9 @@
# Helia Application Configuration # Helia Application Configuration
# Copy this file to config.yaml and adjust values as needed. # Copy this file to config.yaml and adjust values as needed.
log_level: "INFO"
patient_limit: 5 patient_limit: 5
concurrency_limit: 1
mongo: mongo:
uri: "mongodb://localhost:27017" uri: "mongodb://localhost:27017"

View File

@@ -2,12 +2,16 @@ from __future__ import annotations
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from helia.assessment.schema import AssessmentResponse, AssessmentResult, RunConfig from helia.assessment.schema import (
from helia.ingestion.parser import TranscriptParser AssessmentResponse,
AssessmentResult,
EvaluatorConfig,
)
from helia.llm.client import get_chat_model from helia.llm.client import get_chat_model
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path from helia.configuration import RunConfig
from helia.models.transcript import Transcript
# PHQ-8 Scoring Constants # PHQ-8 Scoring Constants
DIAGNOSIS_THRESHOLD = 10 DIAGNOSIS_THRESHOLD = 10
@@ -51,24 +55,23 @@ TRANSCRIPT:
class PHQ8Evaluator: class PHQ8Evaluator:
def __init__(self, config: RunConfig) -> None: def __init__(self, config: EvaluatorConfig) -> None:
self.config = config self.config = config
self.parser = TranscriptParser()
self.llm = get_chat_model( self.llm = get_chat_model(
model_name=self.config.model_name, model_name=self.config.model_name,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.api_base or "", base_url=self.config.api_base,
temperature=self.config.temperature, temperature=self.config.temperature,
) )
async def evaluate(self, file_path: Path) -> AssessmentResult: async def evaluate(
self, transcript: Transcript, original_run_config: RunConfig
) -> AssessmentResult:
""" """
Asynchronously evaluate a transcript using the configured LLM. Asynchronously evaluate a transcript using the configured LLM.
""" """
# 1. Parse Transcript transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
utterances = self.parser.parse(file_path)
transcript_text = "\n".join([f"{u.speaker}: {u.text}" for u in utterances])
# 2. Prepare Prompt # 2. Prepare Prompt
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text) final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
@@ -104,8 +107,8 @@ class PHQ8Evaluator:
diagnosis_algorithm = "Other Depression" diagnosis_algorithm = "Other Depression"
return AssessmentResult( return AssessmentResult(
transcript_id=file_path.stem, transcript_id=transcript.transcript_id,
config=self.config, config=original_run_config,
items=items, items=items,
total_score=total_score, total_score=total_score,
diagnosis_algorithm=diagnosis_algorithm, diagnosis_algorithm=diagnosis_algorithm,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from beanie import Document from beanie import Document
@@ -13,6 +15,17 @@ class Evidence(BaseModel):
reasoning: str reasoning: str
class EvaluatorConfig(BaseModel):
"""Configuration required to initialize the Evaluator."""
model_name: str
api_base: str
api_key: str | None = None
api_spec: str = "openai"
temperature: float = 0.0
prompt_id: str = "default"
class PHQ8Item(BaseModel): class PHQ8Item(BaseModel):
question_id: int question_id: int
question_text: str question_text: str

View File

@@ -61,6 +61,7 @@ class HeliaConfig(BaseSettings):
extra="ignore", extra="ignore",
) )
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
patient_limit: int = 5 patient_limit: int = 5
concurrency_limit: int = 1 concurrency_limit: int = 1
mongo: MongoConfig mongo: MongoConfig

View File

@@ -1,87 +1,72 @@
import argparse import argparse
import asyncio import asyncio
import logging import logging
from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from helia.assessment.core import PHQ8Evaluator from helia.assessment.core import PHQ8Evaluator
from helia.assessment.schema import RunConfig from helia.assessment.schema import EvaluatorConfig
from helia.configuration import ( from helia.configuration import (
HeliaConfig, HeliaConfig,
RunSpec, RunConfig,
S3Config,
load_config,
) )
from helia.db import init_db from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader from helia.models.transcript import Transcript
from helia.preflight import check_all_connections
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_config(path: Path) -> HeliaConfig:
"""Load configuration from the specified path."""
try:
return HeliaConfig(_yaml_file=str(path)) # ty:ignore[missing-argument, unknown-argument]
except Exception:
return HeliaConfig() # ty:ignore[missing-argument]
async def process_run( async def process_run(
run_spec: RunSpec, run_name: str,
input_source: str, run_config: RunConfig,
transcript: Transcript,
config: HeliaConfig, config: HeliaConfig,
s3_config: S3Config,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
) -> None: ) -> None:
""" """
Process a single run for a single transcript, bounded by a semaphore. Process a single run for a single transcript, bounded by a semaphore.
""" """
async with semaphore: async with semaphore:
# Resolve Provider provider_config = config.providers[run_config.model.provider]
provider_name = run_spec.model.provider
if provider_name not in config.providers:
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
return
provider_config = config.providers[provider_name] full_run_name = f"{run_name}_{transcript.transcript_id}"
# Download from S3 (Async) logger.info("--- Processing: %s ---", full_run_name)
loader = S3DatasetLoader(s3_config)
local_file = Path("data/downloads") / input_source
if not local_file.exists():
await loader.download_file_async(input_source, local_file)
input_path = local_file eval_config = EvaluatorConfig(
item_id = Path(input_source).stem.split("_")[0] # Extract 300 from 300_TRANSCRIPT model_name=run_config.model.model_name,
run_name = f"{run_spec.run_name}_{item_id}"
logger.info("--- Processing: %s ---", run_name)
run_config = RunConfig(
model_name=run_spec.model.model_name,
api_base=provider_config.api_base, api_base=provider_config.api_base,
api_key=provider_config.api_key, api_key=provider_config.api_key,
api_format=provider_config.api_format, api_spec=provider_config.api_spec,
prompt_id=run_spec.prompt_id, prompt_id=run_config.prompt_id,
temperature=run_spec.model.temperature, temperature=run_config.model.temperature,
timestamp=datetime.now(tz=UTC).isoformat(),
) )
# DEBUG LOGGING
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None" masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base) logger.debug("Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
try: try:
evaluator = PHQ8Evaluator(run_config) evaluator = PHQ8Evaluator(eval_config)
# Await the async evaluation result = await evaluator.evaluate(transcript, run_config)
result = await evaluator.evaluate(input_path)
# Save to DB (Async)
await result.insert() await result.insert()
logger.info("Assessment complete for %s.", run_name) logger.info("Assessment complete for %s.", full_run_name)
logger.info( logger.info(
"ID: %s | Score: %s | Diagnosis: %s", "ID: %s | Score: %s | Diagnosis: %s",
result.id, result.transcript_id,
result.total_score, result.total_score,
result.diagnosis_algorithm, result.diagnosis_algorithm,
) )
except Exception: except Exception:
logger.exception("Failed to process %s", run_name) logger.exception("Failed to process %s", full_run_name)
async def main() -> None: async def main() -> None:
@@ -114,42 +99,31 @@ async def main() -> None:
logger.exception("Error loading configuration") logger.exception("Error loading configuration")
return return
# Run Pre-flight Checks
if not await check_all_connections(config):
logger.error("Pre-flight checks failed. Exiting.")
return
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
await init_db(config) await init_db(config)
# Create S3 config once and reuse logger.debug("Discovering transcripts in MongoDB...")
s3_config = config.get_s3_config() limit=args.limit or config.patient_limit
query = Transcript.find_all(limit=limit)
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
loader = S3DatasetLoader(s3_config)
keys = loader.list_transcripts()
# Apply limit if configured (Priority: CLI > Config)
limit = args.limit
limit = limit if limit is not None else config.limit
if limit is not None: if limit is not None:
logger.info("Limiting processing to first %d transcripts", limit) logger.debug("Limiting processing to first %d transcripts", limit)
keys = keys[:limit]
# Create task list transcripts = await query.to_list()
tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
# Limit concurrency to 10 parallel requests if not transcripts:
semaphore = asyncio.Semaphore(10) logger.warning("No transcripts found in database.")
return
tasks = [ tasks_data = [
process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data (run_name, run_spec, t) for run_name, run_spec in config.runs.items() for t in transcripts
] ]
# Run all tasks concurrently logger.info("Starting batch assessment with %d total items...", len(tasks_data))
semaphore = asyncio.Semaphore(config.concurrency_limit)
tasks = [process_run(r_name, r_spec, t, config, semaphore) for r_name, r_spec, t in tasks_data]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
logger.info("Batch assessment complete.") logger.info("Batch assessment complete.")