Refactor configuration management; remove deprecated files and unify settings in YAML format
This commit is contained in:
17
.env.example
17
.env.example
@@ -1,17 +0,0 @@
|
|||||||
# Database Configuration
|
|
||||||
HELIA_MONGO_URI=mongodb://localhost:27017
|
|
||||||
HELIA_DATABASE_NAME=helia
|
|
||||||
|
|
||||||
# S3 Configuration (MinIO or AWS)
|
|
||||||
# Required for finding and downloading transcripts
|
|
||||||
HELIA_S3_ENDPOINT=https://s3.amazonaws.com
|
|
||||||
HELIA_S3_ACCESS_KEY=your_access_key
|
|
||||||
HELIA_S3_SECRET_KEY=your_secret_key
|
|
||||||
HELIA_S3_BUCKET=your-bucket-name
|
|
||||||
HELIA_S3_REGION=us-east-1
|
|
||||||
|
|
||||||
# LLM API Keys
|
|
||||||
# These are used by the run configuration YAML via ${VAR} substitution
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
ANTHROPIC_API_KEY=sk-ant-...
|
|
||||||
OPENROUTER_API_KEY=sk-or-...
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,6 +7,7 @@ wheels/
|
|||||||
*.egg-info
|
*.egg-info
|
||||||
|
|
||||||
.env
|
.env
|
||||||
|
config.yaml
|
||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv
|
||||||
|
|||||||
17
README.md
17
README.md
@@ -17,7 +17,7 @@ src/helia/
|
|||||||
├── analysis/
|
├── analysis/
|
||||||
│ └── extractor.py # Metadata extraction (LLM-agnostic)
|
│ └── extractor.py # Metadata extraction (LLM-agnostic)
|
||||||
├── assessment/
|
├── assessment/
|
||||||
│ ├── core.py # Clinical assessment logic (PHQ-8)
|
│ └── core.py # Clinical assessment logic (PHQ-8)
|
||||||
│ └── schema.py # Data models (AssessmentResult, PHQ8Item)
|
│ └── schema.py # Data models (AssessmentResult, PHQ8Item)
|
||||||
├── ingestion/
|
├── ingestion/
|
||||||
│ └── parser.py # Transcript parsing (DAIC-WOZ support)
|
│ └── parser.py # Transcript parsing (DAIC-WOZ support)
|
||||||
@@ -50,6 +50,7 @@ graph TD
|
|||||||
- **Clinical Parsing**: Native support for DAIC-WOZ transcript formats.
|
- **Clinical Parsing**: Native support for DAIC-WOZ transcript formats.
|
||||||
- **Structured Assessment**: Maps unstructured conversation to validatable PHQ-8 scores.
|
- **Structured Assessment**: Maps unstructured conversation to validatable PHQ-8 scores.
|
||||||
- **Document Persistence**: Stores full experimental context (config + evidence + scores) in MongoDB using Beanie.
|
- **Document Persistence**: Stores full experimental context (config + evidence + scores) in MongoDB using Beanie.
|
||||||
|
- **Unified Configuration**: Single YAML config with environment variable fallbacks for system settings.
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
@@ -67,22 +68,26 @@ uv sync
|
|||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
1. **Environment Setup**:
|
1. **Configuration**:
|
||||||
|
Copy `example.config.yaml` to `config.yaml` and edit it.
|
||||||
```sh
|
```sh
|
||||||
export OPENAI_API_KEY=sk-...
|
cp example.config.yaml config.yaml
|
||||||
# Ensure MongoDB is running (e.g., via Docker)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can set system settings (MongoDB, S3) in the YAML file, or use environment variables (e.g., `HELIA_MONGO_URI`, `HELIA_S3_BUCKET`).
|
||||||
|
Provider API keys can be set in YAML or via `{PROVIDER_NAME}_API_KEY` (uppercased), e.g. `ANTHROPIC_API_KEY` for `providers.anthropic`.
|
||||||
|
|
||||||
2. **Run an Assessment**:
|
2. **Run an Assessment**:
|
||||||
```sh
|
```sh
|
||||||
python -m helia.main "assess --input data/transcript.tsv"
|
# Ensure MongoDB is running
|
||||||
|
uv run python -m helia.main config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
- **Linting**: `uv run ruff check .`
|
- **Linting**: `uv run ruff check .`
|
||||||
- **Formatting**: `uv run ruff format .`
|
- **Formatting**: `uv run ruff format .`
|
||||||
- **Type Checking**: `uv run pyrefly`
|
- **Type Checking**: `uv run ty check`
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,36 @@
|
|||||||
# Helia Run Configuration
|
# System Configuration
|
||||||
# This file defines the "providers" (LLM connections) and the "runs" (experiments).
|
mongo_uri: "mongodb://localhost:27017"
|
||||||
# Environment variables like ${OPENAI_API_KEY} are expanded at runtime.
|
database_name: "helia"
|
||||||
|
|
||||||
|
# S3 Configuration (MinIO or AWS)
|
||||||
|
s3_endpoint: "https://s3.amazonaws.com"
|
||||||
|
s3_access_key: "your_access_key"
|
||||||
|
s3_secret_key: "your_secret_key"
|
||||||
|
s3_bucket: "your-bucket-name"
|
||||||
|
s3_prefix: ""
|
||||||
|
s3_region: "us-east-1"
|
||||||
|
|
||||||
|
# Run Configuration
|
||||||
|
limit: 5
|
||||||
|
|
||||||
providers:
|
providers:
|
||||||
openai:
|
openai:
|
||||||
api_key: "${OPENAI_API_KEY}"
|
# Set api_key here or export OPENAI_API_KEY
|
||||||
api_base: "https://api.openai.com/v1"
|
api_base: "https://api.openai.com/v1"
|
||||||
api_format: "openai"
|
api_format: "openai"
|
||||||
|
|
||||||
anthropic:
|
anthropic:
|
||||||
api_key: "${ANTHROPIC_API_KEY}"
|
# Set api_key here or export ANTHROPIC_API_KEY
|
||||||
api_base: "https://api.anthropic.com/v1"
|
api_base: "https://api.anthropic.com/v1"
|
||||||
api_format: "anthropic"
|
api_format: "anthropic"
|
||||||
|
|
||||||
openrouter:
|
openrouter:
|
||||||
api_key: "${OPENROUTER_API_KEY}"
|
# Set api_key here or export OPENROUTER_API_KEY
|
||||||
api_base: "https://openrouter.ai/api/v1"
|
api_base: "https://openrouter.ai/api/v1"
|
||||||
api_format: "openai"
|
api_format: "openai"
|
||||||
|
|
||||||
local_ollama:
|
local_ollama:
|
||||||
api_key: "none"
|
# API key optional for local_* providers
|
||||||
api_base: "http://localhost:11434/v1"
|
api_base: "http://localhost:11434/v1"
|
||||||
api_format: "ollama"
|
api_format: "ollama"
|
||||||
|
|
||||||
@@ -37,6 +48,3 @@ runs:
|
|||||||
model_name: "llama3"
|
model_name: "llama3"
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
prompt_id: "default"
|
prompt_id: "default"
|
||||||
|
|
||||||
# Optional: Limit the number of transcripts processed
|
|
||||||
limit: 5
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
providers:
|
|
||||||
openrouter:
|
|
||||||
api_key: "${OPENROUTER_API_KEY}"
|
|
||||||
api_base: "https://openrouter.ai/api/v1"
|
|
||||||
api_format: "openai"
|
|
||||||
|
|
||||||
runs:
|
|
||||||
- run_name: "gemini_flash"
|
|
||||||
model:
|
|
||||||
provider: openrouter
|
|
||||||
model_name: "google/gemini-3-flash-preview"
|
|
||||||
temperature: 1.0
|
|
||||||
prompt_id: "default"
|
|
||||||
|
|
||||||
limit: 5
|
|
||||||
@@ -5,9 +5,6 @@ from typing import Any
|
|||||||
from langgraph.graph import END, StateGraph
|
from langgraph.graph import END, StateGraph
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from helia.llm.client import get_chat_model
|
|
||||||
from helia.llm.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState(BaseModel):
|
class AgentState(BaseModel):
|
||||||
"""State for the agent workflow."""
|
"""State for the agent workflow."""
|
||||||
@@ -25,9 +22,8 @@ class AgentState(BaseModel):
|
|||||||
# proper protocol matching while maintaining runtime correctness.
|
# proper protocol matching while maintaining runtime correctness.
|
||||||
|
|
||||||
|
|
||||||
def planner_node(state: AgentState): # noqa: ANN201
|
def planner_node(state: AgentState): # noqa: ANN201, ARG001
|
||||||
"""Plan the steps to answer the question."""
|
"""Plan the steps to answer the question."""
|
||||||
_ = state
|
|
||||||
plan: list[str] = ["Understand question", "Retrieve info", "Synthesize answer"]
|
plan: list[str] = ["Understand question", "Retrieve info", "Synthesize answer"]
|
||||||
return {"plan": plan}
|
return {"plan": plan}
|
||||||
|
|
||||||
@@ -57,40 +53,17 @@ def vector_tool_node(state: AgentState): # noqa: ANN201
|
|||||||
def synthesizer_node(state: AgentState): # noqa: ANN201
|
def synthesizer_node(state: AgentState): # noqa: ANN201
|
||||||
"""Synthesize an answer from the gathered context."""
|
"""Synthesize an answer from the gathered context."""
|
||||||
context_text = "\n".join(state.context)
|
context_text = "\n".join(state.context)
|
||||||
question = state.question
|
|
||||||
|
|
||||||
prompt = f"""
|
answer = (
|
||||||
Answer the user's question based on the provided context.
|
"Agent workflow is placeholder and no longer loads LLM settings from env vars. "
|
||||||
|
f"Fallback: Based on context: {context_text}, here is the answer."
|
||||||
Context:
|
|
||||||
{context_text}
|
|
||||||
|
|
||||||
Question: {question}
|
|
||||||
|
|
||||||
Answer:
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm = get_chat_model(
|
|
||||||
model_name=settings.model,
|
|
||||||
api_key=settings.resolve_api_key(),
|
|
||||||
base_url=settings.base_url,
|
|
||||||
)
|
)
|
||||||
messages = [
|
|
||||||
("system", "You are a helpful assistant."),
|
|
||||||
("user", prompt),
|
|
||||||
]
|
|
||||||
response = llm.invoke(messages)
|
|
||||||
answer = str(response.content)
|
|
||||||
except Exception as e:
|
|
||||||
answer = f"Error generating answer: {e}. Fallback: Based on context: {context_text}, here is the answer."
|
|
||||||
|
|
||||||
return {"answer": answer}
|
return {"answer": answer}
|
||||||
|
|
||||||
|
|
||||||
def reflector_node(state: AgentState): # noqa: ANN201
|
def reflector_node(state: AgentState): # noqa: ANN201, ARG001
|
||||||
"""Reflect on the quality of the answer."""
|
"""Reflect on the quality of the answer."""
|
||||||
_ = state
|
|
||||||
return {"critique": "Answer appears sufficient."}
|
return {"critique": "Answer appears sufficient."}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,11 +55,10 @@ class PHQ8Evaluator:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.parser = TranscriptParser()
|
self.parser = TranscriptParser()
|
||||||
|
|
||||||
# Initialize LangChain Chat Model
|
|
||||||
self.llm = get_chat_model(
|
self.llm = get_chat_model(
|
||||||
model_name=self.config.model_name,
|
model_name=self.config.model_name,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.api_base,
|
base_url=self.config.api_base or "",
|
||||||
temperature=self.config.temperature,
|
temperature=self.config.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
from typing import TYPE_CHECKING, Literal, NamedTuple
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, NamedTuple
|
|
||||||
|
|
||||||
import yaml
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic_settings import (
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
SettingsConfigDict,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class S3Config(NamedTuple):
|
class S3Config(NamedTuple):
|
||||||
@@ -21,42 +26,12 @@ class S3Config(NamedTuple):
|
|||||||
region_name: str | None = None
|
region_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemConfig(BaseSettings):
|
|
||||||
"""
|
|
||||||
System-level configuration loaded from environment variables.
|
|
||||||
Includes Database and AWS/S3 settings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_prefix="HELIA_", env_file=".env", extra="ignore")
|
|
||||||
|
|
||||||
mongo_uri: str = Field(..., description="MongoDB connection string")
|
|
||||||
database_name: str = Field("helia", description="MongoDB database name")
|
|
||||||
|
|
||||||
s3_endpoint: str = Field(..., description="S3 endpoint URL")
|
|
||||||
s3_access_key: str = Field(..., description="S3 access key")
|
|
||||||
s3_secret_key: str = Field(..., description="S3 secret key")
|
|
||||||
s3_bucket: str = Field(..., description="S3 bucket containing the dataset")
|
|
||||||
s3_prefix: str = Field("", description="S3 key prefix for dataset files")
|
|
||||||
s3_region: str | None = Field(None, description="S3 region name (optional)")
|
|
||||||
|
|
||||||
def get_s3_config(self) -> S3Config:
|
|
||||||
"""Create an S3Config from the system configuration."""
|
|
||||||
return S3Config(
|
|
||||||
bucket_name=self.s3_bucket,
|
|
||||||
endpoint_url=self.s3_endpoint,
|
|
||||||
aws_access_key_id=self.s3_access_key,
|
|
||||||
aws_secret_access_key=self.s3_secret_key,
|
|
||||||
prefix=self.s3_prefix,
|
|
||||||
region_name=self.s3_region,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfig(BaseModel):
|
class ProviderConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Configuration for an LLM provider
|
Configuration for an LLM provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_key: str
|
api_key: str | None = None
|
||||||
api_base: str
|
api_base: str
|
||||||
api_format: Literal["openai", "anthropic", "ollama"] = "openai"
|
api_format: Literal["openai", "anthropic", "ollama"] = "openai"
|
||||||
|
|
||||||
@@ -81,50 +56,112 @@ class RunSpec(BaseModel):
|
|||||||
prompt_id: str = "default"
|
prompt_id: str = "default"
|
||||||
|
|
||||||
|
|
||||||
class AssessBatchConfig(BaseModel):
|
class HeliaConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration file structure for batch assessment.
|
Unified configuration for Helia.
|
||||||
|
Loads from YAML first, then falls back to Environment variables for system settings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
providers: dict[str, ProviderConfig]
|
# System Settings (Env Fallback available via HELIA_ prefix)
|
||||||
runs: list[RunSpec]
|
mongo_uri: str = Field(..., description="MongoDB connection string")
|
||||||
limit: int | None = None
|
database_name: str = Field("helia", description="MongoDB database name")
|
||||||
|
|
||||||
|
s3_endpoint: str = Field(..., description="S3 endpoint URL")
|
||||||
|
s3_access_key: str = Field(..., description="S3 access key")
|
||||||
|
s3_secret_key: str = Field(..., description="S3 secret key")
|
||||||
|
s3_bucket: str = Field(..., description="S3 bucket containing the dataset")
|
||||||
|
s3_prefix: str = Field("", description="S3 key prefix for dataset files")
|
||||||
|
s3_region: str | None = Field(None, description="S3 region name (optional)")
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
# Run Settings
|
||||||
# Placeholder for future agent config
|
limit: int | None = Field(None, description="Limit the number of transcripts processed")
|
||||||
command: Literal["agent"] = "agent"
|
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||||
|
runs: list[RunSpec] = Field(default_factory=list)
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="HELIA_",
|
||||||
|
env_file=None, # Disable .env file loading
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
ConfigType = AssessBatchConfig
|
def get_s3_config(self) -> S3Config:
|
||||||
|
"""Create an S3Config from the system configuration."""
|
||||||
|
return S3Config(
|
||||||
|
bucket_name=self.s3_bucket,
|
||||||
|
endpoint_url=self.s3_endpoint,
|
||||||
|
aws_access_key_id=self.s3_access_key,
|
||||||
|
aws_secret_access_key=self.s3_secret_key,
|
||||||
|
prefix=self.s3_prefix,
|
||||||
|
region_name=self.s3_region,
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def resolve_provider_api_keys(self) -> HeliaConfig:
|
||||||
|
"""Resolve provider API keys from env vars as fallback.
|
||||||
|
|
||||||
def _expand_env_vars(yaml_content: str) -> str:
|
- If `providers.<name>.api_key` is set in YAML, it is used.
|
||||||
|
- Otherwise, the loader tries `{NAME}_API_KEY` (uppercased provider name).
|
||||||
|
- Providers whose name starts with `local_` may omit an API key.
|
||||||
"""
|
"""
|
||||||
Expand environment variables in the format ${VAR} or ${VAR:default}.
|
for name, provider in self.providers.items():
|
||||||
|
if provider.api_key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
env_var_name = f"{name.upper()}_API_KEY"
|
||||||
|
if env_key := os.environ.get(env_var_name):
|
||||||
|
provider.api_key = env_key
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_provider_api_keys(self) -> HeliaConfig:
|
||||||
|
"""Enforce API key presence for non-local providers used by runs."""
|
||||||
|
used_providers = {run.model.provider for run in self.runs}
|
||||||
|
|
||||||
|
for provider_name in used_providers:
|
||||||
|
provider = self.providers.get(provider_name)
|
||||||
|
if provider is None:
|
||||||
|
msg = f"Provider '{provider_name}' is used in runs but not configured under 'providers'."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if provider_name.startswith("local_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not provider.api_key:
|
||||||
|
env_var_name = f"{provider_name.upper()}_API_KEY"
|
||||||
|
msg = (
|
||||||
|
f"Missing API key for provider '{provider_name}'. "
|
||||||
|
f"Set providers.{provider_name}.api_key in YAML or export {env_var_name}."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(path: str | Path) -> HeliaConfig:
|
||||||
|
"""
|
||||||
|
Load configuration from a YAML file, with environment variable fallback.
|
||||||
"""
|
"""
|
||||||
pattern = re.compile(r"\$\{([^}^{]+)\}")
|
|
||||||
|
|
||||||
def replace(match: re.Match) -> str:
|
# Create a dynamic subclass to bind the specific yaml_file path
|
||||||
env_var = match.group(1)
|
class RuntimeHeliaConfig(HeliaConfig):
|
||||||
default_value = ""
|
model_config = SettingsConfigDict(yaml_file=path)
|
||||||
if ":" in env_var:
|
|
||||||
env_var, default_value = env_var.split(":", 1)
|
|
||||||
return os.environ.get(env_var, default_value)
|
|
||||||
|
|
||||||
return pattern.sub(replace, yaml_content)
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
return (
|
||||||
|
YamlConfigSettingsSource(settings_cls),
|
||||||
|
env_settings,
|
||||||
|
# No dotenv, no init_settings needed usually
|
||||||
|
)
|
||||||
|
|
||||||
|
return RuntimeHeliaConfig() # type: ignore[call-arg]
|
||||||
def load_config(path: str | Path) -> ConfigType:
|
|
||||||
with Path(path).open() as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
content = _expand_env_vars(content)
|
|
||||||
data = yaml.safe_load(content)
|
|
||||||
|
|
||||||
adapter = TypeAdapter(ConfigType)
|
|
||||||
return adapter.validate_python(data)
|
|
||||||
|
|
||||||
|
|
||||||
def load_system_config() -> SystemConfig:
|
|
||||||
return SystemConfig()
|
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
|||||||
from helia.assessment.schema import AssessmentResult
|
from helia.assessment.schema import AssessmentResult
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from helia.configuration import SystemConfig
|
from helia.configuration import HeliaConfig
|
||||||
|
|
||||||
|
|
||||||
async def init_db(config: SystemConfig) -> None:
|
async def init_db(config: HeliaConfig) -> None:
|
||||||
client = AsyncIOMotorClient(config.mongo_uri)
|
client = AsyncIOMotorClient(config.mongo_uri)
|
||||||
await init_beanie(
|
await init_beanie(
|
||||||
database=client[config.database_name], # type: ignore[arg-type]
|
database=client[config.database_name], # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from helia.llm.client import get_openai_client
|
from helia.llm.client import get_chat_model
|
||||||
from helia.llm.settings import settings
|
|
||||||
|
|
||||||
__all__ = ["get_openai_client", "settings"]
|
__all__ = ["get_chat_model"]
|
||||||
|
|||||||
@@ -3,55 +3,20 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from openai import AsyncOpenAI, OpenAI
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from helia.llm.settings import settings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
|
|
||||||
def get_openai_client() -> OpenAI:
|
|
||||||
"""
|
|
||||||
Returns an configured OpenAI client based on global settings.
|
|
||||||
Defaults to OpenRouter base_url if not specified otherwise.
|
|
||||||
"""
|
|
||||||
api_key = settings.resolve_api_key()
|
|
||||||
|
|
||||||
return OpenAI(
|
|
||||||
base_url=settings.base_url,
|
|
||||||
api_key=api_key,
|
|
||||||
timeout=settings.timeout,
|
|
||||||
max_retries=settings.max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_async_openai_client() -> AsyncOpenAI:
|
|
||||||
"""
|
|
||||||
Returns a configured AsyncOpenAI client based on global settings.
|
|
||||||
"""
|
|
||||||
api_key = settings.resolve_api_key()
|
|
||||||
|
|
||||||
return AsyncOpenAI(
|
|
||||||
base_url=settings.base_url,
|
|
||||||
api_key=api_key,
|
|
||||||
timeout=settings.timeout,
|
|
||||||
max_retries=settings.max_retries,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_model(
|
def get_chat_model(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
api_key: str | None = None,
|
api_key: str | None,
|
||||||
base_url: str | None = None,
|
base_url: str,
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
) -> BaseChatModel:
|
) -> BaseChatModel:
|
||||||
"""
|
"""Return a configured LangChain ChatOpenAI instance."""
|
||||||
Returns a configured LangChain ChatOpenAI instance.
|
|
||||||
Supports OpenRouter, Ollama, and OpenAI via base_url.
|
|
||||||
"""
|
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
openai_api_key=SecretStr(api_key or ""),
|
openai_api_key=SecretStr(api_key or ""),
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Final
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
|
|
||||||
class LLMSettings(BaseSettings):
|
|
||||||
"""
|
|
||||||
Configuration for LLM clients, defaulting to OpenRouter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for the LLM provider. Checks HELIA_LLM_API_KEY, OPENROUTER_API_KEY, then OPENAI_API_KEY.",
|
|
||||||
)
|
|
||||||
base_url: str = Field(
|
|
||||||
default="https://openrouter.ai/api/v1",
|
|
||||||
description="Base URL for the LLM provider. Defaults to OpenRouter.",
|
|
||||||
)
|
|
||||||
model: str = Field(
|
|
||||||
default="google/gemini-3.0-pro-preview",
|
|
||||||
description="Model identifier to use.",
|
|
||||||
)
|
|
||||||
timeout: float = Field(
|
|
||||||
default=30.0,
|
|
||||||
description="Request timeout in seconds.",
|
|
||||||
)
|
|
||||||
max_retries: int = Field(
|
|
||||||
default=2,
|
|
||||||
description="Maximum number of retries for failed requests.",
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
|
||||||
env_prefix="HELIA_LLM_",
|
|
||||||
case_sensitive=False,
|
|
||||||
extra="ignore",
|
|
||||||
)
|
|
||||||
|
|
||||||
def resolve_api_key(self) -> str:
|
|
||||||
"""
|
|
||||||
Resolves the API key with a fallback strategy:
|
|
||||||
1. configured api_key (from HELIA_LLM_API_KEY)
|
|
||||||
2. OPENROUTER_API_KEY env var
|
|
||||||
3. OPENAI_API_KEY env var
|
|
||||||
4. Raise ValueError if none found
|
|
||||||
"""
|
|
||||||
if self.api_key:
|
|
||||||
return self.api_key
|
|
||||||
|
|
||||||
# Fallback 1: OpenRouter
|
|
||||||
if key := os.environ.get("OPENROUTER_API_KEY"):
|
|
||||||
return key
|
|
||||||
|
|
||||||
# Fallback 2: OpenAI
|
|
||||||
if key := os.environ.get("OPENAI_API_KEY"):
|
|
||||||
return key
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
"No API key found. Please set HELIA_LLM_API_KEY, OPENROUTER_API_KEY, or OPENAI_API_KEY."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance for easy import
|
|
||||||
settings: Final[LLMSettings] = LLMSettings()
|
|
||||||
@@ -4,15 +4,13 @@ import logging
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from helia.agent.workflow import run_agent
|
|
||||||
from helia.assessment.core import PHQ8Evaluator
|
from helia.assessment.core import PHQ8Evaluator
|
||||||
from helia.assessment.schema import RunConfig
|
from helia.assessment.schema import RunConfig
|
||||||
from helia.configuration import (
|
from helia.configuration import (
|
||||||
AssessBatchConfig,
|
HeliaConfig,
|
||||||
RunSpec,
|
RunSpec,
|
||||||
S3Config,
|
S3Config,
|
||||||
load_config,
|
load_config,
|
||||||
load_system_config,
|
|
||||||
)
|
)
|
||||||
from helia.db import init_db
|
from helia.db import init_db
|
||||||
from helia.ingestion.s3 import S3DatasetLoader
|
from helia.ingestion.s3 import S3DatasetLoader
|
||||||
@@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
async def process_run(
|
async def process_run(
|
||||||
run_spec: RunSpec,
|
run_spec: RunSpec,
|
||||||
input_source: str,
|
input_source: str,
|
||||||
run_config_data: AssessBatchConfig,
|
config: HeliaConfig,
|
||||||
s3_config: S3Config,
|
s3_config: S3Config,
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -34,11 +32,11 @@ async def process_run(
|
|||||||
async with semaphore:
|
async with semaphore:
|
||||||
# Resolve Provider
|
# Resolve Provider
|
||||||
provider_name = run_spec.model.provider
|
provider_name = run_spec.model.provider
|
||||||
if provider_name not in run_config_data.providers:
|
if provider_name not in config.providers:
|
||||||
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
|
logger.error("Run %s refers to unknown provider %s", run_spec.run_name, provider_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
provider_config = run_config_data.providers[provider_name]
|
provider_config = config.providers[provider_name]
|
||||||
|
|
||||||
# Download from S3 (Async)
|
# Download from S3 (Async)
|
||||||
loader = S3DatasetLoader(s3_config)
|
loader = S3DatasetLoader(s3_config)
|
||||||
@@ -111,24 +109,21 @@ async def main() -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
run_config_data = load_config(config_path)
|
config = load_config(config_path)
|
||||||
system_config = load_system_config()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error loading configuration")
|
logger.exception("Error loading configuration")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check the type of configuration
|
|
||||||
if isinstance(run_config_data, AssessBatchConfig):
|
|
||||||
# Run Pre-flight Checks
|
# Run Pre-flight Checks
|
||||||
if not await check_all_connections(run_config_data, system_config):
|
if not await check_all_connections(config):
|
||||||
logger.error("Pre-flight checks failed. Exiting.")
|
logger.error("Pre-flight checks failed. Exiting.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
# Ensure DB is initialized (redundant if pre-flight passed, but safe)
|
||||||
await init_db(system_config)
|
await init_db(config)
|
||||||
|
|
||||||
# Create S3 config once and reuse
|
# Create S3 config once and reuse
|
||||||
s3_config = system_config.get_s3_config()
|
s3_config = config.get_s3_config()
|
||||||
|
|
||||||
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
|
# Discover transcripts (can remain sync or be made async, sync is fine for listing)
|
||||||
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
|
logger.info("Discovering transcripts in S3 bucket: %s", s3_config.bucket_name)
|
||||||
@@ -137,22 +132,21 @@ async def main() -> None:
|
|||||||
|
|
||||||
# Apply limit if configured (Priority: CLI > Config)
|
# Apply limit if configured (Priority: CLI > Config)
|
||||||
limit = args.limit
|
limit = args.limit
|
||||||
limit = limit if limit is not None else run_config_data.limit
|
limit = limit if limit is not None else config.limit
|
||||||
|
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
logger.info("Limiting processing to first %d transcripts", limit)
|
logger.info("Limiting processing to first %d transcripts", limit)
|
||||||
keys = keys[:limit]
|
keys = keys[:limit]
|
||||||
|
|
||||||
# Create task list
|
# Create task list
|
||||||
tasks_data = [(run_spec, key) for run_spec in run_config_data.runs for key in keys]
|
tasks_data = [(run_spec, key) for run_spec in config.runs for key in keys]
|
||||||
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
logger.info("Starting batch assessment with %d total items...", len(tasks_data))
|
||||||
|
|
||||||
# Limit concurrency to 10 parallel requests
|
# Limit concurrency to 10 parallel requests
|
||||||
semaphore = asyncio.Semaphore(10)
|
semaphore = asyncio.Semaphore(10)
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
process_run(run_spec, key, run_config_data, s3_config, semaphore)
|
process_run(run_spec, key, config, s3_config, semaphore) for run_spec, key in tasks_data
|
||||||
for run_spec, key in tasks_data
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run all tasks concurrently
|
# Run all tasks concurrently
|
||||||
@@ -160,13 +154,6 @@ async def main() -> None:
|
|||||||
|
|
||||||
logger.info("Batch assessment complete.")
|
logger.info("Batch assessment complete.")
|
||||||
|
|
||||||
else:
|
|
||||||
# Agent command (Placeholder)
|
|
||||||
question = run_config_data.question
|
|
||||||
logger.info("\nRunning Re-Agent with question: '%s'\n", question)
|
|
||||||
result = run_agent(question)
|
|
||||||
logger.info(result["answer"])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -10,12 +10,12 @@ from helia.db import init_db
|
|||||||
from helia.ingestion.s3 import S3DatasetLoader
|
from helia.ingestion.s3 import S3DatasetLoader
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from helia.configuration import AssessBatchConfig, ProviderConfig, S3Config, SystemConfig
|
from helia.configuration import HeliaConfig, ProviderConfig, S3Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def check_mongo(config: SystemConfig) -> bool:
|
async def check_mongo(config: HeliaConfig) -> bool:
|
||||||
"""Verify MongoDB connectivity."""
|
"""Verify MongoDB connectivity."""
|
||||||
try:
|
try:
|
||||||
logger.info("Checking MongoDB connection...")
|
logger.info("Checking MongoDB connection...")
|
||||||
@@ -67,7 +67,7 @@ async def check_llm_provider(name: str, config: ProviderConfig, model_name: str)
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def check_all_connections(run_config: AssessBatchConfig, system_config: SystemConfig) -> bool:
|
async def check_all_connections(config: HeliaConfig) -> bool:
|
||||||
"""
|
"""
|
||||||
Run all pre-flight checks. Returns True if all checks pass.
|
Run all pre-flight checks. Returns True if all checks pass.
|
||||||
"""
|
"""
|
||||||
@@ -76,16 +76,16 @@ async def check_all_connections(run_config: AssessBatchConfig, system_config: Sy
|
|||||||
checks = []
|
checks = []
|
||||||
|
|
||||||
# 1. Check MongoDB
|
# 1. Check MongoDB
|
||||||
checks.append(check_mongo(system_config))
|
checks.append(check_mongo(config))
|
||||||
|
|
||||||
# 2. Check S3
|
# 2. Check S3
|
||||||
checks.append(check_s3(system_config.get_s3_config()))
|
checks.append(check_s3(config.get_s3_config()))
|
||||||
|
|
||||||
# 3. Check All LLM Providers
|
# 3. Check All LLM Providers
|
||||||
for name, provider_config in run_config.providers.items():
|
for name, provider_config in config.providers.items():
|
||||||
# Find a model used by this provider in the runs to use for the check
|
# Find a model used by this provider in the runs to use for the check
|
||||||
model_name = next(
|
model_name = next(
|
||||||
(run.model.model_name for run in run_config.runs if run.model.provider == name),
|
(run.model.model_name for run in config.runs if run.model.provider == name),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user