WIP
This commit is contained in:
@@ -37,3 +37,6 @@ runs:
|
|||||||
model_name: "llama3"
|
model_name: "llama3"
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
prompt_id: "default"
|
prompt_id: "default"
|
||||||
|
|
||||||
|
# Optional: Limit the number of transcripts processed
|
||||||
|
limit: 5
|
||||||
|
|||||||
15
run_config.yaml
Normal file
15
run_config.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
providers:
|
||||||
|
openrouter:
|
||||||
|
api_key: "${OPENROUTER_API_KEY}"
|
||||||
|
api_base: "https://openrouter.ai/api/v1"
|
||||||
|
api_format: "openai"
|
||||||
|
|
||||||
|
runs:
|
||||||
|
- run_name: "gemini_flash"
|
||||||
|
model:
|
||||||
|
provider: openrouter
|
||||||
|
model_name: "google/gemini-3-flash-preview"
|
||||||
|
temperature: 1.0
|
||||||
|
prompt_id: "default"
|
||||||
|
|
||||||
|
limit: 5
|
||||||
@@ -88,6 +88,7 @@ class AssessBatchConfig(BaseModel):
|
|||||||
|
|
||||||
providers: dict[str, ProviderConfig]
|
providers: dict[str, ProviderConfig]
|
||||||
runs: list[RunSpec]
|
runs: list[RunSpec]
|
||||||
|
limit: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
class AgentConfig(BaseModel):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from helia.configuration import (
|
|||||||
)
|
)
|
||||||
from helia.db import init_db
|
from helia.db import init_db
|
||||||
from helia.ingestion.s3 import S3DatasetLoader
|
from helia.ingestion.s3 import S3DatasetLoader
|
||||||
|
from helia.preflight import check_all_connections
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -61,6 +62,10 @@ async def process_run(
|
|||||||
timestamp=datetime.now(tz=UTC).isoformat(),
|
timestamp=datetime.now(tz=UTC).isoformat(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DEBUG LOGGING
|
||||||
|
masked_key = provider_config.api_key[:4] + "..." if provider_config.api_key else "None"
|
||||||
|
logger.info("DEBUG: Using API Key: %s, Base URL: %s", masked_key, provider_config.api_base)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
evaluator = PHQ8Evaluator(run_config)
|
evaluator = PHQ8Evaluator(run_config)
|
||||||
# Await the async evaluation
|
# Await the async evaluation
|
||||||
@@ -91,6 +96,11 @@ async def main() -> None:
|
|||||||
default="config.yaml",
|
default="config.yaml",
|
||||||
help="Path to YAML configuration file (default: config.yaml)",
|
help="Path to YAML configuration file (default: config.yaml)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit",
|
||||||
|
type=int,
|
||||||
|
help="Limit the number of transcripts to process (overrides config)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config_path = Path(args.config)
|
config_path = Path(args.config)
|
||||||
@@ -109,6 +119,12 @@ async def main() -> None:
|
|||||||
|
|
||||||
# Check the type of configuration
|
# Check the type of configuration
|
||||||
if isinstance(run_config_data, AssessBatchConfig):
|
if isinstance(run_config_data, AssessBatchConfig):
|
||||||
|
# Run Pre-flight Checks
|
||||||
|
if not await check_all_connections(run_config_data, system_config):
|
||||||
|
logger.error("Pre-flight checks failed. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
||||||
await init_db(system_config)
|
await init_db(system_config)
|
||||||
|
|
||||||
# Create S3 config once and reuse
|
# Create S3 config once and reuse
|
||||||
@@ -119,6 +135,14 @@ async def main() -> None:
|
|||||||
loader = S3DatasetLoader(s3_config)
|
loader = S3DatasetLoader(s3_config)
|
||||||
keys = loader.list_transcripts()
|
keys = loader.list_transcripts()
|
||||||
|
|
||||||
|
# Apply limit if configured (Priority: CLI > Config)
|
||||||
|
limit = args.limit
|
||||||
|
limit = limit if limit is not None else run_config_data.limit
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
logger.info("Limiting processing to first %d transcripts", limit)
|
||||||
|
keys = keys[:limit]
|
||||||
|
|
||||||
# Create task list
|
# Create task list
|
||||||
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys]
|
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys]
|
||||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||||
|
|||||||
116
src/helia/preflight.py
Normal file
116
src/helia/preflight.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from helia.db import init_db
|
||||||
|
from helia.ingestion.s3 import S3DatasetLoader
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_mongo(config: SystemConfig) -> bool:
|
||||||
|
"""Verify MongoDB connectivity."""
|
||||||
|
try:
|
||||||
|
logger.info("Checking MongoDB connection...")
|
||||||
|
await init_db(config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("MongoDB connection failed")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("MongoDB connection successful.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def check_s3(config: S3Config) -> bool:
|
||||||
|
"""Verify S3 connectivity by listing objects."""
|
||||||
|
try:
|
||||||
|
logger.info("Checking S3 connection...")
|
||||||
|
loader = S3DatasetLoader(config)
|
||||||
|
loader.list_transcripts()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("S3 connection failed")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("S3 connection successful.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def check_llm_provider(name: str, config: ProviderConfig, model_name: str) -> bool:
|
||||||
|
"""Verify connectivity to an LLM provider."""
|
||||||
|
try:
|
||||||
|
logger.info("Checking LLM provider: %s...", name)
|
||||||
|
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
api_key=config.api_key,
|
||||||
|
base_url=config.api_base,
|
||||||
|
max_retries=0, # Fail fast
|
||||||
|
timeout=5.0, # Fail fast
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
|
max_tokens=1,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("LLM provider %s connection failed", name)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("LLM provider %s connection successful.", name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool:
|
||||||
|
"""
|
||||||
|
Run all pre-flight checks. Returns True if all checks pass.
|
||||||
|
"""
|
||||||
|
logger.info("--- Starting Pre-flight Checks ---")
|
||||||
|
|
||||||
|
checks = []
|
||||||
|
|
||||||
|
# 1. Check MongoDB
|
||||||
|
checks.append(check_mongo(system_config))
|
||||||
|
|
||||||
|
# 2. Check S3
|
||||||
|
checks.append(check_s3(system_config.get_s3_config()))
|
||||||
|
|
||||||
|
# 3. Check All LLM Providers
|
||||||
|
for name, provider_config in run_config.providers.items():
|
||||||
|
# Find a model used by this provider in the runs to use for the check
|
||||||
|
model_name = next(
|
||||||
|
(run.model.model_name for run in run_config.runs if run.model.provider == name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_name:
|
||||||
|
checks.append(check_llm_provider(name, provider_config, model_name))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping check for provider '%s': defined in config but not used in any run.",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*checks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Analyze results
|
||||||
|
all_passed = True
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error("Check failed with exception: %s", result)
|
||||||
|
all_passed = False
|
||||||
|
elif result is False:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
if all_passed:
|
||||||
|
logger.info("--- All Pre-flight Checks Passed ---")
|
||||||
|
else:
|
||||||
|
logger.error("--- Pre-flight Checks Failed ---")
|
||||||
|
|
||||||
|
return all_passed
|
||||||
Reference in New Issue
Block a user