Refactor configuration management; remove deprecated files and unify settings in YAML format
This commit is contained in:
17
.env.example
17
.env.example
@@ -1,17 +0,0 @@
|
||||
# Database Configuration
|
||||
HELIA_MONGO_URI=mongodb://localhost:27017
|
||||
HELIA_DATABASE_NAME=helia
|
||||
|
||||
# S3 Configuration (MinIO or AWS)
|
||||
# Required for finding and downloading transcripts
|
||||
HELIA_S3_ENDPOINT=https://s3.amazonaws.com
|
||||
HELIA_S3_ACCESS_KEY=your_access_key
|
||||
HELIA_S3_SECRET_KEY=your_secret_key
|
||||
HELIA_S3_BUCKET=your-bucket-name
|
||||
HELIA_S3_REGION=us-east-1
|
||||
|
||||
# LLM API Keys
|
||||
# These are used by the run configuration YAML via ${VAR} substitution
|
||||
OPENAI_API_KEY=sk-...
|
||||
ANTHROPIC_API_KEY=sk-ant-...
|
||||
OPENROUTER_API_KEY=sk-or-...
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,6 +7,7 @@ wheels/
|
||||
*.egg-info
|
||||
|
||||
.env
|
||||
config.yaml
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
17
README.md
17
README.md
@@ -17,7 +17,7 @@ src/helia/
|
||||
├── analysis/
|
||||
│ └── extractor.py # Metadata extraction (LLM-agnostic)
|
||||
├── assessment/
|
||||
│ ├── core.py # Clinical assessment logic (PHQ-8)
|
||||
│ └── core.py # Clinical assessment logic (PHQ-8)
|
||||
│ └── schema.py # Data models (AssessmentResult, PHQ8Item)
|
||||
├── ingestion/
|
||||
│ └── parser.py # Transcript parsing (DAIC-WOZ support)
|
||||
@@ -50,6 +50,7 @@ graph TD
|
||||
- **Clinical Parsing**: Native support for DAIC-WOZ transcript formats.
|
||||
- **Structured Assessment**: Maps unstructured conversation to validatable PHQ-8 scores.
|
||||
- **Document Persistence**: Stores full experimental context (config + evidence + scores) in MongoDB using Beanie.
|
||||
- **Unified Configuration**: Single YAML config with environment variable fallbacks for system settings.
|
||||
|
||||
## Roadmap
|
||||
|
||||
@@ -67,22 +68,26 @@ uv sync
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Environment Setup**:
|
||||
1. **Configuration**:
|
||||
Copy `example.config.yaml` to `config.yaml` and edit it.
|
||||
```sh
|
||||
export OPENAI_API_KEY=sk-...
|
||||
# Ensure MongoDB is running (e.g., via Docker)
|
||||
cp example.config.yaml config.yaml
|
||||
```
|
||||
|
||||
You can set system settings (MongoDB, S3) in the YAML file, or use environment variables (e.g., `HELIA_MONGO_URI`, `HELIA_S3_BUCKET`).
|
||||
Provider API keys can be set in YAML or via `{PROVIDER_NAME}_API_KEY` (uppercased), e.g. `ANTHROPIC_API_KEY` for `providers.anthropic`.
|
||||
|
||||
2. **Run an Assessment**:
|
||||
```sh
|
||||
python -m helia.main "assess --input data/transcript.tsv"
|
||||
# Ensure MongoDB is running
|
||||
uv run python -m helia.main config.yaml
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
- **Linting**: `uv run ruff check .`
|
||||
- **Formatting**: `uv run ruff format .`
|
||||
- **Type Checking**: `uv run pyrefly`
|
||||
- **Type Checking**: `uv run ty check`
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -1,25 +1,36 @@
|
||||
# Helia Run Configuration
|
||||
# This file defines the "providers" (LLM connections) and the "runs" (experiments).
|
||||
# Environment variables like ${OPENAI_API_KEY} are expanded at runtime.
|
||||
# System Configuration
|
||||
mongo_uri: "mongodb://localhost:27017"
|
||||
database_name: "helia"
|
||||
|
||||
# S3 Configuration (MinIO or AWS)
|
||||
s3_endpoint: "https://s3.amazonaws.com"
|
||||
s3_access_key: "your_access_key"
|
||||
s3_secret_key: "your_secret_key"
|
||||
s3_bucket: "your-bucket-name"
|
||||
s3_prefix: ""
|
||||
s3_region: "us-east-1"
|
||||
|
||||
# Run Configuration
|
||||
limit: 5
|
||||
|
||||
providers:
|
||||
openai:
|
||||
api_key: "${OPENAI_API_KEY}"
|
||||
# Set api_key here or export OPENAI_API_KEY
|
||||
api_base: "https://api.openai.com/v1"
|
||||
api_format: "openai"
|
||||
|
||||
anthropic:
|
||||
api_key: "${ANTHROPIC_API_KEY}"
|
||||
# Set api_key here or export ANTHROPIC_API_KEY
|
||||
api_base: "https://api.anthropic.com/v1"
|
||||
api_format: "anthropic"
|
||||
|
||||
openrouter:
|
||||
api_key: "${OPENROUTER_API_KEY}"
|
||||
# Set api_key here or export OPENROUTER_API_KEY
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_format: "openai"
|
||||
|
||||
local_ollama:
|
||||
api_key: "none"
|
||||
# API key optional for local_* providers
|
||||
api_base: "http://localhost:11434/v1"
|
||||
api_format: "ollama"
|
||||
|
||||
@@ -37,6 +48,3 @@ runs:
|
||||
model_name: "llama3"
|
||||
temperature: 0.7
|
||||
prompt_id: "default"
|
||||
|
||||
# Optional: Limit the number of transcripts processed
|
||||
limit: 5
|
||||
@@ -1,15 +0,0 @@
|
||||
providers:
|
||||
openrouter:
|
||||
api_key: "${OPENROUTER_API_KEY}"
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_format: "openai"
|
||||
|
||||
runs:
|
||||
- run_name: "gemini_flash"
|
||||
model:
|
||||
provider: openrouter
|
||||
model_name: "google/gemini-3-flash-preview"
|
||||
temperature: 1.0
|
||||
prompt_id: "default"
|
||||
|
||||
limit: 5
|
||||
@@ -5,9 +5,6 @@ from typing import Any
|
||||
from langgraph.graph import END, StateGraph
|
||||
from pydantic import BaseModel
|
||||
|
||||
from helia.llm.client import get_chat_model
|
||||
from helia.llm.settings import settings
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""State for the agent workflow."""
|
||||
@@ -25,9 +22,8 @@ class AgentState(BaseModel):
|
||||
# proper protocol matching while maintaining runtime correctness.
|
||||
|
||||
|
||||
def planner_node(state: AgentState): # noqa: ANN201
|
||||
def planner_node(state: AgentState): # noqa: ANN201, ARG001
|
||||
"""Plan the steps to answer the question."""
|
||||
_ = state
|
||||
plan: list[str] = ["Understand question", "Retrieve info", "Synthesize answer"]
|
||||
return {"plan": plan}
|
||||
|
||||
@@ -57,40 +53,17 @@ def vector_tool_node(state: AgentState): # noqa: ANN201
|
||||
def synthesizer_node(state: AgentState): # noqa: ANN201
|
||||
"""Synthesize an answer from the gathered context."""
|
||||
context_text = "\n".join(state.context)
|
||||
question = state.question
|
||||
|
||||
prompt = f"""
|
||||
Answer the user's question based on the provided context.
|
||||
|
||||
Context:
|
||||
{context_text}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
try:
|
||||
llm = get_chat_model(
|
||||
model_name=settings.model,
|
||||
api_key=settings.resolve_api_key(),
|
||||
base_url=settings.base_url,
|
||||
)
|
||||
messages = [
|
||||
("system", "You are a helpful assistant."),
|
||||
("user", prompt),
|
||||
]
|
||||
response = llm.invoke(messages)
|
||||
answer = str(response.content)
|
||||
except Exception as e:
|
||||
answer = f"Error generating answer: {e}. Fallback: Based on context: {context_text}, here is the answer."
|
||||
answer = (
|
||||
"Agent workflow is placeholder and no longer loads LLM settings from env vars. "
|
||||
f"Fallback: Based on context: {context_text}, here is the answer."
|
||||
)
|
||||
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
def reflector_node(state: AgentState): # noqa: ANN201
|
||||
def reflector_node(state: AgentState): # noqa: ANN201, ARG001
|
||||
"""Reflect on the quality of the answer."""
|
||||
_ = state
|
||||
return {"critique": "Answer appears sufficient."}
|
||||
|
||||
|
||||
|
||||
@@ -55,11 +55,10 @@ class PHQ8Evaluator:
|
||||
self.config = config
|
||||
self.parser = TranscriptParser()
|
||||
|
||||
# Initialize LangChain Chat Model
|
||||
self.llm = get_chat_model(
|
||||
model_name=self.config.model_name,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.api_base,
|
||||
base_url=self.config.api_base or "",
|
||||
temperature=self.config.temperature,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Literal, NamedTuple
|
||||
from typing import TYPE_CHECKING, Literal, NamedTuple
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class S3Config(NamedTuple):
|
||||
@@ -21,42 +26,12 @@ class S3Config(NamedTuple):
|
||||
region_name: str | None = None
|
||||
|
||||
|
||||
class SystemConfig(BaseSettings):
|
||||
"""
|
||||
System-level configuration loaded from environment variables.
|
||||
Includes Database and AWS/S3 settings.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="HELIA_", env_file=".env", extra="ignore")
|
||||
|
||||
mongo_uri: str = Field(..., description="MongoDB connection string")
|
||||
database_name: str = Field("helia", description="MongoDB database name")
|
||||
|
||||
s3_endpoint: str = Field(..., description="S3 endpoint URL")
|
||||
s3_access_key: str = Field(..., description="S3 access key")
|
||||
s3_secret_key: str = Field(..., description="S3 secret key")
|
||||
s3_bucket: str = Field(..., description="S3 bucket containing the dataset")
|
||||
s3_prefix: str = Field("", description="S3 key prefix for dataset files")
|
||||
s3_region: str | None = Field(None, description="S3 region name (optional)")
|
||||
|
||||
def get_s3_config(self) -> S3Config:
|
||||
"""Create an S3Config from the system configuration."""
|
||||
return S3Config(
|
||||
bucket_name=self.s3_bucket,
|
||||
endpoint_url=self.s3_endpoint,
|
||||
aws_access_key_id=self.s3_access_key,
|
||||
aws_secret_access_key=self.s3_secret_key,
|
||||
prefix=self.s3_prefix,
|
||||
region_name=self.s3_region,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""
|
||||
Configuration for an LLM provider
|
||||
Configuration for an LLM provider.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
api_key: str | None = None
|
||||
api_base: str
|
||||
api_format: Literal["openai", "anthropic", "ollama"] = "openai"
|
||||
|
||||
@@ -81,50 +56,112 @@ class RunSpec(BaseModel):
|
||||
prompt_id: str = "default"
|
||||
|
||||
|
||||
class AssessBatchConfig(BaseModel):
|
||||
class HeliaConfig(BaseSettings):
|
||||
"""
|
||||
Configuration file structure for batch assessment.
|
||||
Unified configuration for Helia.
|
||||
Loads from YAML first, then falls back to Environment variables for system settings.
|
||||
"""
|
||||
|
||||
providers: dict[str, ProviderConfig]
|
||||
runs: list[RunSpec]
|
||||
limit: int | None = None
|
||||
# System Settings (Env Fallback available via HELIA_ prefix)
|
||||
mongo_uri: str = Field(..., description="MongoDB connection string")
|
||||
database_name: str = Field("helia", description="MongoDB database name")
|
||||
|
||||
s3_endpoint: str = Field(..., description="S3 endpoint URL")
|
||||
s3_access_key: str = Field(..., description="S3 access key")
|
||||
s3_secret_key: str = Field(..., description="S3 secret key")
|
||||
s3_bucket: str = Field(..., description="S3 bucket containing the dataset")
|
||||
s3_prefix: str = Field("", description="S3 key prefix for dataset files")
|
||||
s3_region: str | None = Field(None, description="S3 region name (optional)")
|
||||
|
||||
# Run Settings
|
||||
limit: int | None = Field(None, description="Limit the number of transcripts processed")
|
||||
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
runs: list[RunSpec] = Field(default_factory=list)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="HELIA_",
|
||||
env_file=None, # Disable .env file loading
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def get_s3_config(self) -> S3Config:
|
||||
"""Create an S3Config from the system configuration."""
|
||||
return S3Config(
|
||||
bucket_name=self.s3_bucket,
|
||||
endpoint_url=self.s3_endpoint,
|
||||
aws_access_key_id=self.s3_access_key,
|
||||
aws_secret_access_key=self.s3_secret_key,
|
||||
prefix=self.s3_prefix,
|
||||
region_name=self.s3_region,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_provider_api_keys(self) -> HeliaConfig:
|
||||
"""Resolve provider API keys from env vars as fallback.
|
||||
|
||||
- If `providers.<name>.api_key` is set in YAML, it is used.
|
||||
- Otherwise, the loader tries `{NAME}_API_KEY` (uppercased provider name).
|
||||
- Providers whose name starts with `local_` may omit an API key.
|
||||
"""
|
||||
for name, provider in self.providers.items():
|
||||
if provider.api_key:
|
||||
continue
|
||||
|
||||
env_var_name = f"{name.upper()}_API_KEY"
|
||||
if env_key := os.environ.get(env_var_name):
|
||||
provider.api_key = env_key
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_provider_api_keys(self) -> HeliaConfig:
|
||||
"""Enforce API key presence for non-local providers used by runs."""
|
||||
used_providers = {run.model.provider for run in self.runs}
|
||||
|
||||
for provider_name in used_providers:
|
||||
provider = self.providers.get(provider_name)
|
||||
if provider is None:
|
||||
msg = f"Provider '{provider_name}' is used in runs but not configured under 'providers'."
|
||||
raise ValueError(msg)
|
||||
|
||||
if provider_name.startswith("local_"):
|
||||
continue
|
||||
|
||||
if not provider.api_key:
|
||||
env_var_name = f"{provider_name.upper()}_API_KEY"
|
||||
msg = (
|
||||
f"Missing API key for provider '{provider_name}'. "
|
||||
f"Set providers.{provider_name}.api_key in YAML or export {env_var_name}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
# Placeholder for future agent config
|
||||
command: Literal["agent"] = "agent"
|
||||
|
||||
|
||||
ConfigType = AssessBatchConfig
|
||||
|
||||
|
||||
def _expand_env_vars(yaml_content: str) -> str:
|
||||
def load_config(path: str | Path) -> HeliaConfig:
|
||||
"""
|
||||
Expand environment variables in the format ${VAR} or ${VAR:default}.
|
||||
Load configuration from a YAML file, with environment variable fallback.
|
||||
"""
|
||||
pattern = re.compile(r"\$\{([^}^{]+)\}")
|
||||
|
||||
def replace(match: re.Match) -> str:
|
||||
env_var = match.group(1)
|
||||
default_value = ""
|
||||
if ":" in env_var:
|
||||
env_var, default_value = env_var.split(":", 1)
|
||||
return os.environ.get(env_var, default_value)
|
||||
# Create a dynamic subclass to bind the specific yaml_file path
|
||||
class RuntimeHeliaConfig(HeliaConfig):
|
||||
model_config = SettingsConfigDict(yaml_file=path)
|
||||
|
||||
return pattern.sub(replace, yaml_content)
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
env_settings,
|
||||
# No dotenv, no init_settings needed usually
|
||||
)
|
||||
|
||||
|
||||
def load_config(path: str | Path) -> ConfigType:
|
||||
with Path(path).open() as f:
|
||||
content = f.read()
|
||||
|
||||
content = _expand_env_vars(content)
|
||||
data = yaml.safe_load(content)
|
||||
|
||||
adapter = TypeAdapter(ConfigType)
|
||||
return adapter.validate_python(data)
|
||||
|
||||
|
||||
def load_system_config() -> SystemConfig:
|
||||
return SystemConfig()
|
||||
return RuntimeHeliaConfig() # type: ignore[call-arg]
|
||||
|
||||
@@ -8,10 +8,10 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from helia.assessment.schema import AssessmentResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from helia.configuration import SystemConfig
|
||||
from helia.configuration import HeliaConfig
|
||||
|
||||
|
||||
async def init_db(config: SystemConfig) -> None:
|
||||
async def init_db(config: HeliaConfig) -> None:
|
||||
client = AsyncIOMotorClient(config.mongo_uri)
|
||||
await init_beanie(
|
||||
database=client[config.database_name], # type: ignore[arg-type]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from helia.llm.client import get_openai_client
|
||||
from helia.llm.settings import settings
|
||||
from helia.llm.client import get_chat_model
|
||||
|
||||
__all__ = ["get_openai_client", "settings"]
|
||||
__all__ = ["get_chat_model"]
|
||||
|
||||
@@ -3,55 +3,20 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
from helia.llm.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
|
||||
def get_openai_client() -> OpenAI:
|
||||
"""
|
||||
Returns an configured OpenAI client based on global settings.
|
||||
Defaults to OpenRouter base_url if not specified otherwise.
|
||||
"""
|
||||
api_key = settings.resolve_api_key()
|
||||
|
||||
return OpenAI(
|
||||
base_url=settings.base_url,
|
||||
api_key=api_key,
|
||||
timeout=settings.timeout,
|
||||
max_retries=settings.max_retries,
|
||||
)
|
||||
|
||||
|
||||
def get_async_openai_client() -> AsyncOpenAI:
|
||||
"""
|
||||
Returns a configured AsyncOpenAI client based on global settings.
|
||||
"""
|
||||
api_key = settings.resolve_api_key()
|
||||
|
||||
return AsyncOpenAI(
|
||||
base_url=settings.base_url,
|
||||
api_key=api_key,
|
||||
timeout=settings.timeout,
|
||||
max_retries=settings.max_retries,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_model(
|
||||
model_name: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None,
|
||||
base_url: str,
|
||||
temperature: float = 0.0,
|
||||
max_retries: int = 3,
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Returns a configured LangChain ChatOpenAI instance.
|
||||
Supports OpenRouter, Ollama, and OpenAI via base_url.
|
||||
"""
|
||||
"""Return a configured LangChain ChatOpenAI instance."""
|
||||
return ChatOpenAI(
|
||||
model_name=model_name,
|
||||
openai_api_key=SecretStr(api_key or ""),
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class LLMSettings(BaseSettings):
|
||||
"""
|
||||
Configuration for LLM clients, defaulting to OpenRouter.
|
||||
"""
|
||||
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the LLM provider. Checks HELIA_LLM_API_KEY, OPENROUTER_API_KEY, then OPENAI_API_KEY.",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
description="Base URL for the LLM provider. Defaults to OpenRouter.",
|
||||
)
|
||||
model: str = Field(
|
||||
default="google/gemini-3.0-pro-preview",
|
||||
description="Model identifier to use.",
|
||||
)
|
||||
timeout: float = Field(
|
||||
default=30.0,
|
||||
description="Request timeout in seconds.",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=2,
|
||||
description="Maximum number of retries for failed requests.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="HELIA_LLM_",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def resolve_api_key(self) -> str:
|
||||
"""
|
||||
Resolves the API key with a fallback strategy:
|
||||
1. configured api_key (from HELIA_LLM_API_KEY)
|
||||
2. OPENROUTER_API_KEY env var
|
||||
3. OPENAI_API_KEY env var
|
||||
4. Raise ValueError if none found
|
||||
"""
|
||||
if self.api_key:
|
||||
return self.api_key
|
||||
|
||||
# Fallback 1: OpenRouter
|
||||
if key := os.environ.get("OPENROUTER_API_KEY"):
|
||||
return key
|
||||
|
||||
# Fallback 2: OpenAI
|
||||
if key := os.environ.get("OPENAI_API_KEY"):
|
||||
return key
|
||||
|
||||
raise ValueError(
|
||||
"No API key found. Please set HELIA_LLM_API_KEY, OPENROUTER_API_KEY, or OPENAI_API_KEY."
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance for easy import
|
||||
settings: Final[LLMSettings] = LLMSettings()
|
||||
@@ -4,15 +4,13 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from helia.agent.workflow import run_agent
|
||||
from helia.assessment.core import PHQ8Evaluator
|
||||
from helia.assessment.schema import RunConfig
|
||||
from helia.configuration import (
|
||||
AssessBatchConfig,
|
||||
HeliaConfig,
|
||||
RunSpec,
|
||||
S3Config,
|
||||
load_config,
|
||||
load_system_config,
|
||||
)
|
||||
from helia.db import init_db
|
||||
from helia.ingestion.s3 import S3DatasetLoader
|
||||
@@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
async def process_run(
|
||||
run_spec: RunSpec,
|
||||
input_source: str,
|
||||
run_config_data: AssessBatchConfig,
|
||||
config: HeliaConfig,
|
||||
s3_config: S3Config,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> None:
|
||||
@@ -34,11 +32,11 @@ async def process_run(
|
||||
async with semaphore:
|
||||
# Resolve Provider
|
||||
provider_name = run_spec.model.provider
|
||||
if provider_name not in run_config_data.providers:
|
||||
if provider_name not in config.providers:
|
||||
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
|
||||
return
|
||||
|
||||
provider_config = run_config_data.providers[provider_name]
|
||||
provider_config = config.providers[provider_name]
|
||||
|
||||
# Download from S3 (Async)
|
||||
loader = S3DatasetLoader(s3_config)
|
||||
@@ -111,61 +109,50 @@ async def main() -> None:
|
||||
return
|
||||
|
||||
try:
|
||||
run_config_data = load_config(config_path)
|
||||
system_config = load_system_config()
|
||||
config = load_config(config_path)
|
||||
except Exception:
|
||||
logger.exception("Error loading configuration")
|
||||
return
|
||||
|
||||
# Check the type of configuration
|
||||
if isinstance(run_config_data, AssessBatchConfig):
|
||||
# Run Pre-flight Checks
|
||||
if not await check_all_connections(run_config_data, system_config):
|
||||
logger.error("Pre-flight checks failed. Exiting.")
|
||||
return
|
||||
# Run Pre-flight Checks
|
||||
if not await check_all_connections(config):
|
||||
logger.error("Pre-flight checks failed. Exiting.")
|
||||
return
|
||||
|
||||
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
||||
await init_db(system_config)
|
||||
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
||||
await init_db(config)
|
||||
|
||||
# Create S3 config once and reuse
|
||||
s3_config = system_config.get_s3_config()
|
||||
# Create S3 config once and reuse
|
||||
s3_config = config.get_s3_config()
|
||||
|
||||
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
|
||||
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
|
||||
loader = S3DatasetLoader(s3_config)
|
||||
keys = loader.list_transcripts()
|
||||
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
|
||||
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
|
||||
loader = S3DatasetLoader(s3_config)
|
||||
keys = loader.list_transcripts()
|
||||
|
||||
# Apply limit if configured (Priority: CLI > Config)
|
||||
limit = args.limit
|
||||
limit = limit if limit is not None else run_config_data.limit
|
||||
# Apply limit if configured (Priority: CLI > Config)
|
||||
limit = args.limit
|
||||
limit = limit if limit is not None else config.limit
|
||||
|
||||
if limit is not None:
|
||||
logger.info("Limiting processing to first %d transcripts", limit)
|
||||
keys = keys[:limit]
|
||||
if limit is not None:
|
||||
logger.info("Limiting processing to first %d transcripts", limit)
|
||||
keys = keys[:limit]
|
||||
|
||||
# Create task list
|
||||
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys]
|
||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||
# Create task list
|
||||
tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
|
||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||
|
||||
# Limit concurrency to 10 parallel requests
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
# Limit concurrency to 10 parallel requests
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
|
||||
tasks = [
|
||||
process_run(run_spec, key, run_config_data, s3_config, semaphore)
|
||||
for run_spec, key in tasks_data
|
||||
]
|
||||
tasks = [
|
||||
process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data
|
||||
]
|
||||
|
||||
# Run all tasks concurrently
|
||||
await asyncio.gather(*tasks)
|
||||
# Run all tasks concurrently
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
logger.info("Batch assessment complete.")
|
||||
|
||||
else:
|
||||
# Agent command (Placeholder)
|
||||
question = run_config_data.question
|
||||
logger.info("\nRunning Re-Agent with question: '%s'\n", question)
|
||||
result = run_agent(question)
|
||||
logger.info(result["answer"])
|
||||
logger.info("Batch assessment complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,12 +10,12 @@ from helia.db import init_db
|
||||
from helia.ingestion.s3 import S3DatasetLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig
|
||||
from helia.configuration import HeliaConfig, ProviderConfig, S3Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def check_mongo(config: SystemConfig) -> bool:
|
||||
async def check_mongo(config: HeliaConfig) -> bool:
|
||||
"""Verify MongoDB connectivity."""
|
||||
try:
|
||||
logger.info("Checking MongoDB connection...")
|
||||
@@ -67,7 +67,7 @@ async def check_llm_provider(name: str, config: ProviderConfig, model_name: str)
|
||||
return True
|
||||
|
||||
|
||||
async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool:
|
||||
async def check_all_connections(config: HeliaConfig) -> bool:
|
||||
"""
|
||||
Run all pre-flight checks. Returns True if all checks pass.
|
||||
"""
|
||||
@@ -76,16 +76,16 @@ async def check_all_connections(run_config: AssessBatchConfig, system_config: Sy
|
||||
checks = []
|
||||
|
||||
# 1. Check MongoDB
|
||||
checks.append(check_mongo(system_config))
|
||||
checks.append(check_mongo(config))
|
||||
|
||||
# 2. Check S3
|
||||
checks.append(check_s3(system_config.get_s3_config()))
|
||||
checks.append(check_s3(config.get_s3_config()))
|
||||
|
||||
# 3. Check All LLM Providers
|
||||
for name, provider_config in run_config.providers.items():
|
||||
for name, provider_config in config.providers.items():
|
||||
# Find a model used by this provider in the runs to use for the check
|
||||
model_name = next(
|
||||
(run.model.model_name for run in run_config.runs if run.model.provider == name),
|
||||
(run.model.model_name for run in config.runs if run.model.provider == name),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user