feat: enhance configuration management and refactor evaluator logic for improved clarity and functionality

This commit is contained in:
Santiago Martinez-Avial
2025-12-22 18:44:22 +01:00
parent 572f59e9ce
commit 8efc2b6217
5 changed files with 75 additions and 82 deletions

View File

@@ -1,7 +1,9 @@
# Helia Application Configuration
# Copy this file to config.yaml and adjust values as needed.
log_level: "INFO"
patient_limit: 5
concurrency_limit: 1
mongo:
uri: "mongodb://localhost:27017"

View File

@@ -2,12 +2,16 @@ from __future__ import annotations
from typing import TYPE_CHECKING, cast
from helia.assessment.schema import AssessmentResponse, AssessmentResult, RunConfig
from helia.ingestion.parser import TranscriptParser
from helia.assessment.schema import (
AssessmentResponse,
AssessmentResult,
EvaluatorConfig,
)
from helia.llm.client import get_chat_model
if TYPE_CHECKING:
from pathlib import Path
from helia.configuration import RunConfig
from helia.models.transcript import Transcript
# PHQ-8 Scoring Constants
DIAGNOSIS_THRESHOLD = 10
@@ -51,24 +55,23 @@ TRANSCRIPT:
class PHQ8Evaluator:
def __init__(self, config: RunConfig) -> None:
def __init__(self, config: EvaluatorConfig) -> None:
self.config = config
self.parser = TranscriptParser()
self.llm = get_chat_model(
model_name=self.config.model_name,
api_key=self.config.api_key,
base_url=self.config.api_base or "",
base_url=self.config.api_base,
temperature=self.config.temperature,
)
async def evaluate(self, file_path: Path) -> AssessmentResult:
async def evaluate(
self, transcript: Transcript, original_run_config: RunConfig
) -> AssessmentResult:
"""
Asynchronously evaluate a transcript using the configured LLM.
"""
# 1. Parse Transcript
utterances = self.parser.parse(file_path)
transcript_text = "\n".join([f"{u.speaker}: {u.text}" for u in utterances])
transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
# 2. Prepare Prompt
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
@@ -104,8 +107,8 @@ class PHQ8Evaluator:
diagnosis_algorithm = "Other Depression"
return AssessmentResult(
transcript_id=file_path.stem,
config=self.config,
transcript_id=transcript.transcript_id,
config=original_run_config,
items=items,
total_score=total_score,
diagnosis_algorithm=diagnosis_algorithm,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from beanie import Document
@@ -13,6 +15,17 @@ class Evidence(BaseModel):
reasoning: str
class EvaluatorConfig(BaseModel):
"""Configuration required to initialize the Evaluator."""
model_name: str
api_base: str
api_key: str | None = None
api_spec: str = "openai"
temperature: float = 0.0
prompt_id: str = "default"
class PHQ8Item(BaseModel):
question_id: int
question_text: str

View File

@@ -61,6 +61,7 @@ class HeliaConfig(BaseSettings):
extra="ignore",
)
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
patient_limit: int = 5
concurrency_limit: int = 1
mongo: MongoConfig

View File

@@ -1,87 +1,72 @@
import argparse
import asyncio
import logging
from datetime import UTC, datetime
from pathlib import Path
from helia.assessment.core import PHQ8Evaluator
from helia.assessment.schema import RunConfig
from helia.assessment.schema import EvaluatorConfig
from helia.configuration import (
HeliaConfig,
RunSpec,
S3Config,
load_config,
RunConfig,
)
from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader
from helia.preflight import check_all_connections
from helia.models.transcript import Transcript
logger = logging.getLogger(__name__)
def load_config(path: Path) -> HeliaConfig:
"""Load configuration from the specified path."""
try:
return HeliaConfig(_yaml_file=str(path)) # ty:ignore[missing-argument, unknown-argument]
except Exception:
return HeliaConfig() # ty:ignore[missing-argument]
async def process_run(
run_spec: RunSpec,
input_source: str,
run_name: str,
run_config: RunConfig,
transcript: Transcript,
config: HeliaConfig,
s3_config: S3Config,
semaphore: asyncio.Semaphore,
) -> None:
"""
Process a single run for a single transcript, bounded by a semaphore.
"""
async with semaphore:
# Resolve Provider
provider_name = run_spec.model.provider
if provider_name not in config.providers:
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
return
provider_config = config.providers[run_config.model.provider]
provider_config = config.providers[provider_name]
full_run_name = f"{run_name}_{transcript.transcript_id}"
# Download from S3 (Async)
loader = S3DatasetLoader(s3_config)
local_file = Path("data/downloads") / input_source
if not local_file.exists():
await loader.download_file_async(input_source, local_file)
logger.info("--- Processing: %s ---", full_run_name)
input_path = local_file
item_id = Path(input_source).stem.split("_")[0] # Extract 300 from 300_TRANSCRIPT
run_name = f"{run_spec.run_name}_{item_id}"
logger.info("--- Processing: %s ---", run_name)
run_config = RunConfig(
model_name=run_spec.model.model_name,
eval_config = EvaluatorConfig(
model_name=run_config.model.model_name,
api_base=provider_config.api_base,
api_key=provider_config.api_key,
api_format=provider_config.api_format,
prompt_id=run_spec.prompt_id,
temperature=run_spec.model.temperature,
timestamp=datetime.now(tz=UTC).isoformat(),
api_spec=provider_config.api_spec,
prompt_id=run_config.prompt_id,
temperature=run_config.model.temperature,
)
# DEBUG LOGGING
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
logger.debug("Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
try:
evaluator = PHQ8Evaluator(run_config)
# Await the async evaluation
result = await evaluator.evaluate(input_path)
# Save to DB (Async)
evaluator = PHQ8Evaluator(eval_config)
result = await evaluator.evaluate(transcript, run_config)
await result.insert()
logger.info("Assessment complete for %s.", run_name)
logger.info("Assessment complete for %s.", full_run_name)
logger.info(
"ID: %s | Score: %s | Diagnosis: %s",
result.id,
result.transcript_id,
result.total_score,
result.diagnosis_algorithm,
)
except Exception:
logger.exception("Failed to process %s", run_name)
logger.exception("Failed to process %s", full_run_name)
async def main() -> None:
@@ -114,42 +99,31 @@ async def main() -> None:
logger.exception("Error loading configuration")
return
# Run Pre-flight Checks
if not await check_all_connections(config):
logger.error("Pre-flight checks failed. Exiting.")
return
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
await init_db(config)
# Create S3 config once and reuse
s3_config = config.get_s3_config()
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
loader = S3DatasetLoader(s3_config)
keys = loader.list_transcripts()
# Apply limit if configured (Priority: CLI > Config)
limit = args.limit
limit = limit if limit is not None else config.limit
logger.debug("Discovering transcripts in MongoDB...")
limit=args.limit or config.patient_limit
query = Transcript.find_all(limit=limit)
if limit is not None:
logger.info("Limiting processing to first %d transcripts", limit)
keys = keys[:limit]
logger.debug("Limiting processing to first %d transcripts", limit)
# Create task list
tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
transcripts = await query.to_list()
# Limit concurrency to 10 parallel requests
semaphore = asyncio.Semaphore(10)
if not transcripts:
logger.warning("No transcripts found in database.")
return
tasks = [
process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data
tasks_data = [
(run_name, run_spec, t) for run_name, run_spec in config.runs.items() for t in transcripts
]
# Run all tasks concurrently
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
semaphore = asyncio.Semaphore(config.concurrency_limit)
tasks = [process_run(r_name, r_spec, t, config, semaphore) for r_name, r_spec, t in tasks_data]
await asyncio.gather(*tasks)
logger.info("Batch assessment complete.")