WIP
This commit is contained in:
@@ -37,3 +37,6 @@ runs:
|
||||
model_name: "llama3"
|
||||
temperature: 0.7
|
||||
prompt_id: "default"
|
||||
|
||||
# Optional: Limit the number of transcripts processed
|
||||
limit: 5
|
||||
|
||||
15
run_config.yaml
Normal file
15
run_config.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
providers:
|
||||
openrouter:
|
||||
api_key: "${OPENROUTER_API_KEY}"
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_format: "openai"
|
||||
|
||||
runs:
|
||||
- run_name: "gemini_flash"
|
||||
model:
|
||||
provider: openrouter
|
||||
model_name: "google/gemini-3-flash-preview"
|
||||
temperature: 1.0
|
||||
prompt_id: "default"
|
||||
|
||||
limit: 5
|
||||
@@ -88,6 +88,7 @@ class AssessBatchConfig(BaseModel):
|
||||
|
||||
providers: dict[str, ProviderConfig]
|
||||
runs: list[RunSpec]
|
||||
limit: int | None = None
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
|
||||
@@ -16,6 +16,7 @@ from helia.configuration import (
|
||||
)
|
||||
from helia.db import init_db
|
||||
from helia.ingestion.s3 import S3DatasetLoader
|
||||
from helia.preflight import check_all_connections
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,6 +62,10 @@ async def process_run(
|
||||
timestamp=datetime.now(tz=UTC).isoformat(),
|
||||
)
|
||||
|
||||
# DEBUG LOGGING
|
||||
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
|
||||
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
|
||||
|
||||
try:
|
||||
evaluator = PHQ8Evaluator(run_config)
|
||||
# Await the async evaluation
|
||||
@@ -91,6 +96,11 @@ async def main() -> None:
|
||||
default="config.yaml",
|
||||
help="Path to YAML configuration file (default: config.yaml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
help="Limit the number of transcripts to process (overrides config)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
config_path = Path(args.config)
|
||||
@@ -109,6 +119,12 @@ async def main() -> None:
|
||||
|
||||
# Check the type of configuration
|
||||
if isinstance(run_config_data, AssessBatchConfig):
|
||||
# Run Pre-flight Checks
|
||||
if not await check_all_connections(run_config_data, system_config):
|
||||
logger.error("Pre-flight checks failed. Exiting.")
|
||||
return
|
||||
|
||||
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
||||
await init_db(system_config)
|
||||
|
||||
# Create S3 config once and reuse
|
||||
@@ -119,6 +135,14 @@ async def main() -> None:
|
||||
loader = S3DatasetLoader(s3_config)
|
||||
keys = loader.list_transcripts()
|
||||
|
||||
# Apply limit if configured (Priority: CLI > Config)
|
||||
limit = args.limit
|
||||
limit = limit if limit is not None else run_config_data.limit
|
||||
|
||||
if limit is not None:
|
||||
logger.info("Limiting processing to first %d transcripts", limit)
|
||||
keys = keys[:limit]
|
||||
|
||||
# Create task list
|
||||
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys]
|
||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||
|
||||
116
src/helia/preflight.py
Normal file
116
src/helia/preflight.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from helia.db import init_db
|
||||
from helia.ingestion.s3 import S3DatasetLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def check_mongo(config: SystemConfig) -> bool:
|
||||
"""Verify MongoDB connectivity."""
|
||||
try:
|
||||
logger.info("Checking MongoDB connection...")
|
||||
await init_db(config)
|
||||
except Exception:
|
||||
logger.exception("MongoDB connection failed")
|
||||
return False
|
||||
else:
|
||||
logger.info("MongoDB connection successful.")
|
||||
return True
|
||||
|
||||
|
||||
async def check_s3(config: S3Config) -> bool:
|
||||
"""Verify S3 connectivity by listing objects."""
|
||||
try:
|
||||
logger.info("Checking S3 connection...")
|
||||
loader = S3DatasetLoader(config)
|
||||
loader.list_transcripts()
|
||||
except Exception:
|
||||
logger.exception("S3 connection failed")
|
||||
return False
|
||||
else:
|
||||
logger.info("S3 connection successful.")
|
||||
return True
|
||||
|
||||
|
||||
async def check_llm_provider(name: str, config: ProviderConfig, model_name: str) -> bool:
|
||||
"""Verify connectivity to an LLM provider."""
|
||||
try:
|
||||
logger.info("Checking LLM provider: %s...", name)
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=config.api_key,
|
||||
base_url=config.api_base,
|
||||
max_retries=0, # Fail fast
|
||||
timeout=5.0, # Fail fast
|
||||
)
|
||||
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("LLM provider %s connection failed", name)
|
||||
return False
|
||||
else:
|
||||
logger.info("LLM provider %s connection successful.", name)
|
||||
return True
|
||||
|
||||
|
||||
async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool:
|
||||
"""
|
||||
Run all pre-flight checks. Returns True if all checks pass.
|
||||
"""
|
||||
logger.info("--- Starting Pre-flight Checks ---")
|
||||
|
||||
checks = []
|
||||
|
||||
# 1. Check MongoDB
|
||||
checks.append(check_mongo(system_config))
|
||||
|
||||
# 2. Check S3
|
||||
checks.append(check_s3(system_config.get_s3_config()))
|
||||
|
||||
# 3. Check All LLM Providers
|
||||
for name, provider_config in run_config.providers.items():
|
||||
# Find a model used by this provider in the runs to use for the check
|
||||
model_name = next(
|
||||
(run.model.model_name for run in run_config.runs if run.model.provider == name),
|
||||
None,
|
||||
)
|
||||
|
||||
if model_name:
|
||||
checks.append(check_llm_provider(name, provider_config, model_name))
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping check for provider '%s': defined in config but not used in any run.",
|
||||
name,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*checks, return_exceptions=True)
|
||||
|
||||
# Analyze results
|
||||
all_passed = True
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Check failed with exception: %s", result)
|
||||
all_passed = False
|
||||
elif result is False:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
logger.info("--- All Pre-flight Checks Passed ---")
|
||||
else:
|
||||
logger.error("--- Pre-flight Checks Failed ---")
|
||||
|
||||
return all_passed
|
||||
Reference in New Issue
Block a user