feat: enhance configuration management and refactor evaluator logic for improved clarity and functionality
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
# Helia Application Configuration
|
||||
# Copy this file to config.yaml and adjust values as needed.
|
||||
|
||||
log_level: "INFO"
|
||||
patient_limit: 5
|
||||
concurrency_limit: 1
|
||||
|
||||
mongo:
|
||||
uri: "mongodb://localhost:27017"
|
||||
|
||||
@@ -2,12 +2,16 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from helia.assessment.schema import AssessmentResponse, AssessmentResult, RunConfig
|
||||
from helia.ingestion.parser import TranscriptParser
|
||||
from helia.assessment.schema import (
|
||||
AssessmentResponse,
|
||||
AssessmentResult,
|
||||
EvaluatorConfig,
|
||||
)
|
||||
from helia.llm.client import get_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from helia.configuration import RunConfig
|
||||
from helia.models.transcript import Transcript
|
||||
|
||||
# PHQ-8 Scoring Constants
|
||||
DIAGNOSIS_THRESHOLD = 10
|
||||
@@ -51,24 +55,23 @@ TRANSCRIPT:
|
||||
|
||||
|
||||
class PHQ8Evaluator:
|
||||
def __init__(self, config: RunConfig) -> None:
|
||||
def __init__(self, config: EvaluatorConfig) -> None:
|
||||
self.config = config
|
||||
self.parser = TranscriptParser()
|
||||
|
||||
self.llm = get_chat_model(
|
||||
model_name=self.config.model_name,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.api_base or "",
|
||||
base_url=self.config.api_base,
|
||||
temperature=self.config.temperature,
|
||||
)
|
||||
|
||||
async def evaluate(self, file_path: Path) -> AssessmentResult:
|
||||
async def evaluate(
|
||||
self, transcript: Transcript, original_run_config: RunConfig
|
||||
) -> AssessmentResult:
|
||||
"""
|
||||
Asynchronously evaluate a transcript using the configured LLM.
|
||||
"""
|
||||
# 1. Parse Transcript
|
||||
utterances = self.parser.parse(file_path)
|
||||
transcript_text = "\n".join([f"{u.speaker}: {u.text}" for u in utterances])
|
||||
transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
|
||||
|
||||
# 2. Prepare Prompt
|
||||
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
|
||||
@@ -104,8 +107,8 @@ class PHQ8Evaluator:
|
||||
diagnosis_algorithm = "Other Depression"
|
||||
|
||||
return AssessmentResult(
|
||||
transcript_id=file_path.stem,
|
||||
config=self.config,
|
||||
transcript_id=transcript.transcript_id,
|
||||
config=original_run_config,
|
||||
items=items,
|
||||
total_score=total_score,
|
||||
diagnosis_algorithm=diagnosis_algorithm,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from beanie import Document
|
||||
@@ -13,6 +15,17 @@ class Evidence(BaseModel):
|
||||
reasoning: str
|
||||
|
||||
|
||||
class EvaluatorConfig(BaseModel):
|
||||
"""Configuration required to initialize the Evaluator."""
|
||||
|
||||
model_name: str
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
api_spec: str = "openai"
|
||||
temperature: float = 0.0
|
||||
prompt_id: str = "default"
|
||||
|
||||
|
||||
class PHQ8Item(BaseModel):
|
||||
question_id: int
|
||||
question_text: str
|
||||
|
||||
@@ -61,6 +61,7 @@ class HeliaConfig(BaseSettings):
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
patient_limit: int = 5
|
||||
concurrency_limit: int = 1
|
||||
mongo: MongoConfig
|
||||
|
||||
@@ -1,87 +1,72 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from helia.assessment.core import PHQ8Evaluator
|
||||
from helia.assessment.schema import RunConfig
|
||||
from helia.assessment.schema import EvaluatorConfig
|
||||
from helia.configuration import (
|
||||
HeliaConfig,
|
||||
RunSpec,
|
||||
S3Config,
|
||||
load_config,
|
||||
RunConfig,
|
||||
)
|
||||
from helia.db import init_db
|
||||
from helia.ingestion.s3 import S3DatasetLoader
|
||||
from helia.preflight import check_all_connections
|
||||
from helia.models.transcript import Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_config(path: Path) -> HeliaConfig:
|
||||
"""Load configuration from the specified path."""
|
||||
try:
|
||||
return HeliaConfig(_yaml_file=str(path)) # ty:ignore[missing-argument, unknown-argument]
|
||||
except Exception:
|
||||
return HeliaConfig() # ty:ignore[missing-argument]
|
||||
|
||||
|
||||
async def process_run(
|
||||
run_spec: RunSpec,
|
||||
input_source: str,
|
||||
run_name: str,
|
||||
run_config: RunConfig,
|
||||
transcript: Transcript,
|
||||
config: HeliaConfig,
|
||||
s3_config: S3Config,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> None:
|
||||
"""
|
||||
Process a single run for a single transcript, bounded by a semaphore.
|
||||
"""
|
||||
async with semaphore:
|
||||
# Resolve Provider
|
||||
provider_name = run_spec.model.provider
|
||||
if provider_name not in config.providers:
|
||||
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
|
||||
return
|
||||
provider_config = config.providers[run_config.model.provider]
|
||||
|
||||
provider_config = config.providers[provider_name]
|
||||
full_run_name = f"{run_name}_{transcript.transcript_id}"
|
||||
|
||||
# Download from S3 (Async)
|
||||
loader = S3DatasetLoader(s3_config)
|
||||
local_file = Path("data/downloads") / input_source
|
||||
if not local_file.exists():
|
||||
await loader.download_file_async(input_source, local_file)
|
||||
logger.info("--- Processing: %s ---", full_run_name)
|
||||
|
||||
input_path = local_file
|
||||
item_id = Path(input_source).stem.split("_")[0] # Extract 300 from 300_TRANSCRIPT
|
||||
run_name = f"{run_spec.run_name}_{item_id}"
|
||||
|
||||
logger.info("--- Processing: %s ---", run_name)
|
||||
|
||||
run_config = RunConfig(
|
||||
model_name=run_spec.model.model_name,
|
||||
eval_config = EvaluatorConfig(
|
||||
model_name=run_config.model.model_name,
|
||||
api_base=provider_config.api_base,
|
||||
api_key=provider_config.api_key,
|
||||
api_format=provider_config.api_format,
|
||||
prompt_id=run_spec.prompt_id,
|
||||
temperature=run_spec.model.temperature,
|
||||
timestamp=datetime.now(tz=UTC).isoformat(),
|
||||
api_spec=provider_config.api_spec,
|
||||
prompt_id=run_config.prompt_id,
|
||||
temperature=run_config.model.temperature,
|
||||
)
|
||||
|
||||
# DEBUG LOGGING
|
||||
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
|
||||
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
|
||||
logger.debug("Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
|
||||
|
||||
try:
|
||||
evaluator = PHQ8Evaluator(run_config)
|
||||
# Await the async evaluation
|
||||
result = await evaluator.evaluate(input_path)
|
||||
|
||||
# Save to DB (Async)
|
||||
evaluator = PHQ8Evaluator(eval_config)
|
||||
result = await evaluator.evaluate(transcript, run_config)
|
||||
await result.insert()
|
||||
|
||||
logger.info("Assessment complete for %s.", run_name)
|
||||
logger.info("Assessment complete for %s.", full_run_name)
|
||||
logger.info(
|
||||
"ID: %s | Score: %s | Diagnosis: %s",
|
||||
result.id,
|
||||
result.transcript_id,
|
||||
result.total_score,
|
||||
result.diagnosis_algorithm,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to process %s", run_name)
|
||||
logger.exception("Failed to process %s", full_run_name)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@@ -114,42 +99,31 @@ async def main() -> None:
|
||||
logger.exception("Error loading configuration")
|
||||
return
|
||||
|
||||
# Run Pre-flight Checks
|
||||
if not await check_all_connections(config):
|
||||
logger.error("Pre-flight checks failed. Exiting.")
|
||||
return
|
||||
|
||||
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
||||
await init_db(config)
|
||||
|
||||
# Create S3 config once and reuse
|
||||
s3_config = config.get_s3_config()
|
||||
|
||||
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
|
||||
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
|
||||
loader = S3DatasetLoader(s3_config)
|
||||
keys = loader.list_transcripts()
|
||||
|
||||
# Apply limit if configured (Priority: CLI > Config)
|
||||
limit = args.limit
|
||||
limit = limit if limit is not None else config.limit
|
||||
logger.debug("Discovering transcripts in MongoDB...")
|
||||
limit=args.limit or config.patient_limit
|
||||
query = Transcript.find_all(limit=limit)
|
||||
|
||||
if limit is not None:
|
||||
logger.info("Limiting processing to first %d transcripts", limit)
|
||||
keys = keys[:limit]
|
||||
logger.debug("Limiting processing to first %d transcripts", limit)
|
||||
|
||||
# Create task list
|
||||
tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
|
||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||
transcripts = await query.to_list()
|
||||
|
||||
# Limit concurrency to 10 parallel requests
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
if not transcripts:
|
||||
logger.warning("No transcripts found in database.")
|
||||
return
|
||||
|
||||
tasks = [
|
||||
process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data
|
||||
tasks_data = [
|
||||
(run_name, run_spec, t) for run_name, run_spec in config.runs.items() for t in transcripts
|
||||
]
|
||||
|
||||
# Run all tasks concurrently
|
||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||
|
||||
semaphore = asyncio.Semaphore(config.concurrency_limit)
|
||||
|
||||
tasks = [process_run(r_name, r_spec, t, config, semaphore) for r_name, r_spec, t in tasks_data]
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
logger.info("Batch assessment complete.")
|
||||
|
||||
Reference in New Issue
Block a user