Refactor configuration management; remove deprecated files and unify settings in YAML format

This commit is contained in:
Santiago Martinez-Avial
2025-12-21 17:26:33 +01:00
parent 966e9d547a
commit f860d17206
14 changed files with 195 additions and 318 deletions

View File

@@ -1,17 +0,0 @@
# Database Configuration
HELIA_MONGO_URI=mongodb://localhost:27017
HELIA_DATABASE_NAME=helia
# S3 Configuration (MinIO or AWS)
# Required for finding and downloading transcripts
HELIA_S3_ENDPOINT=https://s3.amazonaws.com
HELIA_S3_ACCESS_KEY=your_access_key
HELIA_S3_SECRET_KEY=your_secret_key
HELIA_S3_BUCKET=your-bucket-name
HELIA_S3_REGION=us-east-1
# LLM API Keys
# These are used by the run configuration YAML via ${VAR} substitution
OPENAI_API_KEY=sk-...
ANTHROPIC_API_KEY=sk-ant-...
OPENROUTER_API_KEY=sk-or-...

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@ wheels/
*.egg-info *.egg-info
.env .env
config.yaml
# Virtual environments # Virtual environments
.venv .venv

View File

@@ -17,7 +17,7 @@ src/helia/
├── analysis/ ├── analysis/
│ └── extractor.py # Metadata extraction (LLM-agnostic) │ └── extractor.py # Metadata extraction (LLM-agnostic)
├── assessment/ ├── assessment/
── core.py # Clinical assessment logic (PHQ-8) ── core.py # Clinical assessment logic (PHQ-8)
│ └── schema.py # Data models (AssessmentResult, PHQ8Item) │ └── schema.py # Data models (AssessmentResult, PHQ8Item)
├── ingestion/ ├── ingestion/
│ └── parser.py # Transcript parsing (DAIC-WOZ support) │ └── parser.py # Transcript parsing (DAIC-WOZ support)
@@ -50,6 +50,7 @@ graph TD
- **Clinical Parsing**: Native support for DAIC-WOZ transcript formats. - **Clinical Parsing**: Native support for DAIC-WOZ transcript formats.
- **Structured Assessment**: Maps unstructured conversation to validatable PHQ-8 scores. - **Structured Assessment**: Maps unstructured conversation to validatable PHQ-8 scores.
- **Document Persistence**: Stores full experimental context (config + evidence + scores) in MongoDB using Beanie. - **Document Persistence**: Stores full experimental context (config + evidence + scores) in MongoDB using Beanie.
- **Unified Configuration**: Single YAML config with environment variable fallbacks for system settings.
## Roadmap ## Roadmap
@@ -67,22 +68,26 @@ uv sync
## Quick Start ## Quick Start
1. **Environment Setup**: 1. **Configuration**:
Copy `example.config.yaml` to `config.yaml` and edit it.
```sh ```sh
export OPENAI_API_KEY=sk-... cp example.config.yaml config.yaml
# Ensure MongoDB is running (e.g., via Docker)
``` ```
You can set system settings (MongoDB, S3) in the YAML file, or use environment variables (e.g., `HELIA_MONGO_URI`, `HELIA_S3_BUCKET`).
Provider API keys can be set in YAML or via `{PROVIDER_NAME}_API_KEY` (uppercased), e.g. `ANTHROPIC_API_KEY` for `providers.anthropic`.
2. **Run an Assessment**: 2. **Run an Assessment**:
```sh ```sh
python -m helia.main "assess --input data/transcript.tsv" # Ensure MongoDB is running
uv run python -m helia.main config.yaml
``` ```
## Development ## Development
- **Linting**: `uv run ruff check .` - **Linting**: `uv run ruff check .`
- **Formatting**: `uv run ruff format .` - **Formatting**: `uv run ruff format .`
- **Type Checking**: `uv run pyrefly` - **Type Checking**: `uv run ty check`
## License ## License

View File

@@ -1,25 +1,36 @@
# Helia Run Configuration # System Configuration
# This file defines the "providers" (LLM connections) and the "runs" (experiments). mongo_uri: "mongodb://localhost:27017"
# Environment variables like ${OPENAI_API_KEY} are expanded at runtime. database_name: "helia"
# S3 Configuration (MinIO or AWS)
s3_endpoint: "https://s3.amazonaws.com"
s3_access_key: "your_access_key"
s3_secret_key: "your_secret_key"
s3_bucket: "your-bucket-name"
s3_prefix: ""
s3_region: "us-east-1"
# Run Configuration
limit: 5
providers: providers:
openai: openai:
api_key: "${OPENAI_API_KEY}" # Set api_key here or export OPENAI_API_KEY
api_base: "https://api.openai.com/v1" api_base: "https://api.openai.com/v1"
api_format: "openai" api_format: "openai"
anthropic: anthropic:
api_key: "${ANTHROPIC_API_KEY}" # Set api_key here or export ANTHROPIC_API_KEY
api_base: "https://api.anthropic.com/v1" api_base: "https://api.anthropic.com/v1"
api_format: "anthropic" api_format: "anthropic"
openrouter: openrouter:
api_key: "${OPENROUTER_API_KEY}" # Set api_key here or export OPENROUTER_API_KEY
api_base: "https://openrouter.ai/api/v1" api_base: "https://openrouter.ai/api/v1"
api_format: "openai" api_format: "openai"
local_ollama: local_ollama:
api_key: "none" # API key optional for local_* providers
api_base: "http://localhost:11434/v1" api_base: "http://localhost:11434/v1"
api_format: "ollama" api_format: "ollama"
@@ -37,6 +48,3 @@ runs:
model_name: "llama3" model_name: "llama3"
temperature: 0.7 temperature: 0.7
prompt_id: "default" prompt_id: "default"
# Optional: Limit the number of transcripts processed
limit: 5

View File

@@ -1,15 +0,0 @@
providers:
openrouter:
api_key: "${OPENROUTER_API_KEY}"
api_base: "https://openrouter.ai/api/v1"
api_format: "openai"
runs:
- run_name: "gemini_flash"
model:
provider: openrouter
model_name: "google/gemini-3-flash-preview"
temperature: 1.0
prompt_id: "default"
limit: 5

View File

@@ -5,9 +5,6 @@ from typing import Any
from langgraph.graph import END, StateGraph from langgraph.graph import END, StateGraph
from pydantic import BaseModel from pydantic import BaseModel
from helia.llm.client import get_chat_model
from helia.llm.settings import settings
class AgentState(BaseModel): class AgentState(BaseModel):
"""State for the agent workflow.""" """State for the agent workflow."""
@@ -25,9 +22,8 @@ class AgentState(BaseModel):
# proper protocol matching while maintaining runtime correctness. # proper protocol matching while maintaining runtime correctness.
def planner_node(state: AgentState): # noqa: ANN201 def planner_node(state: AgentState): # noqa: ANN201, ARG001
"""Plan the steps to answer the question.""" """Plan the steps to answer the question."""
_ = state
plan: list[str] = ["Understand question", "Retrieve info", "Synthesize answer"] plan: list[str] = ["Understand question", "Retrieve info", "Synthesize answer"]
return {"plan": plan} return {"plan": plan}
@@ -57,40 +53,17 @@ def vector_tool_node(state: AgentState): # noqa: ANN201
def synthesizer_node(state: AgentState): # noqa: ANN201 def synthesizer_node(state: AgentState): # noqa: ANN201
"""Synthesize an answer from the gathered context.""" """Synthesize an answer from the gathered context."""
context_text = "\n".join(state.context) context_text = "\n".join(state.context)
question = state.question
prompt = f""" answer = (
Answer the user's question based on the provided context. "Agent workflow is placeholder and no longer loads LLM settings from env vars. "
f"Fallback: Based on context: {context_text}, here is the answer."
Context: )
{context_text}
Question: {question}
Answer:
"""
try:
llm = get_chat_model(
model_name=settings.model,
api_key=settings.resolve_api_key(),
base_url=settings.base_url,
)
messages = [
("system", "You are a helpful assistant."),
("user", prompt),
]
response = llm.invoke(messages)
answer = str(response.content)
except Exception as e:
answer = f"Error generating answer: {e}. Fallback: Based on context: {context_text}, here is the answer."
return {"answer": answer} return {"answer": answer}
def reflector_node(state: AgentState): # noqa: ANN201 def reflector_node(state: AgentState): # noqa: ANN201, ARG001
"""Reflect on the quality of the answer.""" """Reflect on the quality of the answer."""
_ = state
return {"critique": "Answer appears sufficient."} return {"critique": "Answer appears sufficient."}

View File

@@ -55,11 +55,10 @@ class PHQ8Evaluator:
self.config = config self.config = config
self.parser = TranscriptParser() self.parser = TranscriptParser()
# Initialize LangChain Chat Model
self.llm = get_chat_model( self.llm = get_chat_model(
model_name=self.config.model_name, model_name=self.config.model_name,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.api_base, base_url=self.config.api_base or "",
temperature=self.config.temperature, temperature=self.config.temperature,
) )

View File

@@ -1,13 +1,18 @@
from __future__ import annotations from __future__ import annotations
import os import os
import re from typing import TYPE_CHECKING, Literal, NamedTuple
from pathlib import Path
from typing import Literal, NamedTuple
import yaml from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, TypeAdapter from pydantic_settings import (
from pydantic_settings import BaseSettings, SettingsConfigDict BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
if TYPE_CHECKING:
from pathlib import Path
class S3Config(NamedTuple): class S3Config(NamedTuple):
@@ -21,42 +26,12 @@ class S3Config(NamedTuple):
region_name: str | None = None region_name: str | None = None
class SystemConfig(BaseSettings):
"""
System-level configuration loaded from environment variables.
Includes Database and AWS/S3 settings.
"""
model_config = SettingsConfigDict(env_prefix="HELIA_", env_file=".env", extra="ignore")
mongo_uri: str = Field(..., description="MongoDB connection string")
database_name: str = Field("helia", description="MongoDB database name")
s3_endpoint: str = Field(..., description="S3 endpoint URL")
s3_access_key: str = Field(..., description="S3 access key")
s3_secret_key: str = Field(..., description="S3 secret key")
s3_bucket: str = Field(..., description="S3 bucket containing the dataset")
s3_prefix: str = Field("", description="S3 key prefix for dataset files")
s3_region: str | None = Field(None, description="S3 region name (optional)")
def get_s3_config(self) -> S3Config:
"""Create an S3Config from the system configuration."""
return S3Config(
bucket_name=self.s3_bucket,
endpoint_url=self.s3_endpoint,
aws_access_key_id=self.s3_access_key,
aws_secret_access_key=self.s3_secret_key,
prefix=self.s3_prefix,
region_name=self.s3_region,
)
class ProviderConfig(BaseModel): class ProviderConfig(BaseModel):
""" """
Configuration for an LLM provider Configuration for an LLM provider.
""" """
api_key: str api_key: str | None = None
api_base: str api_base: str
api_format: Literal["openai", "anthropic", "ollama"] = "openai" api_format: Literal["openai", "anthropic", "ollama"] = "openai"
@@ -81,50 +56,112 @@ class RunSpec(BaseModel):
prompt_id: str = "default" prompt_id: str = "default"
class AssessBatchConfig(BaseModel): class HeliaConfig(BaseSettings):
""" """
Configuration file structure for batch assessment. Unified configuration for Helia.
Loads from YAML first, then falls back to Environment variables for system settings.
""" """
providers: dict[str, ProviderConfig] # System Settings (Env Fallback available via HELIA_ prefix)
runs: list[RunSpec] mongo_uri: str = Field(..., description="MongoDB connection string")
limit: int | None = None database_name: str = Field("helia", description="MongoDB database name")
s3_endpoint: str = Field(..., description="S3 endpoint URL")
s3_access_key: str = Field(..., description="S3 access key")
s3_secret_key: str = Field(..., description="S3 secret key")
s3_bucket: str = Field(..., description="S3 bucket containing the dataset")
s3_prefix: str = Field("", description="S3 key prefix for dataset files")
s3_region: str | None = Field(None, description="S3 region name (optional)")
# Run Settings
limit: int | None = Field(None, description="Limit the number of transcripts processed")
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
runs: list[RunSpec] = Field(default_factory=list)
model_config = SettingsConfigDict(
env_prefix="HELIA_",
env_file=None, # Disable .env file loading
extra="ignore",
)
def get_s3_config(self) -> S3Config:
"""Create an S3Config from the system configuration."""
return S3Config(
bucket_name=self.s3_bucket,
endpoint_url=self.s3_endpoint,
aws_access_key_id=self.s3_access_key,
aws_secret_access_key=self.s3_secret_key,
prefix=self.s3_prefix,
region_name=self.s3_region,
)
@model_validator(mode="after")
def resolve_provider_api_keys(self) -> HeliaConfig:
"""Resolve provider API keys from env vars as fallback.
- If `providers.<name>.api_key` is set in YAML, it is used.
- Otherwise, the loader tries `{NAME}_API_KEY` (uppercased provider name).
- Providers whose name starts with `local_` may omit an API key.
"""
for name, provider in self.providers.items():
if provider.api_key:
continue
env_var_name = f"{name.upper()}_API_KEY"
if env_key := os.environ.get(env_var_name):
provider.api_key = env_key
return self
@model_validator(mode="after")
def validate_provider_api_keys(self) -> HeliaConfig:
"""Enforce API key presence for non-local providers used by runs."""
used_providers = {run.model.provider for run in self.runs}
for provider_name in used_providers:
provider = self.providers.get(provider_name)
if provider is None:
msg = f"Provider '{provider_name}' is used in runs but not configured under 'providers'."
raise ValueError(msg)
if provider_name.startswith("local_"):
continue
if not provider.api_key:
env_var_name = f"{provider_name.upper()}_API_KEY"
msg = (
f"Missing API key for provider '{provider_name}'. "
f"Set providers.{provider_name}.api_key in YAML or export {env_var_name}."
)
raise ValueError(msg)
return self
class AgentConfig(BaseModel):
# Placeholder for future agent config
command: Literal["agent"] = "agent"
ConfigType = AssessBatchConfig def load_config(path: str | Path) -> HeliaConfig:
def _expand_env_vars(yaml_content: str) -> str:
""" """
Expand environment variables in the format ${VAR} or ${VAR:default}. Load configuration from a YAML file, with environment variable fallback.
""" """
pattern = re.compile(r"\$\{([^}^{]+)\}")
def replace(match: re.Match) -> str: # Create a dynamic subclass to bind the specific yaml_file path
env_var = match.group(1) class RuntimeHeliaConfig(HeliaConfig):
default_value = "" model_config = SettingsConfigDict(yaml_file=path)
if ":" in env_var:
env_var, default_value = env_var.split(":", 1)
return os.environ.get(env_var, default_value)
return pattern.sub(replace, yaml_content) @classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
YamlConfigSettingsSource(settings_cls),
env_settings,
# No dotenv, no init_settings needed usually
)
return RuntimeHeliaConfig() # type: ignore[call-arg]
def load_config(path: str | Path) -> ConfigType:
with Path(path).open() as f:
content = f.read()
content = _expand_env_vars(content)
data = yaml.safe_load(content)
adapter = TypeAdapter(ConfigType)
return adapter.validate_python(data)
def load_system_config() -> SystemConfig:
return SystemConfig()

View File

@@ -8,10 +8,10 @@ from motor.motor_asyncio import AsyncIOMotorClient
from helia.assessment.schema import AssessmentResult from helia.assessment.schema import AssessmentResult
if TYPE_CHECKING: if TYPE_CHECKING:
from helia.configuration import SystemConfig from helia.configuration import HeliaConfig
async def init_db(config: SystemConfig) -> None: async def init_db(config: HeliaConfig) -> None:
client = AsyncIOMotorClient(config.mongo_uri) client = AsyncIOMotorClient(config.mongo_uri)
await init_beanie( await init_beanie(
database=client[config.database_name], # type: ignore[arg-type] database=client[config.database_name], # type: ignore[arg-type]

View File

@@ -1,4 +1,3 @@
from helia.llm.client import get_openai_client from helia.llm.client import get_chat_model
from helia.llm.settings import settings
__all__ = ["get_openai_client", "settings"] __all__ = ["get_chat_model"]

View File

@@ -3,55 +3,20 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from openai import AsyncOpenAI, OpenAI
from pydantic import SecretStr from pydantic import SecretStr
from helia.llm.settings import settings
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
def get_openai_client() -> OpenAI:
"""
Returns an configured OpenAI client based on global settings.
Defaults to OpenRouter base_url if not specified otherwise.
"""
api_key = settings.resolve_api_key()
return OpenAI(
base_url=settings.base_url,
api_key=api_key,
timeout=settings.timeout,
max_retries=settings.max_retries,
)
def get_async_openai_client() -> AsyncOpenAI:
"""
Returns a configured AsyncOpenAI client based on global settings.
"""
api_key = settings.resolve_api_key()
return AsyncOpenAI(
base_url=settings.base_url,
api_key=api_key,
timeout=settings.timeout,
max_retries=settings.max_retries,
)
def get_chat_model( def get_chat_model(
model_name: str, model_name: str,
api_key: str | None = None, api_key: str | None,
base_url: str | None = None, base_url: str,
temperature: float = 0.0, temperature: float = 0.0,
max_retries: int = 3, max_retries: int = 3,
) -> BaseChatModel: ) -> BaseChatModel:
""" """Return a configured LangChain ChatOpenAI instance."""
Returns a configured LangChain ChatOpenAI instance.
Supports OpenRouter, Ollama, and OpenAI via base_url.
"""
return ChatOpenAI( return ChatOpenAI(
model_name=model_name, model_name=model_name,
openai_api_key=SecretStr(api_key or ""), openai_api_key=SecretStr(api_key or ""),

View File

@@ -1,65 +0,0 @@
import os
from typing import Final
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class LLMSettings(BaseSettings):
"""
Configuration for LLM clients, defaulting to OpenRouter.
"""
api_key: str | None = Field(
default=None,
description="API key for the LLM provider. Checks HELIA_LLM_API_KEY, OPENROUTER_API_KEY, then OPENAI_API_KEY.",
)
base_url: str = Field(
default="https://openrouter.ai/api/v1",
description="Base URL for the LLM provider. Defaults to OpenRouter.",
)
model: str = Field(
default="google/gemini-3.0-pro-preview",
description="Model identifier to use.",
)
timeout: float = Field(
default=30.0,
description="Request timeout in seconds.",
)
max_retries: int = Field(
default=2,
description="Maximum number of retries for failed requests.",
)
model_config = SettingsConfigDict(
env_prefix="HELIA_LLM_",
case_sensitive=False,
extra="ignore",
)
def resolve_api_key(self) -> str:
"""
Resolves the API key with a fallback strategy:
1. configured api_key (from HELIA_LLM_API_KEY)
2. OPENROUTER_API_KEY env var
3. OPENAI_API_KEY env var
4. Raise ValueError if none found
"""
if self.api_key:
return self.api_key
# Fallback 1: OpenRouter
if key := os.environ.get("OPENROUTER_API_KEY"):
return key
# Fallback 2: OpenAI
if key := os.environ.get("OPENAI_API_KEY"):
return key
raise ValueError(
"No API key found. Please set HELIA_LLM_API_KEY, OPENROUTER_API_KEY, or OPENAI_API_KEY."
)
# Singleton instance for easy import
settings: Final[LLMSettings] = LLMSettings()

View File

@@ -4,15 +4,13 @@ import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from helia.agent.workflow import run_agent
from helia.assessment.core import PHQ8Evaluator from helia.assessment.core import PHQ8Evaluator
from helia.assessment.schema import RunConfig from helia.assessment.schema import RunConfig
from helia.configuration import ( from helia.configuration import (
AssessBatchConfig, HeliaConfig,
RunSpec, RunSpec,
S3Config, S3Config,
load_config, load_config,
load_system_config,
) )
from helia.db import init_db from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader from helia.ingestion.s3 import S3DatasetLoader
@@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
async def process_run( async def process_run(
run_spec: RunSpec, run_spec: RunSpec,
input_source: str, input_source: str,
run_config_data: AssessBatchConfig, config: HeliaConfig,
s3_config: S3Config, s3_config: S3Config,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
) -> None: ) -> None:
@@ -34,11 +32,11 @@ async def process_run(
async with semaphore: async with semaphore:
# Resolve Provider # Resolve Provider
provider_name = run_spec.model.provider provider_name = run_spec.model.provider
if provider_name not in run_config_data.providers: if provider_name not in config.providers:
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name) logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
return return
provider_config = run_config_data.providers[provider_name] provider_config = config.providers[provider_name]
# Download from S3 (Async) # Download from S3 (Async)
loader = S3DatasetLoader(s3_config) loader = S3DatasetLoader(s3_config)
@@ -111,61 +109,50 @@ async def main() -> None:
return return
try: try:
run_config_data = load_config(config_path) config = load_config(config_path)
system_config = load_system_config()
except Exception: except Exception:
logger.exception("Error loading configuration") logger.exception("Error loading configuration")
return return
# Check the type of configuration # Run Pre-flight Checks
if isinstance(run_config_data, AssessBatchConfig): if not await check_all_connections(config):
# Run Pre-flight Checks logger.error("Pre-flight checks failed. Exiting.")
if not await check_all_connections(run_config_data, system_config): return
logger.error("Pre-flight checks failed. Exiting.")
return
# Ensure DB is initialized (redundant if pre-flight passed, but safe) # Ensure DB is initialized (redundant if pre-flight passed, but safe)
await init_db(system_config) await init_db(config)
# Create S3 config once and reuse # Create S3 config once and reuse
s3_config = system_config.get_s3_config() s3_config = config.get_s3_config()
# Discover transcripts (can remain sync or be made async, sync is fine for listing) # Discover transcripts (can remain sync or be made async, sync is fine for listing)
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name) logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
loader = S3DatasetLoader(s3_config) loader = S3DatasetLoader(s3_config)
keys = loader.list_transcripts() keys = loader.list_transcripts()
# Apply limit if configured (Priority: CLI > Config) # Apply limit if configured (Priority: CLI > Config)
limit = args.limit limit = args.limit
limit = limit if limit is not None else run_config_data.limit limit = limit if limit is not None else config.limit
if limit is not None: if limit is not None:
logger.info("Limiting processing to first %d transcripts", limit) logger.info("Limiting processing to first %d transcripts", limit)
keys = keys[:limit] keys = keys[:limit]
# Create task list # Create task list
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys] tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
logger.info("Starting batch assessment with %d total items...", len(tasks_data)) logger.info("Starting batch assessment with %d total items...", len(tasks_data))
# Limit concurrency to 10 parallel requests # Limit concurrency to 10 parallel requests
semaphore = asyncio.Semaphore(10) semaphore = asyncio.Semaphore(10)
tasks = [ tasks = [
process_run(run_spec, key, run_config_data, s3_config, semaphore) process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data
for run_spec, key in tasks_data ]
]
# Run all tasks concurrently # Run all tasks concurrently
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
logger.info("Batch assessment complete.") logger.info("Batch assessment complete.")
else:
# Agent command (Placeholder)
question = run_config_data.question
logger.info("\nRunning Re-Agent with question: '%s'\n", question)
result = run_agent(question)
logger.info(result["answer"])
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -10,12 +10,12 @@ from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader from helia.ingestion.s3 import S3DatasetLoader
if TYPE_CHECKING: if TYPE_CHECKING:
from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig from helia.configuration import HeliaConfig, ProviderConfig, S3Config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def check_mongo(config: SystemConfig) -> bool: async def check_mongo(config: HeliaConfig) -> bool:
"""Verify MongoDB connectivity.""" """Verify MongoDB connectivity."""
try: try:
logger.info("Checking MongoDB connection...") logger.info("Checking MongoDB connection...")
@@ -67,7 +67,7 @@ async def check_llm_provider(name: str, config: ProviderConfig, model_name: str)
return True return True
async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool: async def check_all_connections(config: HeliaConfig) -> bool:
""" """
Run all pre-flight checks. Returns True if all checks pass. Run all pre-flight checks. Returns True if all checks pass.
""" """
@@ -76,16 +76,16 @@ async def check_all_connections(run_config: AssessBatchConfig, system_config: Sy
checks = [] checks = []
# 1. Check MongoDB # 1. Check MongoDB
checks.append(check_mongo(system_config)) checks.append(check_mongo(config))
# 2. Check S3 # 2. Check S3
checks.append(check_s3(system_config.get_s3_config())) checks.append(check_s3(config.get_s3_config()))
# 3. Check All LLM Providers # 3. Check All LLM Providers
for name, provider_config in run_config.providers.items(): for name, provider_config in config.providers.items():
# Find a model used by this provider in the runs to use for the check # Find a model used by this provider in the runs to use for the check
model_name = next( model_name = next(
(run.model.model_name for run in run_config.runs if run.model.provider == name), (run.model.model_name for run in config.runs if run.model.provider == name),
None, None,
) )