138 lines
4.0 KiB
Python
138 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import TYPE_CHECKING
|
|
|
|
from openai import AsyncOpenAI
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
from helia.configuration import HeliaConfig
|
|
from helia.db import init_db
|
|
from helia.ingestion.s3 import S3DatasetLoader
|
|
|
|
if TYPE_CHECKING:
|
|
from helia.configuration import LLMProviderConfig, S3Config
|
|
|
|
console = Console()
|
|
|
|
|
|
async def check_mongo(config: HeliaConfig) -> tuple[str, bool]:
|
|
"""Verify MongoDB connectivity."""
|
|
try:
|
|
await init_db(config)
|
|
except Exception as e:
|
|
console.print(f"[red]MongoDB error: {e}[/red]")
|
|
return "MongoDB", False
|
|
else:
|
|
return "MongoDB", True
|
|
|
|
|
|
async def check_s3(config: S3Config) -> tuple[str, bool]:
|
|
"""Verify S3 connectivity by listing objects."""
|
|
try:
|
|
loader = S3DatasetLoader(config)
|
|
await loader.list_transcripts()
|
|
except Exception as e:
|
|
console.print(f"[red]S3 error: {e}[/red]")
|
|
return "S3", False
|
|
else:
|
|
return "S3", True
|
|
|
|
|
|
async def check_llm_provider(
|
|
name: str, config: LLMProviderConfig, model_name: str
|
|
) -> tuple[str, bool]:
|
|
"""Verify connectivity to an LLM provider."""
|
|
try:
|
|
client = AsyncOpenAI(
|
|
api_key=config.api_key,
|
|
base_url=config.api_base,
|
|
max_retries=0,
|
|
timeout=5.0,
|
|
)
|
|
|
|
await client.chat.completions.create(
|
|
model=model_name,
|
|
messages=[{"role": "user", "content": "ping"}],
|
|
max_tokens=1,
|
|
)
|
|
except Exception as e:
|
|
console.print(f"[red]{name} error: {e}[/red]")
|
|
return name, False
|
|
else:
|
|
return name, True
|
|
|
|
|
|
async def check_all_connections(config: HeliaConfig) -> bool:
|
|
"""
|
|
Run all pre-flight checks. Returns True if all checks pass.
|
|
"""
|
|
console.print("\n[bold cyan]🔍 Starting Pre-flight Checks[/bold cyan]\n")
|
|
|
|
checks = []
|
|
check_names = []
|
|
|
|
checks.append(check_mongo(config))
|
|
check_names.append("MongoDB")
|
|
|
|
checks.append(check_s3(config.s3))
|
|
check_names.append("S3")
|
|
|
|
for name, provider_config in config.providers.items():
|
|
model_name = next(
|
|
(run.model.model_name for run in config.runs.values() if run.model.provider == name),
|
|
None,
|
|
)
|
|
|
|
if model_name:
|
|
checks.append(check_llm_provider(name, provider_config, model_name))
|
|
check_names.append(f"LLM Provider: {name}")
|
|
else:
|
|
console.print(
|
|
f"[yellow]⊘ Skipping LLM Provider '{name}': "
|
|
f"defined in config but not used in any run[/yellow]"
|
|
)
|
|
|
|
results = await asyncio.gather(*checks, return_exceptions=True)
|
|
|
|
# Create results table
|
|
table = Table(title="Pre-flight Check Results", show_header=True, header_style="bold")
|
|
table.add_column("Service", style="cyan")
|
|
table.add_column("Status", justify="center")
|
|
|
|
all_passed = True
|
|
for i, result in enumerate(results):
|
|
service_name = check_names[i] if i < len(check_names) else "Unknown"
|
|
|
|
if isinstance(result, Exception):
|
|
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
|
|
all_passed = False
|
|
elif isinstance(result, tuple):
|
|
name, passed = result
|
|
if passed:
|
|
table.add_row(name, "[bold green]✓ Passed[/bold green]")
|
|
else:
|
|
table.add_row(name, "[bold red]✗ Failed[/bold red]")
|
|
all_passed = False
|
|
elif result is False:
|
|
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
|
|
all_passed = False
|
|
|
|
console.print(table)
|
|
|
|
if all_passed:
|
|
console.print("\n[bold green]✓ All Pre-flight Checks Passed[/bold green]\n")
|
|
else:
|
|
console.print("\n[bold red]✗ Pre-flight Checks Failed[/bold red]\n")
|
|
|
|
return all_passed
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = HeliaConfig() # type: ignore[missing-argument, unknown-argument]
|
|
|
|
all_ok = asyncio.run(check_all_connections(config))
|
|
if not all_ok:
|
|
raise SystemExit(1)
|