This commit is contained in:
Santiago Martinez-Avial
2025-12-21 16:54:30 +01:00
parent 4a340a9661
commit 966e9d547a
5 changed files with 159 additions and 0 deletions

View File

@@ -37,3 +37,6 @@ runs:
model_name: "llama3"
temperature: 0.7
prompt_id: "default"
# Optional: Limit the number of transcripts processed
limit: 5

15
run_config.yaml Normal file
View File

@@ -0,0 +1,15 @@
providers:
openrouter:
api_key: "${OPENROUTER_API_KEY}"
api_base: "https://openrouter.ai/api/v1"
api_format: "openai"
runs:
- run_name: "gemini_flash"
model:
provider: openrouter
model_name: "google/gemini-3-flash-preview"
temperature: 1.0
prompt_id: "default"
limit: 5

View File

@@ -88,6 +88,7 @@ class AssessBatchConfig(BaseModel):
providers: dict[str, ProviderConfig]
runs: list[RunSpec]
limit: int | None = None
class AgentConfig(BaseModel):

View File

@@ -16,6 +16,7 @@ from helia.configuration import (
)
from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader
from helia.preflight import check_all_connections
logger = logging.getLogger(__name__)
@@ -61,6 +62,10 @@ async def process_run(
timestamp=datetime.now(tz=UTC).isoformat(),
)
# DEBUG LOGGING
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
try:
evaluator = PHQ8Evaluator(run_config)
# Await the async evaluation
@@ -91,6 +96,11 @@ async def main() -> None:
default="config.yaml",
help="Path to YAML configuration file (default: config.yaml)",
)
parser.add_argument(
"--limit",
type=int,
help="Limit the number of transcripts to process (overrides config)",
)
args = parser.parse_args()
config_path = Path(args.config)
@@ -109,6 +119,12 @@ async def main() -> None:
# Check the type of configuration
if isinstance(run_config_data, AssessBatchConfig):
# Run Pre-flight Checks
if not await check_all_connections(run_config_data, system_config):
logger.error("Pre-flight checks failed. Exiting.")
return
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
await init_db(system_config)
# Create S3 config once and reuse
@@ -119,6 +135,14 @@ async def main() -> None:
loader = S3DatasetLoader(s3_config)
keys = loader.list_transcripts()
# Apply limit if configured (Priority: CLI > Config)
limit = args.limit
limit = limit if limit is not None else run_config_data.limit
if limit is not None:
logger.info("Limiting processing to first %d transcripts", limit)
keys = keys[:limit]
# Create task list
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys]
logger.info("Starting batch assessment with %d total items...", len(tasks_data))

116
src/helia/preflight.py Normal file
View File

@@ -0,0 +1,116 @@
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING
from openai import AsyncOpenAI
from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader
if TYPE_CHECKING:
from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig
logger = logging.getLogger(__name__)
async def check_mongo(config: SystemConfig) -> bool:
"""Verify MongoDB connectivity."""
try:
logger.info("Checking MongoDB connection...")
await init_db(config)
except Exception:
logger.exception("MongoDB connection failed")
return False
else:
logger.info("MongoDB connection successful.")
return True
async def check_s3(config: S3Config) -> bool:
"""Verify S3 connectivity by listing objects."""
try:
logger.info("Checking S3 connection...")
loader = S3DatasetLoader(config)
loader.list_transcripts()
except Exception:
logger.exception("S3 connection failed")
return False
else:
logger.info("S3 connection successful.")
return True
async def check_llm_provider(name: str, config: ProviderConfig, model_name: str) -> bool:
"""Verify connectivity to an LLM provider."""
try:
logger.info("Checking LLM provider: %s...", name)
client = AsyncOpenAI(
api_key=config.api_key,
base_url=config.api_base,
max_retries=0, # Fail fast
timeout=5.0, # Fail fast
)
await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "ping"}],
max_tokens=1,
)
except Exception:
logger.exception("LLM provider %s connection failed", name)
return False
else:
logger.info("LLM provider %s connection successful.", name)
return True
async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool:
"""
Run all pre-flight checks. Returns True if all checks pass.
"""
logger.info("--- Starting Pre-flight Checks ---")
checks = []
# 1. Check MongoDB
checks.append(check_mongo(system_config))
# 2. Check S3
checks.append(check_s3(system_config.get_s3_config()))
# 3. Check All LLM Providers
for name, provider_config in run_config.providers.items():
# Find a model used by this provider in the runs to use for the check
model_name = next(
(run.model.model_name for run in run_config.runs if run.model.provider == name),
None,
)
if model_name:
checks.append(check_llm_provider(name, provider_config, model_name))
else:
logger.warning(
"Skipping check for provider '%s': defined in config but not used in any run.",
name,
)
results = await asyncio.gather(*checks, return_exceptions=True)
# Analyze results
all_passed = True
for result in results:
if isinstance(result, Exception):
logger.error("Check failed with exception: %s", result)
all_passed = False
elif result is False:
all_passed = False
if all_passed:
logger.info("--- All Pre-flight Checks Passed ---")
else:
logger.error("--- Pre-flight Checks Failed ---")
return all_passed