from __future__ import annotations import asyncio from typing import TYPE_CHECKING from openai import AsyncOpenAI from rich.console import Console from rich.table import Table from helia.configuration import HeliaConfig from helia.db import init_db from helia.ingestion.s3 import S3DatasetLoader if TYPE_CHECKING: from helia.configuration import LLMProviderConfig, S3Config console = Console() async def check_mongo(config: HeliaConfig) -> tuple[str, bool]: """Verify MongoDB connectivity.""" try: await init_db(config) except Exception as e: console.print(f"[red]MongoDB error: {e}[/red]") return "MongoDB", False else: return "MongoDB", True async def check_s3(config: S3Config) -> tuple[str, bool]: """Verify S3 connectivity by listing objects.""" try: loader = S3DatasetLoader(config) await loader.list_transcripts() except Exception as e: console.print(f"[red]S3 error: {e}[/red]") return "S3", False else: return "S3", True async def check_llm_provider( name: str, config: LLMProviderConfig, model_name: str ) -> tuple[str, bool]: """Verify connectivity to an LLM provider.""" try: client = AsyncOpenAI( api_key=config.api_key, base_url=config.api_base, max_retries=0, timeout=5.0, ) await client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": "ping"}], max_tokens=1, ) except Exception as e: console.print(f"[red]{name} error: {e}[/red]") return name, False else: return name, True async def check_all_connections(config: HeliaConfig) -> bool: """ Run all pre-flight checks. Returns True if all checks pass. """ console.print("\n[bold cyan]🔍 Starting Pre-flight Checks[/bold cyan]\n") checks = [] check_names = [] checks.append(check_mongo(config)) check_names.append("MongoDB") checks.append(check_s3(config.s3)) check_names.append("S3") for name, provider_config in config.providers.items(): model_name = next( (run.model.model_name for run in config.runs.values() if run.model.provider == name), None, ) if model_name: checks.append(check_llm_provider(name, provider_config, model_name)) check_names.append(f"LLM Provider: {name}") else: console.print( f"[yellow]⊘ Skipping LLM Provider '{name}': " f"defined in config but not used in any run[/yellow]" ) results = await asyncio.gather(*checks, return_exceptions=True) # Create results table table = Table(title="Pre-flight Check Results", show_header=True, header_style="bold") table.add_column("Service", style="cyan") table.add_column("Status", justify="center") all_passed = True for i, result in enumerate(results): service_name = check_names[i] if i < len(check_names) else "Unknown" if isinstance(result, Exception): table.add_row(service_name, "[bold red]✗ Failed[/bold red]") all_passed = False elif isinstance(result, tuple): name, passed = result if passed: table.add_row(name, "[bold green]✓ Passed[/bold green]") else: table.add_row(name, "[bold red]✗ Failed[/bold red]") all_passed = False elif result is False: table.add_row(service_name, "[bold red]✗ Failed[/bold red]") all_passed = False console.print(table) if all_passed: console.print("\n[bold green]✓ All Pre-flight Checks Passed[/bold green]\n") else: console.print("\n[bold red]✗ Pre-flight Checks Failed[/bold red]\n") return all_passed if __name__ == "__main__": config = HeliaConfig() # type: ignore[missing-argument, unknown-argument] all_ok = asyncio.run(check_all_connections(config)) if not all_ok: raise SystemExit(1)