Refactor configuration management; update S3 integration, add new migration scripts, and implement pre-flight checks
This commit is contained in:
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
137
scripts/preflight.py
Normal file
137
scripts/preflight.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from helia.configuration import HeliaConfig
|
||||
from helia.db import init_db
|
||||
from helia.ingestion.s3 import S3DatasetLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from helia.configuration import LLMProviderConfig, S3Config
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
async def check_mongo(config: HeliaConfig) -> tuple[str, bool]:
|
||||
"""Verify MongoDB connectivity."""
|
||||
try:
|
||||
await init_db(config)
|
||||
except Exception as e:
|
||||
console.print(f"[red]MongoDB error: {e}[/red]")
|
||||
return "MongoDB", False
|
||||
else:
|
||||
return "MongoDB", True
|
||||
|
||||
|
||||
async def check_s3(config: S3Config) -> tuple[str, bool]:
|
||||
"""Verify S3 connectivity by listing objects."""
|
||||
try:
|
||||
loader = S3DatasetLoader(config)
|
||||
await loader.list_transcripts()
|
||||
except Exception as e:
|
||||
console.print(f"[red]S3 error: {e}[/red]")
|
||||
return "S3", False
|
||||
else:
|
||||
return "S3", True
|
||||
|
||||
|
||||
async def check_llm_provider(
|
||||
name: str, config: LLMProviderConfig, model_name: str
|
||||
) -> tuple[str, bool]:
|
||||
"""Verify connectivity to an LLM provider."""
|
||||
try:
|
||||
client = AsyncOpenAI(
|
||||
api_key=config.api_key,
|
||||
base_url=config.api_base,
|
||||
max_retries=0,
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[red]{name} error: {e}[/red]")
|
||||
return name, False
|
||||
else:
|
||||
return name, True
|
||||
|
||||
|
||||
async def check_all_connections(config: HeliaConfig) -> bool:
|
||||
"""
|
||||
Run all pre-flight checks. Returns True if all checks pass.
|
||||
"""
|
||||
console.print("\n[bold cyan]🔍 Starting Pre-flight Checks[/bold cyan]\n")
|
||||
|
||||
checks = []
|
||||
check_names = []
|
||||
|
||||
checks.append(check_mongo(config))
|
||||
check_names.append("MongoDB")
|
||||
|
||||
checks.append(check_s3(config.s3))
|
||||
check_names.append("S3")
|
||||
|
||||
for name, provider_config in config.providers.items():
|
||||
model_name = next(
|
||||
(run.model.model_name for run in config.runs.values() if run.model.provider == name),
|
||||
None,
|
||||
)
|
||||
|
||||
if model_name:
|
||||
checks.append(check_llm_provider(name, provider_config, model_name))
|
||||
check_names.append(f"LLM Provider: {name}")
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]⊘ Skipping LLM Provider '{name}': "
|
||||
f"defined in config but not used in any run[/yellow]"
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*checks, return_exceptions=True)
|
||||
|
||||
# Create results table
|
||||
table = Table(title="Pre-flight Check Results", show_header=True, header_style="bold")
|
||||
table.add_column("Service", style="cyan")
|
||||
table.add_column("Status", justify="center")
|
||||
|
||||
all_passed = True
|
||||
for i, result in enumerate(results):
|
||||
service_name = check_names[i] if i < len(check_names) else "Unknown"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
|
||||
all_passed = False
|
||||
elif isinstance(result, tuple):
|
||||
name, passed = result
|
||||
if passed:
|
||||
table.add_row(name, "[bold green]✓ Passed[/bold green]")
|
||||
else:
|
||||
table.add_row(name, "[bold red]✗ Failed[/bold red]")
|
||||
all_passed = False
|
||||
elif result is False:
|
||||
table.add_row(service_name, "[bold red]✗ Failed[/bold red]")
|
||||
all_passed = False
|
||||
|
||||
console.print(table)
|
||||
|
||||
if all_passed:
|
||||
console.print("\n[bold green]✓ All Pre-flight Checks Passed[/bold green]\n")
|
||||
else:
|
||||
console.print("\n[bold red]✗ Pre-flight Checks Failed[/bold red]\n")
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = HeliaConfig() # type: ignore[missing-argument, unknown-argument]
|
||||
|
||||
all_ok = asyncio.run(check_all_connections(config))
|
||||
if not all_ok:
|
||||
raise SystemExit(1)
|
||||
39
scripts/verify_turns.py
Normal file
39
scripts/verify_turns.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from helia.configuration import HeliaConfig
|
||||
from helia.db import init_db
|
||||
from helia.models.transcript import Transcript
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def verify() -> None:
|
||||
config = HeliaConfig() # type: ignore[call-arg]
|
||||
await init_db(config)
|
||||
|
||||
transcript = await Transcript.find_one(Transcript.transcript_id == "300")
|
||||
if not transcript:
|
||||
logger.error("Transcript 300 not found.")
|
||||
return
|
||||
|
||||
logger.info("Transcript 300 found with %d utterances.", len(transcript.utterances))
|
||||
|
||||
turns = transcript.turns
|
||||
logger.info("Aggregated into %d turns.", len(turns))
|
||||
|
||||
for i, turn in enumerate(turns[:5]):
|
||||
logger.info(
|
||||
"Turn %d [%s] (%s - %s): %s... (Merged %d utterances)",
|
||||
i + 1,
|
||||
turn.speaker,
|
||||
turn.start_time,
|
||||
turn.end_time,
|
||||
turn.value[:50],
|
||||
turn.utterance_count,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(verify())
|
||||
Reference in New Issue
Block a user