diff --git a/example.run_config.yaml b/example.run_config.yaml index f0277d9..d40ef54 100644 --- a/example.run_config.yaml +++ b/example.run_config.yaml @@ -37,3 +37,6 @@ runs: model_name: "llama3" temperature: 0.7 prompt_id: "default" + +# Optional: Limit the number of transcripts processed +limit: 5 diff --git a/run_config.yaml b/run_config.yaml new file mode 100644 index 0000000..436f256 --- /dev/null +++ b/run_config.yaml @@ -0,0 +1,15 @@ +providers: + openrouter: + api_key: "${OPENROUTER_API_KEY}" + api_base: "https://openrouter.ai/api/v1" + api_format: "openai" + +runs: + - run_name: "gemini_flash" + model: + provider: openrouter + model_name: "google/gemini-3-flash-preview" + temperature: 1.0 + prompt_id: "default" + +limit: 5 diff --git a/src/helia/configuration.py b/src/helia/configuration.py index cd8e0c2..7fe3166 100644 --- a/src/helia/configuration.py +++ b/src/helia/configuration.py @@ -88,6 +88,7 @@ class AssessBatchConfig(BaseModel): providers: dict[str, ProviderConfig] runs: list[RunSpec] + limit: int | None = None class AgentConfig(BaseModel): diff --git a/src/helia/main.py b/src/helia/main.py index 11ebd78..8131635 100644 --- a/src/helia/main.py +++ b/src/helia/main.py @@ -16,6 +16,7 @@ from helia.configuration import ( ) from helia.db import init_db from helia.ingestion.s3 import S3DatasetLoader +from helia.preflight import check_all_connections logger = logging.getLogger(__name__) @@ -61,6 +62,10 @@ async def process_run( timestamp=datetime.now(tz=UTC).isoformat(), ) + # DEBUG LOGGING + masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None" + logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base) + try: evaluator = PHQ8Evaluator(run_config) # Await the async evaluation @@ -91,6 +96,11 @@ async def main() -> None: default="config.yaml", help="Path to YAML configuration file (default: config.yaml)", ) + parser.add_argument( + "--limit", + type=int, + help="Limit the number of transcripts to process (overrides config)", + ) args = parser.parse_args() config_path = Path(args.config) @@ -109,6 +119,12 @@ async def main() -> None: # Check the type of configuration if isinstance(run_config_data, AssessBatchConfig): + # Run Pre-flight Checks + if not await check_all_connections(run_config_data, system_config): + logger.error("Pre-flight checks failed. Exiting.") + return + + # Ensure DB is initialized (redundant if pre-flight passed, but safe) await init_db(system_config) # Create S3 config once and reuse @@ -119,6 +135,14 @@ async def main() -> None: loader = S3DatasetLoader(s3_config) keys = loader.list_transcripts() + # Apply limit if configured (Priority: CLI > Config) + limit = args.limit + limit = limit if limit is not None else run_config_data.limit + + if limit is not None: + logger.info("Limiting processing to first %d transcripts", limit) + keys = keys[:limit] + # Create task list tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys] logger.info("Starting batch assessment with %d total items...", len(tasks_data)) diff --git a/src/helia/preflight.py b/src/helia/preflight.py new file mode 100644 index 0000000..88bd49d --- /dev/null +++ b/src/helia/preflight.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING + +from openai import AsyncOpenAI + +from helia.db import init_db +from helia.ingestion.s3 import S3DatasetLoader + +if TYPE_CHECKING: + from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig + +logger = logging.getLogger(__name__) + + +async def check_mongo(config: SystemConfig) -> bool: + """Verify MongoDB connectivity.""" + try: + logger.info("Checking MongoDB connection...") + await init_db(config) + except Exception: + logger.exception("MongoDB connection failed") + return False + else: + logger.info("MongoDB connection successful.") + return True + + +async def check_s3(config: S3Config) -> bool: + """Verify S3 connectivity by listing objects.""" + try: + logger.info("Checking S3 connection...") + loader = S3DatasetLoader(config) + loader.list_transcripts() + except Exception: + logger.exception("S3 connection failed") + return False + else: + logger.info("S3 connection successful.") + return True + + +async def check_llm_provider(name: str, config: ProviderConfig, model_name: str) -> bool: + """Verify connectivity to an LLM provider.""" + try: + logger.info("Checking LLM provider: %s...", name) + + client = AsyncOpenAI( + api_key=config.api_key, + base_url=config.api_base, + max_retries=0, # Fail fast + timeout=5.0, # Fail fast + ) + + await client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "ping"}], + max_tokens=1, + ) + except Exception: + logger.exception("LLM provider %s connection failed", name) + return False + else: + logger.info("LLM provider %s connection successful.", name) + return True + + +async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool: + """ + Run all pre-flight checks. Returns True if all checks pass. + """ + logger.info("--- Starting Pre-flight Checks ---") + + checks = [] + + # 1. Check MongoDB + checks.append(check_mongo(system_config)) + + # 2. Check S3 + checks.append(check_s3(system_config.get_s3_config())) + + # 3. Check All LLM Providers + for name, provider_config in run_config.providers.items(): + # Find a model used by this provider in the runs to use for the check + model_name = next( + (run.model.model_name for run in run_config.runs if run.model.provider == name), + None, + ) + + if model_name: + checks.append(check_llm_provider(name, provider_config, model_name)) + else: + logger.warning( + "Skipping check for provider '%s': defined in config but not used in any run.", + name, + ) + + results = await asyncio.gather(*checks, return_exceptions=True) + + # Analyze results + all_passed = True + for result in results: + if isinstance(result, Exception): + logger.error("Check failed with exception: %s", result) + all_passed = False + elif result is False: + all_passed = False + + if all_passed: + logger.info("--- All Pre-flight Checks Passed ---") + else: + logger.error("--- Pre-flight Checks Failed ---") + + return all_passed