Files
helia/scripts/preflight.py

138 lines
4.0 KiB
Python

from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from openai import AsyncOpenAI
from rich.console import Console
from rich.table import Table
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.ingestion.s3 import S3DatasetLoader
if TYPE_CHECKING:
from helia.configuration import LLMProviderConfig, S3Config
console = Console()
async def check_mongo(config: HeliaConfig) -> tuple[str, bool]:
"""Verify MongoDB connectivity."""
try:
await init_db(config)
except Exception as e:
console.print(f"[red]MongoDB error: {e}[/red]")
return "MongoDB", False
else:
return "MongoDB", True
async def check_s3(config: S3Config) -> tuple[str, bool]:
"""Verify S3 connectivity by listing objects."""
try:
loader = S3DatasetLoader(config)
await loader.list_transcripts()
except Exception as e:
console.print(f"[red]S3 error: {e}[/red]")
return "S3", False
else:
return "S3", True
async def check_llm_provider(
name: str, config: LLMProviderConfig, model_name: str
) -> tuple[str, bool]:
"""Verify connectivity to an LLM provider."""
try:
client = AsyncOpenAI(
api_key=config.api_key,
base_url=config.api_base,
max_retries=0,
timeout=5.0,
)
await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "ping"}],
max_tokens=1,
)
except Exception as e:
console.print(f"[red]{name} error: {e}[/red]")
return name, False
else:
return name, True
async def check_all_connections(config: HeliaConfig) -> bool:
"""
Run all pre-flight checks. Returns True if all checks pass.
"""
console.print("\n[bold cyan]🔍 Starting Pre-flight Checks[/bold cyan]\n")
checks = []
check_names = []
checks.append(check_mongo(config))
check_names.append("MongoDB")
checks.append(check_s3(config.s3))
check_names.append("S3")
for name, provider_config in config.providers.items():
model_name = next(
(run.model.model_name for run in config.runs.values() if run.model.provider == name),
None,
)
if model_name:
checks.append(check_llm_provider(name, provider_config, model_name))
check_names.append(f"LLM Provider: {name}")
else:
console.print(
f"[yellow]⊘ Skipping LLM Provider '{name}': "
f"defined in config but not used in any run[/yellow]"
)
results = await asyncio.gather(*checks, return_exceptions=True)
# Create results table
table = Table(title="Pre-flight Check Results", show_header=True, header_style="bold")
table.add_column("Service", style="cyan")
table.add_column("Status", justify="center")
all_passed = True
for i, result in enumerate(results):
service_name = check_names[i] if i < len(check_names) else "Unknown"
if isinstance(result, Exception):
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
all_passed = False
elif isinstance(result, tuple):
name, passed = result
if passed:
table.add_row(name, "[bold green]✓ Passed[/bold green]")
else:
table.add_row(name, "[bold red]✗ Failed[/bold red]")
all_passed = False
elif result is False:
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
all_passed = False
console.print(table)
if all_passed:
console.print("\n[bold green]✓ All Pre-flight Checks Passed[/bold green]\n")
else:
console.print("\n[bold red]✗ Pre-flight Checks Failed[/bold red]\n")
return all_passed
if __name__ == "__main__":
config = HeliaConfig() # type: ignore[missing-argument, unknown-argument]
all_ok = asyncio.run(check_all_connections(config))
if not all_ok:
raise SystemExit(1)