feat: enhance configuration management and refactor evaluator logic for improved clarity and functionality
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
# Helia Application Configuration
|
# Helia Application Configuration
|
||||||
# Copy this file to config.yaml and adjust values as needed.
|
# Copy this file to config.yaml and adjust values as needed.
|
||||||
|
|
||||||
|
log_level: "INFO"
|
||||||
patient_limit: 5
|
patient_limit: 5
|
||||||
|
concurrency_limit: 1
|
||||||
|
|
||||||
mongo:
|
mongo:
|
||||||
uri: "mongodb://localhost:27017"
|
uri: "mongodb://localhost:27017"
|
||||||
|
|||||||
@@ -2,12 +2,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
from helia.assessment.schema import AssessmentResponse, AssessmentResult, RunConfig
|
from helia.assessment.schema import (
|
||||||
from helia.ingestion.parser import TranscriptParser
|
AssessmentResponse,
|
||||||
|
AssessmentResult,
|
||||||
|
EvaluatorConfig,
|
||||||
|
)
|
||||||
from helia.llm.client import get_chat_model
|
from helia.llm.client import get_chat_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from helia.configuration import RunConfig
|
||||||
|
from helia.models.transcript import Transcript
|
||||||
|
|
||||||
# PHQ-8 Scoring Constants
|
# PHQ-8 Scoring Constants
|
||||||
DIAGNOSIS_THRESHOLD = 10
|
DIAGNOSIS_THRESHOLD = 10
|
||||||
@@ -51,24 +55,23 @@ TRANSCRIPT:
|
|||||||
|
|
||||||
|
|
||||||
class PHQ8Evaluator:
|
class PHQ8Evaluator:
|
||||||
def __init__(self, config: RunConfig) -> None:
|
def __init__(self, config: EvaluatorConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.parser = TranscriptParser()
|
|
||||||
|
|
||||||
self.llm = get_chat_model(
|
self.llm = get_chat_model(
|
||||||
model_name=self.config.model_name,
|
model_name=self.config.model_name,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.api_base or "",
|
base_url=self.config.api_base,
|
||||||
temperature=self.config.temperature,
|
temperature=self.config.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def evaluate(self, file_path: Path) -> AssessmentResult:
|
async def evaluate(
|
||||||
|
self, transcript: Transcript, original_run_config: RunConfig
|
||||||
|
) -> AssessmentResult:
|
||||||
"""
|
"""
|
||||||
Asynchronously evaluate a transcript using the configured LLM.
|
Asynchronously evaluate a transcript using the configured LLM.
|
||||||
"""
|
"""
|
||||||
# 1. Parse Transcript
|
transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
|
||||||
utterances = self.parser.parse(file_path)
|
|
||||||
transcript_text = "\n".join([f"{u.speaker}: {u.text}" for u in utterances])
|
|
||||||
|
|
||||||
# 2. Prepare Prompt
|
# 2. Prepare Prompt
|
||||||
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
|
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
|
||||||
@@ -104,8 +107,8 @@ class PHQ8Evaluator:
|
|||||||
diagnosis_algorithm = "Other Depression"
|
diagnosis_algorithm = "Other Depression"
|
||||||
|
|
||||||
return AssessmentResult(
|
return AssessmentResult(
|
||||||
transcript_id=file_path.stem,
|
transcript_id=transcript.transcript_id,
|
||||||
config=self.config,
|
config=original_run_config,
|
||||||
items=items,
|
items=items,
|
||||||
total_score=total_score,
|
total_score=total_score,
|
||||||
diagnosis_algorithm=diagnosis_algorithm,
|
diagnosis_algorithm=diagnosis_algorithm,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from beanie import Document
|
from beanie import Document
|
||||||
@@ -13,6 +15,17 @@ class Evidence(BaseModel):
|
|||||||
reasoning: str
|
reasoning: str
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorConfig(BaseModel):
|
||||||
|
"""Configuration required to initialize the Evaluator."""
|
||||||
|
|
||||||
|
model_name: str
|
||||||
|
api_base: str
|
||||||
|
api_key: str | None = None
|
||||||
|
api_spec: str = "openai"
|
||||||
|
temperature: float = 0.0
|
||||||
|
prompt_id: str = "default"
|
||||||
|
|
||||||
|
|
||||||
class PHQ8Item(BaseModel):
|
class PHQ8Item(BaseModel):
|
||||||
question_id: int
|
question_id: int
|
||||||
question_text: str
|
question_text: str
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ class HeliaConfig(BaseSettings):
|
|||||||
extra="ignore",
|
extra="ignore",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||||
patient_limit: int = 5
|
patient_limit: int = 5
|
||||||
concurrency_limit: int = 1
|
concurrency_limit: int = 1
|
||||||
mongo: MongoConfig
|
mongo: MongoConfig
|
||||||
|
|||||||
@@ -1,87 +1,72 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from helia.assessment.core import PHQ8Evaluator
|
from helia.assessment.core import PHQ8Evaluator
|
||||||
from helia.assessment.schema import RunConfig
|
from helia.assessment.schema import EvaluatorConfig
|
||||||
from helia.configuration import (
|
from helia.configuration import (
|
||||||
HeliaConfig,
|
HeliaConfig,
|
||||||
RunSpec,
|
RunConfig,
|
||||||
S3Config,
|
|
||||||
load_config,
|
|
||||||
)
|
)
|
||||||
from helia.db import init_db
|
from helia.db import init_db
|
||||||
from helia.ingestion.s3 import S3DatasetLoader
|
from helia.models.transcript import Transcript
|
||||||
from helia.preflight import check_all_connections
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(path: Path) -> HeliaConfig:
|
||||||
|
"""Load configuration from the specified path."""
|
||||||
|
try:
|
||||||
|
return HeliaConfig(_yaml_file=str(path)) # ty:ignore[missing-argument, unknown-argument]
|
||||||
|
except Exception:
|
||||||
|
return HeliaConfig() # ty:ignore[missing-argument]
|
||||||
|
|
||||||
|
|
||||||
async def process_run(
|
async def process_run(
|
||||||
run_spec: RunSpec,
|
run_name: str,
|
||||||
input_source: str,
|
run_config: RunConfig,
|
||||||
|
transcript: Transcript,
|
||||||
config: HeliaConfig,
|
config: HeliaConfig,
|
||||||
s3_config: S3Config,
|
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Process a single run for a single transcript, bounded by a semaphore.
|
Process a single run for a single transcript, bounded by a semaphore.
|
||||||
"""
|
"""
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
# Resolve Provider
|
provider_config = config.providers[run_config.model.provider]
|
||||||
provider_name = run_spec.model.provider
|
|
||||||
if provider_name not in config.providers:
|
|
||||||
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
|
|
||||||
return
|
|
||||||
|
|
||||||
provider_config = config.providers[provider_name]
|
full_run_name = f"{run_name}_{transcript.transcript_id}"
|
||||||
|
|
||||||
# Download from S3 (Async)
|
logger.info("--- Processing: %s ---", full_run_name)
|
||||||
loader = S3DatasetLoader(s3_config)
|
|
||||||
local_file = Path("data/downloads") / input_source
|
|
||||||
if not local_file.exists():
|
|
||||||
await loader.download_file_async(input_source, local_file)
|
|
||||||
|
|
||||||
input_path = local_file
|
eval_config = EvaluatorConfig(
|
||||||
item_id = Path(input_source).stem.split("_")[0] # Extract 300 from 300_TRANSCRIPT
|
model_name=run_config.model.model_name,
|
||||||
run_name = f"{run_spec.run_name}_{item_id}"
|
|
||||||
|
|
||||||
logger.info("--- Processing: %s ---", run_name)
|
|
||||||
|
|
||||||
run_config = RunConfig(
|
|
||||||
model_name=run_spec.model.model_name,
|
|
||||||
api_base=provider_config.api_base,
|
api_base=provider_config.api_base,
|
||||||
api_key=provider_config.api_key,
|
api_key=provider_config.api_key,
|
||||||
api_format=provider_config.api_format,
|
api_spec=provider_config.api_spec,
|
||||||
prompt_id=run_spec.prompt_id,
|
prompt_id=run_config.prompt_id,
|
||||||
temperature=run_spec.model.temperature,
|
temperature=run_config.model.temperature,
|
||||||
timestamp=datetime.now(tz=UTC).isoformat(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# DEBUG LOGGING
|
|
||||||
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
|
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
|
||||||
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
|
logger.debug("Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
evaluator = PHQ8Evaluator(run_config)
|
evaluator = PHQ8Evaluator(eval_config)
|
||||||
# Await the async evaluation
|
result = await evaluator.evaluate(transcript, run_config)
|
||||||
result = await evaluator.evaluate(input_path)
|
|
||||||
|
|
||||||
# Save to DB (Async)
|
|
||||||
await result.insert()
|
await result.insert()
|
||||||
|
|
||||||
logger.info("Assessment complete for %s.", run_name)
|
logger.info("Assessment complete for %s.", full_run_name)
|
||||||
logger.info(
|
logger.info(
|
||||||
"ID: %s | Score: %s | Diagnosis: %s",
|
"ID: %s | Score: %s | Diagnosis: %s",
|
||||||
result.id,
|
result.transcript_id,
|
||||||
result.total_score,
|
result.total_score,
|
||||||
result.diagnosis_algorithm,
|
result.diagnosis_algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to process %s", run_name)
|
logger.exception("Failed to process %s", full_run_name)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
@@ -114,42 +99,31 @@ async def main() -> None:
|
|||||||
logger.exception("Error loading configuration")
|
logger.exception("Error loading configuration")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Run Pre-flight Checks
|
|
||||||
if not await check_all_connections(config):
|
|
||||||
logger.error("Pre-flight checks failed. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
|
||||||
await init_db(config)
|
await init_db(config)
|
||||||
|
|
||||||
# Create S3 config once and reuse
|
logger.debug("Discovering transcripts in MongoDB...")
|
||||||
s3_config = config.get_s3_config()
|
limit=args.limit or config.patient_limit
|
||||||
|
query = Transcript.find_all(limit=limit)
|
||||||
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
|
|
||||||
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
|
|
||||||
loader = S3DatasetLoader(s3_config)
|
|
||||||
keys = loader.list_transcripts()
|
|
||||||
|
|
||||||
# Apply limit if configured (Priority: CLI > Config)
|
|
||||||
limit = args.limit
|
|
||||||
limit = limit if limit is not None else config.limit
|
|
||||||
|
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
logger.info("Limiting processing to first %d transcripts", limit)
|
logger.debug("Limiting processing to first %d transcripts", limit)
|
||||||
keys = keys[:limit]
|
|
||||||
|
|
||||||
# Create task list
|
transcripts = await query.to_list()
|
||||||
tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
|
|
||||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
|
||||||
|
|
||||||
# Limit concurrency to 10 parallel requests
|
if not transcripts:
|
||||||
semaphore = asyncio.Semaphore(10)
|
logger.warning("No transcripts found in database.")
|
||||||
|
return
|
||||||
|
|
||||||
tasks = [
|
tasks_data = [
|
||||||
process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data
|
(run_name, run_spec, t) for run_name, run_spec in config.runs.items() for t in transcripts
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run all tasks concurrently
|
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(config.concurrency_limit)
|
||||||
|
|
||||||
|
tasks = [process_run(r_name, r_spec, t, config, semaphore) for r_name, r_spec, t in tasks_data]
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
logger.info("Batch assessment complete.")
|
logger.info("Batch assessment complete.")
|
||||||
|
|||||||
Reference in New Issue
Block a user