Refactor configuration management; update S3 integration, add new migration scripts, and implement pre-flight checks

This commit is contained in:
Santiago Martinez-Avial
2025-12-22 17:46:36 +01:00
parent f860d17206
commit 5c6d87dab7
17 changed files with 598 additions and 450 deletions

0
scripts/__init__.py Normal file
View File

137
scripts/preflight.py Normal file
View File

@@ -0,0 +1,137 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from openai import AsyncOpenAI
from rich.console import Console
from rich.table import Table
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader
if TYPE_CHECKING:
from helia.configuration import LLMProviderConfig, S3Config
console = Console()
async def check_mongo(config: HeliaConfig) -> tuple[str, bool]:
"""Verify MongoDB connectivity."""
try:
await init_db(config)
except Exception as e:
console.print(f"[red]MongoDB error: {e}[/red]")
return "MongoDB", False
else:
return "MongoDB", True
async def check_s3(config: S3Config) -> tuple[str, bool]:
"""Verify S3 connectivity by listing objects."""
try:
loader = S3DatasetLoader(config)
await loader.list_transcripts()
except Exception as e:
console.print(f"[red]S3 error: {e}[/red]")
return "S3", False
else:
return "S3", True
async def check_llm_provider(
name: str, config: LLMProviderConfig, model_name: str
) -> tuple[str, bool]:
"""Verify connectivity to an LLM provider."""
try:
client = AsyncOpenAI(
api_key=config.api_key,
base_url=config.api_base,
max_retries=0,
timeout=5.0,
)
await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "ping"}],
max_tokens=1,
)
except Exception as e:
console.print(f"[red]{name} error: {e}[/red]")
return name, False
else:
return name, True
async def check_all_connections(config: HeliaConfig) -> bool:
"""
Run all pre-flight checks. Returns True if all checks pass.
"""
console.print("\n[bold cyan]🔍 Starting Pre-flight Checks[/bold cyan]\n")
checks = []
check_names = []
checks.append(check_mongo(config))
check_names.append("MongoDB")
checks.append(check_s3(config.s3))
check_names.append("S3")
for name, provider_config in config.providers.items():
model_name = next(
(run.model.model_name for run in config.runs.values() if run.model.provider == name),
None,
)
if model_name:
checks.append(check_llm_provider(name, provider_config, model_name))
check_names.append(f"LLM Provider: {name}")
else:
console.print(
f"[yellow]⊘ Skipping LLM Provider '{name}': "
f"defined in config but not used in any run[/yellow]"
)
results = await asyncio.gather(*checks, return_exceptions=True)
# Create results table
table = Table(title="Pre-flight Check Results", show_header=True, header_style="bold")
table.add_column("Service", style="cyan")
table.add_column("Status", justify="center")
all_passed = True
for i, result in enumerate(results):
service_name = check_names[i] if i < len(check_names) else "Unknown"
if isinstance(result, Exception):
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
all_passed = False
elif isinstance(result, tuple):
name, passed = result
if passed:
table.add_row(name, "[bold green]✓ Passed[/bold green]")
else:
table.add_row(name, "[bold red]✗ Failed[/bold red]")
all_passed = False
elif result is False:
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
all_passed = False
console.print(table)
if all_passed:
console.print("\n[bold green]✓ All Pre-flight Checks Passed[/bold green]\n")
else:
console.print("\n[bold red]✗ Pre-flight Checks Failed[/bold red]\n")
return all_passed
if __name__ == "__main__":
config = HeliaConfig() # type: ignore[missing-argument, unknown-argument]
all_ok = asyncio.run(check_all_connections(config))
if not all_ok:
raise SystemExit(1)

39
scripts/verify_turns.py Normal file
View File

@@ -0,0 +1,39 @@
import asyncio
import logging
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.models.transcript import Transcript
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def verify() -> None:
config = HeliaConfig() # type: ignore[call-arg]
await init_db(config)
transcript = await Transcript.find_one(Transcript.transcript_id == "300")
if not transcript:
logger.error("Transcript 300 not found.")
return
logger.info("Transcript 300 found with %d utterances.", len(transcript.utterances))
turns = transcript.turns
logger.info("Aggregated into %d turns.", len(turns))
for i, turn in enumerate(turns[:5]):
logger.info(
"Turn %d [%s] (%s - %s): %s... (Merged %d utterances)",
i + 1,
turn.speaker,
turn.start_time,
turn.end_time,
turn.value[:50],
turn.utterance_count,
)
if __name__ == "__main__":
asyncio.run(verify())