diff --git a/.env.example b/.env.example deleted file mode 100644 index a3aba4e..0000000 --- a/.env.example +++ /dev/null @@ -1,17 +0,0 @@ -# Database Configuration -HELIA_MONGO_URI=mongodb://localhost:27017 -HELIA_DATABASE_NAME=helia - -# S3 Configuration (MinIO or AWS) -# Required for finding and downloading transcripts -HELIA_S3_ENDPOINT=https://s3.amazonaws.com -HELIA_S3_ACCESS_KEY=your_access_key -HELIA_S3_SECRET_KEY=your_secret_key -HELIA_S3_BUCKET=your-bucket-name -HELIA_S3_REGION=us-east-1 - -# LLM API Keys -# These are used by the run configuration YAML via ${VAR} substitution -OPENAI_API_KEY=sk-... -ANTHROPIC_API_KEY=sk-ant-... -OPENROUTER_API_KEY=sk-or-... diff --git a/.gitignore b/.gitignore index b525513..6e4e25d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ wheels/ *.egg-info .env +config.yaml # Virtual environments .venv diff --git a/README.md b/README.md index c5bc664..1cc6a94 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ src/helia/ ├── analysis/ │ └── extractor.py # Metadata extraction (LLM-agnostic) ├── assessment/ -│ ├── core.py # Clinical assessment logic (PHQ-8) +│ └── core.py # Clinical assessment logic (PHQ-8) │ └── schema.py # Data models (AssessmentResult, PHQ8Item) ├── ingestion/ │ └── parser.py # Transcript parsing (DAIC-WOZ support) @@ -50,6 +50,7 @@ graph TD - **Clinical Parsing**: Native support for DAIC-WOZ transcript formats. - **Structured Assessment**: Maps unstructured conversation to validatable PHQ-8 scores. - **Document Persistence**: Stores full experimental context (config + evidence + scores) in MongoDB using Beanie. +- **Unified Configuration**: Single YAML config with environment variable fallbacks for system settings. ## Roadmap @@ -67,22 +68,26 @@ uv sync ## Quick Start -1. **Environment Setup**: +1. **Configuration**: + Copy `example.config.yaml` to `config.yaml` and edit it. ```sh - export OPENAI_API_KEY=sk-... - # Ensure MongoDB is running (e.g., via Docker) + cp example.config.yaml config.yaml ``` + + You can set system settings (MongoDB, S3) in the YAML file, or use environment variables (e.g., `HELIA_MONGO_URI`, `HELIA_S3_BUCKET`). + Provider API keys can be set in YAML or via `{PROVIDER_NAME}_API_KEY` (uppercased), e.g. `ANTHROPIC_API_KEY` for `providers.anthropic`. 2. **Run an Assessment**: ```sh - python -m helia.main "assess --input data/transcript.tsv" + # Ensure MongoDB is running + uv run python -m helia.main config.yaml ``` ## Development - **Linting**: `uv run ruff check .` - **Formatting**: `uv run ruff format .` -- **Type Checking**: `uv run pyrefly` +- **Type Checking**: `uv run ty check` ## License diff --git a/example.run_config.yaml b/example.config.yaml similarity index 55% rename from example.run_config.yaml rename to example.config.yaml index d40ef54..e8b2f4f 100644 --- a/example.run_config.yaml +++ b/example.config.yaml @@ -1,25 +1,36 @@ -# Helia Run Configuration -# This file defines the "providers" (LLM connections) and the "runs" (experiments). -# Environment variables like ${OPENAI_API_KEY} are expanded at runtime. +# System Configuration +mongo_uri: "mongodb://localhost:27017" +database_name: "helia" + +# S3 Configuration (MinIO or AWS) +s3_endpoint: "https://s3.amazonaws.com" +s3_access_key: "your_access_key" +s3_secret_key: "your_secret_key" +s3_bucket: "your-bucket-name" +s3_prefix: "" +s3_region: "us-east-1" + +# Run Configuration +limit: 5 providers: openai: - api_key: "${OPENAI_API_KEY}" + # Set api_key here or export OPENAI_API_KEY api_base: "https://api.openai.com/v1" api_format: "openai" anthropic: - api_key: "${ANTHROPIC_API_KEY}" + # Set api_key here or export ANTHROPIC_API_KEY api_base: "https://api.anthropic.com/v1" api_format: "anthropic" openrouter: - api_key: "${OPENROUTER_API_KEY}" + # Set api_key here or export OPENROUTER_API_KEY api_base: "https://openrouter.ai/api/v1" api_format: "openai" local_ollama: - api_key: "none" + # API key optional for local_* providers api_base: "http://localhost:11434/v1" api_format: "ollama" @@ -37,6 +48,3 @@ runs: model_name: "llama3" temperature: 0.7 prompt_id: "default" - -# Optional: Limit the number of transcripts processed -limit: 5 diff --git a/run_config.yaml b/run_config.yaml deleted file mode 100644 index 436f256..0000000 --- a/run_config.yaml +++ /dev/null @@ -1,15 +0,0 @@ -providers: - openrouter: - api_key: "${OPENROUTER_API_KEY}" - api_base: "https://openrouter.ai/api/v1" - api_format: "openai" - -runs: - - run_name: "gemini_flash" - model: - provider: openrouter - model_name: "google/gemini-3-flash-preview" - temperature: 1.0 - prompt_id: "default" - -limit: 5 diff --git a/src/helia/agent/workflow.py b/src/helia/agent/workflow.py index 9757904..b5d2b05 100644 --- a/src/helia/agent/workflow.py +++ b/src/helia/agent/workflow.py @@ -5,9 +5,6 @@ from typing import Any from langgraph.graph import END, StateGraph from pydantic import BaseModel -from helia.llm.client import get_chat_model -from helia.llm.settings import settings - class AgentState(BaseModel): """State for the agent workflow.""" @@ -25,9 +22,8 @@ class AgentState(BaseModel): # proper protocol matching while maintaining runtime correctness. -def planner_node(state: AgentState): # noqa: ANN201 +def planner_node(state: AgentState): # noqa: ANN201, ARG001 """Plan the steps to answer the question.""" - _ = state plan: list[str] = ["Understand question", "Retrieve info", "Synthesize answer"] return {"plan": plan} @@ -57,40 +53,17 @@ def vector_tool_node(state: AgentState): # noqa: ANN201 def synthesizer_node(state: AgentState): # noqa: ANN201 """Synthesize an answer from the gathered context.""" context_text = "\n".join(state.context) - question = state.question - prompt = f""" - Answer the user's question based on the provided context. - - Context: - {context_text} - - Question: {question} - - Answer: - """ - - try: - llm = get_chat_model( - model_name=settings.model, - api_key=settings.resolve_api_key(), - base_url=settings.base_url, - ) - messages = [ - ("system", "You are a helpful assistant."), - ("user", prompt), - ] - response = llm.invoke(messages) - answer = str(response.content) - except Exception as e: - answer = f"Error generating answer: {e}. Fallback: Based on context: {context_text}, here is the answer." + answer = ( + "Agent workflow is placeholder and no longer loads LLM settings from env vars. " + f"Fallback: Based on context: {context_text}, here is the answer." + ) return {"answer": answer} -def reflector_node(state: AgentState): # noqa: ANN201 +def reflector_node(state: AgentState): # noqa: ANN201, ARG001 """Reflect on the quality of the answer.""" - _ = state return {"critique": "Answer appears sufficient."} diff --git a/src/helia/assessment/core.py b/src/helia/assessment/core.py index a5068e1..ee9ec74 100644 --- a/src/helia/assessment/core.py +++ b/src/helia/assessment/core.py @@ -55,11 +55,10 @@ class PHQ8Evaluator: self.config = config self.parser = TranscriptParser() - # Initialize LangChain Chat Model self.llm = get_chat_model( model_name=self.config.model_name, api_key=self.config.api_key, - base_url=self.config.api_base, + base_url=self.config.api_base or "", temperature=self.config.temperature, ) diff --git a/src/helia/configuration.py b/src/helia/configuration.py index 7fe3166..cafa67f 100644 --- a/src/helia/configuration.py +++ b/src/helia/configuration.py @@ -1,13 +1,18 @@ from __future__ import annotations import os -import re -from pathlib import Path -from typing import Literal, NamedTuple +from typing import TYPE_CHECKING, Literal, NamedTuple -import yaml -from pydantic import BaseModel, Field, TypeAdapter -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import BaseModel, Field, model_validator +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) + +if TYPE_CHECKING: + from pathlib import Path class S3Config(NamedTuple): @@ -21,42 +26,12 @@ class S3Config(NamedTuple): region_name: str | None = None -class SystemConfig(BaseSettings): - """ - System-level configuration loaded from environment variables. - Includes Database and AWS/S3 settings. - """ - - model_config = SettingsConfigDict(env_prefix="HELIA_", env_file=".env", extra="ignore") - - mongo_uri: str = Field(..., description="MongoDB connection string") - database_name: str = Field("helia", description="MongoDB database name") - - s3_endpoint: str = Field(..., description="S3 endpoint URL") - s3_access_key: str = Field(..., description="S3 access key") - s3_secret_key: str = Field(..., description="S3 secret key") - s3_bucket: str = Field(..., description="S3 bucket containing the dataset") - s3_prefix: str = Field("", description="S3 key prefix for dataset files") - s3_region: str | None = Field(None, description="S3 region name (optional)") - - def get_s3_config(self) -> S3Config: - """Create an S3Config from the system configuration.""" - return S3Config( - bucket_name=self.s3_bucket, - endpoint_url=self.s3_endpoint, - aws_access_key_id=self.s3_access_key, - aws_secret_access_key=self.s3_secret_key, - prefix=self.s3_prefix, - region_name=self.s3_region, - ) - - class ProviderConfig(BaseModel): """ - Configuration for an LLM provider + Configuration for an LLM provider. """ - api_key: str + api_key: str | None = None api_base: str api_format: Literal["openai", "anthropic", "ollama"] = "openai" @@ -81,50 +56,112 @@ class RunSpec(BaseModel): prompt_id: str = "default" -class AssessBatchConfig(BaseModel): +class HeliaConfig(BaseSettings): """ - Configuration file structure for batch assessment. + Unified configuration for Helia. + Loads from YAML first, then falls back to Environment variables for system settings. """ - providers: dict[str, ProviderConfig] - runs: list[RunSpec] - limit: int | None = None + # System Settings (Env Fallback available via HELIA_ prefix) + mongo_uri: str = Field(..., description="MongoDB connection string") + database_name: str = Field("helia", description="MongoDB database name") + + s3_endpoint: str = Field(..., description="S3 endpoint URL") + s3_access_key: str = Field(..., description="S3 access key") + s3_secret_key: str = Field(..., description="S3 secret key") + s3_bucket: str = Field(..., description="S3 bucket containing the dataset") + s3_prefix: str = Field("", description="S3 key prefix for dataset files") + s3_region: str | None = Field(None, description="S3 region name (optional)") + + # Run Settings + limit: int | None = Field(None, description="Limit the number of transcripts processed") + providers: dict[str, ProviderConfig] = Field(default_factory=dict) + runs: list[RunSpec] = Field(default_factory=list) + + model_config = SettingsConfigDict( + env_prefix="HELIA_", + env_file=None, # Disable .env file loading + extra="ignore", + ) + + def get_s3_config(self) -> S3Config: + """Create an S3Config from the system configuration.""" + return S3Config( + bucket_name=self.s3_bucket, + endpoint_url=self.s3_endpoint, + aws_access_key_id=self.s3_access_key, + aws_secret_access_key=self.s3_secret_key, + prefix=self.s3_prefix, + region_name=self.s3_region, + ) + + @model_validator(mode="after") + def resolve_provider_api_keys(self) -> HeliaConfig: + """Resolve provider API keys from env vars as fallback. + + - If `providers..api_key` is set in YAML, it is used. + - Otherwise, the loader tries `{NAME}_API_KEY` (uppercased provider name). + - Providers whose name starts with `local_` may omit an API key. + """ + for name, provider in self.providers.items(): + if provider.api_key: + continue + + env_var_name = f"{name.upper()}_API_KEY" + if env_key := os.environ.get(env_var_name): + provider.api_key = env_key + + return self + + @model_validator(mode="after") + def validate_provider_api_keys(self) -> HeliaConfig: + """Enforce API key presence for non-local providers used by runs.""" + used_providers = {run.model.provider for run in self.runs} + + for provider_name in used_providers: + provider = self.providers.get(provider_name) + if provider is None: + msg = f"Provider '{provider_name}' is used in runs but not configured under 'providers'." + raise ValueError(msg) + + if provider_name.startswith("local_"): + continue + + if not provider.api_key: + env_var_name = f"{provider_name.upper()}_API_KEY" + msg = ( + f"Missing API key for provider '{provider_name}'. " + f"Set providers.{provider_name}.api_key in YAML or export {env_var_name}." + ) + raise ValueError(msg) + + return self -class AgentConfig(BaseModel): - # Placeholder for future agent config - command: Literal["agent"] = "agent" -ConfigType = AssessBatchConfig - - -def _expand_env_vars(yaml_content: str) -> str: +def load_config(path: str | Path) -> HeliaConfig: """ - Expand environment variables in the format ${VAR} or ${VAR:default}. + Load configuration from a YAML file, with environment variable fallback. """ - pattern = re.compile(r"\$\{([^}^{]+)\}") - def replace(match: re.Match) -> str: - env_var = match.group(1) - default_value = "" - if ":" in env_var: - env_var, default_value = env_var.split(":", 1) - return os.environ.get(env_var, default_value) + # Create a dynamic subclass to bind the specific yaml_file path + class RuntimeHeliaConfig(HeliaConfig): + model_config = SettingsConfigDict(yaml_file=path) - return pattern.sub(replace, yaml_content) + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, # noqa: ARG003 + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003 + file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003 + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + YamlConfigSettingsSource(settings_cls), + env_settings, + # No dotenv, no init_settings needed usually + ) - -def load_config(path: str | Path) -> ConfigType: - with Path(path).open() as f: - content = f.read() - - content = _expand_env_vars(content) - data = yaml.safe_load(content) - - adapter = TypeAdapter(ConfigType) - return adapter.validate_python(data) - - -def load_system_config() -> SystemConfig: - return SystemConfig() + return RuntimeHeliaConfig() # type: ignore[call-arg] diff --git a/src/helia/db.py b/src/helia/db.py index 2edf920..180df87 100644 --- a/src/helia/db.py +++ b/src/helia/db.py @@ -8,10 +8,10 @@ from motor.motor_asyncio import AsyncIOMotorClient from helia.assessment.schema import AssessmentResult if TYPE_CHECKING: - from helia.configuration import SystemConfig + from helia.configuration import HeliaConfig -async def init_db(config: SystemConfig) -> None: +async def init_db(config: HeliaConfig) -> None: client = AsyncIOMotorClient(config.mongo_uri) await init_beanie( database=client[config.database_name], # type: ignore[arg-type] diff --git a/src/helia/llm/__init__.py b/src/helia/llm/__init__.py index f1aa122..ea20460 100644 --- a/src/helia/llm/__init__.py +++ b/src/helia/llm/__init__.py @@ -1,4 +1,3 @@ -from helia.llm.client import get_openai_client -from helia.llm.settings import settings +from helia.llm.client import get_chat_model -__all__ = ["get_openai_client", "settings"] +__all__ = ["get_chat_model"] diff --git a/src/helia/llm/client.py b/src/helia/llm/client.py index 311b0d1..232bd72 100644 --- a/src/helia/llm/client.py +++ b/src/helia/llm/client.py @@ -3,55 +3,20 @@ from __future__ import annotations from typing import TYPE_CHECKING from langchain_openai import ChatOpenAI -from openai import AsyncOpenAI, OpenAI from pydantic import SecretStr -from helia.llm.settings import settings - if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel -def get_openai_client() -> OpenAI: - """ - Returns an configured OpenAI client based on global settings. - Defaults to OpenRouter base_url if not specified otherwise. - """ - api_key = settings.resolve_api_key() - - return OpenAI( - base_url=settings.base_url, - api_key=api_key, - timeout=settings.timeout, - max_retries=settings.max_retries, - ) - - -def get_async_openai_client() -> AsyncOpenAI: - """ - Returns a configured AsyncOpenAI client based on global settings. - """ - api_key = settings.resolve_api_key() - - return AsyncOpenAI( - base_url=settings.base_url, - api_key=api_key, - timeout=settings.timeout, - max_retries=settings.max_retries, - ) - - def get_chat_model( model_name: str, - api_key: str | None = None, - base_url: str | None = None, + api_key: str | None, + base_url: str, temperature: float = 0.0, max_retries: int = 3, ) -> BaseChatModel: - """ - Returns a configured LangChain ChatOpenAI instance. - Supports OpenRouter, Ollama, and OpenAI via base_url. - """ + """Return a configured LangChain ChatOpenAI instance.""" return ChatOpenAI( model_name=model_name, openai_api_key=SecretStr(api_key or ""), diff --git a/src/helia/llm/settings.py b/src/helia/llm/settings.py deleted file mode 100644 index 3c8cdb2..0000000 --- a/src/helia/llm/settings.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -from typing import Final - -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class LLMSettings(BaseSettings): - """ - Configuration for LLM clients, defaulting to OpenRouter. - """ - - api_key: str | None = Field( - default=None, - description="API key for the LLM provider. Checks HELIA_LLM_API_KEY, OPENROUTER_API_KEY, then OPENAI_API_KEY.", - ) - base_url: str = Field( - default="https://openrouter.ai/api/v1", - description="Base URL for the LLM provider. Defaults to OpenRouter.", - ) - model: str = Field( - default="google/gemini-3.0-pro-preview", - description="Model identifier to use.", - ) - timeout: float = Field( - default=30.0, - description="Request timeout in seconds.", - ) - max_retries: int = Field( - default=2, - description="Maximum number of retries for failed requests.", - ) - - model_config = SettingsConfigDict( - env_prefix="HELIA_LLM_", - case_sensitive=False, - extra="ignore", - ) - - def resolve_api_key(self) -> str: - """ - Resolves the API key with a fallback strategy: - 1. configured api_key (from HELIA_LLM_API_KEY) - 2. OPENROUTER_API_KEY env var - 3. OPENAI_API_KEY env var - 4. Raise ValueError if none found - """ - if self.api_key: - return self.api_key - - # Fallback 1: OpenRouter - if key := os.environ.get("OPENROUTER_API_KEY"): - return key - - # Fallback 2: OpenAI - if key := os.environ.get("OPENAI_API_KEY"): - return key - - raise ValueError( - "No API key found. Please set HELIA_LLM_API_KEY, OPENROUTER_API_KEY, or OPENAI_API_KEY." - ) - - -# Singleton instance for easy import -settings: Final[LLMSettings] = LLMSettings() diff --git a/src/helia/main.py b/src/helia/main.py index 8131635..036e2ea 100644 --- a/src/helia/main.py +++ b/src/helia/main.py @@ -4,15 +4,13 @@ import logging from datetime import UTC, datetime from pathlib import Path -from helia.agent.workflow import run_agent from helia.assessment.core import PHQ8Evaluator from helia.assessment.schema import RunConfig from helia.configuration import ( - AssessBatchConfig, + HeliaConfig, RunSpec, S3Config, load_config, - load_system_config, ) from helia.db import init_db from helia.ingestion.s3 import S3DatasetLoader @@ -24,7 +22,7 @@ logger = logging.getLogger(__name__) async def process_run( run_spec: RunSpec, input_source: str, - run_config_data: AssessBatchConfig, + config: HeliaConfig, s3_config: S3Config, semaphore: asyncio.Semaphore, ) -> None: @@ -34,11 +32,11 @@ async def process_run( async with semaphore: # Resolve Provider provider_name = run_spec.model.provider - if provider_name not in run_config_data.providers: + if provider_name not in config.providers: logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name) return - provider_config = run_config_data.providers[provider_name] + provider_config = config.providers[provider_name] # Download from S3 (Async) loader = S3DatasetLoader(s3_config) @@ -111,61 +109,50 @@ async def main() -> None: return try: - run_config_data = load_config(config_path) - system_config = load_system_config() + config = load_config(config_path) except Exception: logger.exception("Error loading configuration") return - # Check the type of configuration - if isinstance(run_config_data, AssessBatchConfig): - # Run Pre-flight Checks - if not await check_all_connections(run_config_data, system_config): - logger.error("Pre-flight checks failed. Exiting.") - return + # Run Pre-flight Checks + if not await check_all_connections(config): + logger.error("Pre-flight checks failed. Exiting.") + return - # Ensure DB is initialized (redundant if pre-flight passed, but safe) - await init_db(system_config) + # Ensure DB is initialized (redundant if pre-flight passed, but safe) + await init_db(config) - # Create S3 config once and reuse - s3_config = system_config.get_s3_config() + # Create S3 config once and reuse + s3_config = config.get_s3_config() - # Discover transcripts (can remain sync or be made async, sync is fine for listing) - logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name) - loader = S3DatasetLoader(s3_config) - keys = loader.list_transcripts() + # Discover transcripts (can remain sync or be made async, sync is fine for listing) + logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name) + loader = S3DatasetLoader(s3_config) + keys = loader.list_transcripts() - # Apply limit if configured (Priority: CLI > Config) - limit = args.limit - limit = limit if limit is not None else run_config_data.limit + # Apply limit if configured (Priority: CLI > Config) + limit = args.limit + limit = limit if limit is not None else config.limit - if limit is not None: - logger.info("Limiting processing to first %d transcripts", limit) - keys = keys[:limit] + if limit is not None: + logger.info("Limiting processing to first %d transcripts", limit) + keys = keys[:limit] - # Create task list - tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys] - logger.info("Starting batch assessment with %d total items...", len(tasks_data)) + # Create task list + tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys] + logger.info("Starting batch assessment with %d total items...", len(tasks_data)) - # Limit concurrency to 10 parallel requests - semaphore = asyncio.Semaphore(10) + # Limit concurrency to 10 parallel requests + semaphore = asyncio.Semaphore(10) - tasks = [ - process_run(run_spec, key, run_config_data, s3_config, semaphore) - for run_spec, key in tasks_data - ] + tasks = [ + process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data + ] - # Run all tasks concurrently - await asyncio.gather(*tasks) + # Run all tasks concurrently + await asyncio.gather(*tasks) - logger.info("Batch assessment complete.") - - else: - # Agent command (Placeholder) - question = run_config_data.question - logger.info("\nRunning Re-Agent with question: '%s'\n", question) - result = run_agent(question) - logger.info(result["answer"]) + logger.info("Batch assessment complete.") if __name__ == "__main__": diff --git a/src/helia/preflight.py b/src/helia/preflight.py index 88bd49d..1157264 100644 --- a/src/helia/preflight.py +++ b/src/helia/preflight.py @@ -10,12 +10,12 @@ from helia.db import init_db from helia.ingestion.s3 import S3DatasetLoader if TYPE_CHECKING: - from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig + from helia.configuration import HeliaConfig, ProviderConfig, S3Config logger = logging.getLogger(__name__) -async def check_mongo(config: SystemConfig) -> bool: +async def check_mongo(config: HeliaConfig) -> bool: """Verify MongoDB connectivity.""" try: logger.info("Checking MongoDB connection...") @@ -67,7 +67,7 @@ async def check_llm_provider(name: str, config: ProviderConfig, model_name: str) return True -async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool: +async def check_all_connections(config: HeliaConfig) -> bool: """ Run all pre-flight checks. Returns True if all checks pass. """ @@ -76,16 +76,16 @@ async def check_all_connections(run_config: AssessBatchConfig, system_config: Sy checks = [] # 1. Check MongoDB - checks.append(check_mongo(system_config)) + checks.append(check_mongo(config)) # 2. Check S3 - checks.append(check_s3(system_config.get_s3_config())) + checks.append(check_s3(config.get_s3_config())) # 3. Check All LLM Providers - for name, provider_config in run_config.providers.items(): + for name, provider_config in config.providers.items(): # Find a model used by this provider in the runs to use for the check model_name = next( - (run.model.model_name for run in run_config.runs if run.model.provider == name), + (run.model.model_name for run in config.runs if run.model.provider == name), None, )