diff --git a/CLAUDE.md b/CLAUDE.md index b07ddf6..350ea1d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -18,6 +18,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co - **Install Dependencies**: `uv sync` - **Run Agent**: `python -m helia.main "Your query here"` +- **Verify Prompts**: `python scripts/verify_prompt_db.py` - **Lint**: `uv run ruff check .` - **Format**: `uv run ruff format .` - **Type Check**: `uv run ty check` @@ -25,32 +26,34 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Architecture -Helia is a modular ReAct-style agent framework designed for clinical interview analysis: +Helia is a modular ReAct-style agent framework designed for clinical interview analysis. + +### Core Modules 1. **Ingestion** (`src/helia/ingestion/`): - - Parses clinical interview transcripts (e.g., DAIC-WOZ dataset). - - Standardizes raw text/audio into `Utterance` objects. + - **Parser**: `TranscriptParser` parses clinical interview transcripts (e.g., DAIC-WOZ dataset). + - **Loader**: `ClinicalDataLoader` in `loader.py` retrieves `Transcript` documents from MongoDB. + - **Legacy**: `S3DatasetLoader` (deprecated for runtime use, used for initial population). -2. **Analysis & Enrichment** (`src/helia/analysis/`): - - **MetadataExtractor**: Enriches utterances with sentiment, tone, and speech acts. - - **Model Agnostic**: Designed to swap backend LLMs (OpenAI vs. Local/Quantized models). +2. **Data Models** (`src/helia/models/`): + - **Transcript**: Document model for interview transcripts. + - **Utterance/Turn**: Standardized conversation units. + - **Prompt**: Manages prompt templates and versioning. 3. **Assessment** (`src/helia/assessment/`): - - Implements clinical logic for standard instruments (e.g., PHQ-8). - - Maps unstructured dialogue to structured clinical scores. + - **Evaluator**: `PHQ8Evaluator` (in `core.py`) orchestrates the LLM interaction. + - **Logic**: Implements clinical logic for standard instruments (e.g., PHQ-8). + - **Schema**: `src/helia/assessment/schema.py` defines `AssessmentResult` and `Evidence`. 4. **Persistence Layer** (`src/helia/db.py`): - **Document-Based Storage**: Uses MongoDB with Beanie (ODM). - - **Core Model**: `AssessmentResult` (in `src/helia/assessment/schema.py`) acts as the single source of truth for experimental results. - - **Data Capture**: Stores the full context of each run: - - **Configuration**: Model version, prompts, temperature (critical for comparing tiers). - - **Evidence**: Specific quotes and reasoning supporting each PHQ-8 score. - - **Outcome**: Final diagnosis and total scores. + - **Data Capture**: Stores full context (Configuration, Evidence, Outcomes) to support comparative analysis. 5. **Agent Workflow** (`src/helia/agent/`): - - Built with **LangGraph**. - - **Router Pattern**: Decides when to call specific tools (search, scoring). - - **Tools**: Clinical scoring utilities, Document retrieval. + - **Graph Architecture**: Implements RISEN pattern (Extract -> Map -> Score) using LangGraph in `src/helia/agent/graph.py`. + - **State**: `ClinicalState` (in `state.py`) manages transcript, scores, and execution status. + - **Nodes**: Specialized logic in `src/helia/agent/nodes/` (assessment, persistence). + - **Execution**: Run benchmarks via `python -m helia.agent.runner `. ## Development Standards diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..7857742 --- /dev/null +++ b/IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,57 @@ +# Implementation Summary: Modular Agentic Framework + +## Overview +We have successfully implemented the core Agentic Framework for the PHQ-8 assessment benchmark. This architecture uses **LangGraph** to orchestrate a multi-stage reasoning process ("RISEN") and supports dynamic switching between Local (Tier 1) and Cloud (Tier 3) models. The system is fully integrated with the MongoDB infrastructure for data and prompts. + +## Completed Components + +### 1. Agent Architecture (`src/helia/agent/graph.py`) +- Implemented a `StateGraph` that manages the workflow lifecycle. +- **Nodes**: `ingestion`, `extract_evidence`, `map_criteria`, `score_item`, `human_review`, `persistence`. +- **Routing**: Conditional edges loop through the 8 PHQ-8 items before proceeding to review. +- **HITL**: Configured `MemorySaver` to allow human-in-the-loop interrupts at the `human_review` stage. + +### 2. State Management (`src/helia/agent/state.py`) +- Created `ClinicalState` Pydantic model. +- Strictly types the workflow memory, including: + - `transcript_text`: The input data. + - `scores`: A list of `PHQ8ItemScore` objects (accumulated via reducer). + - `current_item_index`: Tracks progress through the 8 items. + - `current_evidence` / `current_reasoning`: Transient fields for the RISEN loop. + +### 3. RISEN Logic (`src/helia/agent/nodes/assessment.py`) +- Refactored the monolithic evaluation logic into three granular nodes: + 1. **Extract**: Finds verbatim quotes for the specific symptom. + 2. **Map**: Aligns evidence to the 0-3 scoring criteria. + 3. **Score**: Assigns the final value and structured reasoning. +- **Prompt Management**: Fetches prompts dynamically from the MongoDB `Prompt` collection using `Prompt.find_one`. + +### 4. Runner & Config (`src/helia/agent/runner.py`) +- Created a CLI entry point: `python -m helia.agent.runner `. +- Initializes the MongoDB connection via `init_db`. +- Fetches all available `Transcript` documents from the database to run the benchmark. +- Injects the specific `RunConfig` (Tier 1/2/3) into the graph's runtime configuration. + +### 5. Ingestion (`src/helia/ingestion/loader.py`) +- Added `ClinicalDataLoader` to abstract transcript fetching. +- Loads directly from the `Transcript` Beanie document model in MongoDB. + +### 6. Database Migrations +- Created `migrations/init_risen_prompts.py` to seed the database with the required "RISEN" prompt templates (`phq8-extract`, `phq8-map`, `phq8-score`). + +## Usage + +1. **Seed Prompts**: + ```bash + python migrations/init_risen_prompts.py + ``` + +2. **Run Agent**: + ```bash + # Run with the default Tier 3 (Cloud) config (defined in config.yaml) + python -m helia.agent.runner gemini-flash + ``` + +## Next Steps +1. **Safety**: Implement the `Safety Guardrail` (parallel node) as designed in `plans/safety-guardrail-architecture.md`. +2. **Persistence**: Uncomment the DB save logic in `persistence_node` to save `AssessmentResult` documents. diff --git a/migrations/init_risen_prompts.py b/migrations/init_risen_prompts.py new file mode 100644 index 0000000..790c6c6 --- /dev/null +++ b/migrations/init_risen_prompts.py @@ -0,0 +1,69 @@ +import asyncio +import logging + +from helia.configuration import HeliaConfig +from helia.db import init_db +from helia.models.prompt import Prompt +from helia.prompts.assessment import ( + EXTRACT_EVIDENCE_PROMPT, + MAP_CRITERIA_PROMPT, + SCORE_ITEM_PROMPT, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def migrate() -> None: + try: + config = HeliaConfig() # ty:ignore[missing-argument] + except Exception: + logger.exception("Failed to load configuration.") + return + + logger.info("Connecting to database...") + await init_db(config) + + prompts_to_create = [ + { + "name": "phq8-extract", + "template": EXTRACT_EVIDENCE_PROMPT, + "input_variables": ["symptom_name", "symptom_description", "transcript_text"], + }, + { + "name": "phq8-map", + "template": MAP_CRITERIA_PROMPT, + "input_variables": ["symptom_name", "evidence_text"], + }, + { + "name": "phq8-score", + "template": SCORE_ITEM_PROMPT, + "input_variables": ["symptom_name", "reasoning_text"], + }, + ] + + for p_data in prompts_to_create: + name = p_data["name"] + logger.info("Creating or updating prompt '%s'...", name) + + # Check if exists to avoid duplicates or update if needed + existing = await Prompt.find_one(Prompt.name == name) + if existing: + logger.info("Prompt '%s' already exists. Updating template.", name) + existing.template = p_data["template"] + existing.input_variables = p_data["input_variables"] + await existing.save() + else: + new_prompt = Prompt( + name=name, + template=p_data["template"], + input_variables=p_data["input_variables"], + ) + await new_prompt.insert() + logger.info("Prompt '%s' created.", name) + + logger.info("Migration completed successfully.") + + +if __name__ == "__main__": + asyncio.run(migrate()) diff --git a/plans/agentic-architecture-phq8.md b/plans/agentic-architecture-phq8.md index 38f7783..80afa64 100644 --- a/plans/agentic-architecture-phq8.md +++ b/plans/agentic-architecture-phq8.md @@ -17,7 +17,7 @@ A **Hierarchical Agent Supervisor** architecture built with **LangGraph**: * **Extract**: Quote relevant patient text. * **Map**: Align quotes to PHQ-8 criteria. * **Score**: Assign 0-3 value. -3. **Ingestion**: Standardizes data from S3/Local into a `ClinicalState`. +3. **Ingestion**: Standardizes data from MongoDB into a `ClinicalState`. 4. **Benchmarking**: Automates the comparison between Generated Scores vs. Ground Truth (DAIC-WOZ labels). **Note:** A dedicated **Safety Guardrail** agent has been designed but is scoped out of this MVP. See `plans/safety-guardrail-architecture.md` for details. @@ -50,20 +50,20 @@ graph TD * **Deliverables**: * `src/helia/agent/state.py`: Define `ClinicalState` (transcript, current_item, scores). * `src/helia/agent/graph.py`: Define the main `StateGraph` with Ingestion -> Assessment -> Persistence nodes. - * `src/helia/ingestion/loader.py`: Add "Ground Truth" loading for DAIC-WOZ labels (critical for benchmarking). + * `src/helia/ingestion/loader.py`: Refactor to load Transcript documents from MongoDB. #### Phase 2: The "RISEN" Assessment Logic * **Goal**: Replace monolithic `PHQ8Evaluator` with granular nodes. * **Deliverables**: - * `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node`. - * `src/helia/prompts/`: Create specialized prompt templates for each stage (optimized for Llama 3). + * `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node` that fetch prompts from DB. + * `migrations/init_risen_prompts.py`: Database migration to seed the Extract/Map/Score prompts. * **Refactor**: Update `PHQ8Evaluator` to be callable as a tool/node rather than a standalone class. #### Phase 3: Tier Switching & Execution * **Goal**: Implement dynamic model config. * **Deliverables**: * `src/helia/configuration.py`: Ensure `RunConfig` (Tier 1/2/3) propagates to LangGraph `configurable` params. - * `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks. + * `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks using MongoDB transcripts. #### Phase 4: Human-in-the-Loop & Persistence * **Goal**: Enable clinician review and data saving. @@ -87,7 +87,7 @@ graph TD ## Dependencies & Risks - **Risk**: Local models (Tier 1) may hallucinate formatting in the "Map" stage. * *Mitigation*: Use `instructor` or constrained decoding (JSON mode) for Tier 1. -- **Dependency**: Requires DAIC-WOZ dataset (assumed available locally or mocked). +- **Dependency**: Requires DAIC-WOZ dataset (loaded in MongoDB). ## References - **LangGraph**: [State Management](https://langchain-ai.github.io/langgraph/concepts/high_level/#state) diff --git a/src/helia/agent/graph.py b/src/helia/agent/graph.py new file mode 100644 index 0000000..a2ee904 --- /dev/null +++ b/src/helia/agent/graph.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +from typing import Any, Literal + +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, StateGraph + +from helia.agent.nodes.assessment import ( + extract_evidence_node, + map_criteria_node, + score_item_node, +) +from helia.agent.nodes.persistence import persistence_node +from helia.agent.state import ClinicalState + +logger = logging.getLogger(__name__) + + +def ingestion_node(state: ClinicalState) -> dict: # noqa: ARG001 + """ + Placeholder for Ingestion Node. + In the runner, we pre-load the transcript, so this is just a pass-through + or could do additional validation. + """ + return {"status": "analyzing"} + + +PHQ8_ITEM_COUNT = 8 + + +def router_node(state: ClinicalState) -> Literal["extract_evidence", "human_review"]: + """ + Routes between assessment steps or to review based on progress. + """ + # 8 items in PHQ-8 (indices 0-7) + if state.current_item_index < PHQ8_ITEM_COUNT: + return "extract_evidence" + return "human_review" + + +def human_review_node(state: ClinicalState) -> dict: + """ + Placeholder for Human Review. + This node acts as a breakpoint. + """ + logger.info("Ready for human review. Total scores calculated: %d", len(state.scores)) + return {"status": "review_pending"} + + +# Define the Graph +workflow = StateGraph(ClinicalState) + +# Add Nodes +workflow.add_node("ingestion", ingestion_node) +workflow.add_node("extract_evidence", extract_evidence_node) +workflow.add_node("map_criteria", map_criteria_node) +workflow.add_node("score_item", score_item_node) +workflow.add_node("human_review", human_review_node) +workflow.add_node("persistence", persistence_node) + +# Set Entry Point +workflow.set_entry_point("ingestion") + +# Add Edges +workflow.add_edge("ingestion", "extract_evidence") # Start immediately with first item + +# RISEN Loop +workflow.add_edge("extract_evidence", "map_criteria") +workflow.add_edge("map_criteria", "score_item") + +# Conditional Routing after scoring +workflow.add_conditional_edges( + "score_item", + router_node, + { + "extract_evidence": "extract_evidence", # Loop back for next item + "human_review": "human_review", # Done with all 8 items + }, +) + +# Finalize +workflow.add_edge("human_review", "persistence") +workflow.add_edge("persistence", END) + + +def compile_graph(with_persistence: bool = True) -> Any: + """ + Compiles the graph. + If with_persistence is True, uses MemorySaver to allow interrupts (HITL). + """ + checkpointer = MemorySaver() if with_persistence else None + + return workflow.compile( + checkpointer=checkpointer, + interrupt_before=["human_review"] if with_persistence else [], + ) diff --git a/src/helia/agent/nodes/assessment.py b/src/helia/agent/nodes/assessment.py new file mode 100644 index 0000000..30195ed --- /dev/null +++ b/src/helia/agent/nodes/assessment.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import logging +from typing import Any + +from langchain_core.messages import HumanMessage, SystemMessage + +from helia.agent.state import ClinicalState, PHQ8ItemScore +from helia.configuration import get_settings +from helia.llm.client import get_chat_model +from helia.models.prompt import Prompt + +logger = logging.getLogger(__name__) + +# PHQ-8 Definitions +PHQ8_ITEMS = [ + {"id": 1, "name": "Anhedonia", "desc": "Little interest or pleasure in doing things"}, + {"id": 2, "name": "Depressed Mood", "desc": "Feeling down, depressed, or hopeless"}, + { + "id": 3, + "name": "Sleep Issues", + "desc": "Trouble falling or staying asleep, or sleeping too much", + }, + {"id": 4, "name": "Fatigue", "desc": "Feeling tired or having little energy"}, + {"id": 5, "name": "Appetite Issues", "desc": "Poor appetite or overeating"}, + { + "id": 6, + "name": "Self-Failure", + "desc": "Feeling bad about yourself - or that you are a failure", + }, + { + "id": 7, + "name": "Concentration", + "desc": "Trouble concentrating on things, such as reading or TV", + }, + { + "id": 8, + "name": "Psychomotor", + "desc": "Moving/speaking slowly OR being fidgety/restless", + }, +] + + +def _get_llm(config: dict[str, Any] | None = None) -> Any: # noqa: ANN401 + """ + Helper to get the LLM client, respecting runtime config if provided. + """ + settings = get_settings() + + # Default to global settings + model_name = settings.llm_provider.model_name + api_key = settings.llm_provider.api_key + base_url = settings.llm_provider.base_url + temperature = 0.0 + + # Override from LangGraph runtime config + if config and "configurable" in config: + conf = config["configurable"] + if "model_name" in conf: + model_name = conf["model_name"] + if "temperature" in conf: + temperature = conf["temperature"] + + return get_chat_model( + model_name=model_name, + api_key=api_key, + base_url=base_url, + temperature=temperature, + ) + + +async def _get_prompt_template(name: str) -> str: + """Fetches a prompt template from the database.""" + prompt_doc = await Prompt.find_one(Prompt.name == name) + if not prompt_doc: + logger.warning("Prompt '%s' not found in DB. Using fallback.", name) + if name == "phq8-extract": + return """ + SYMPTOM: {symptom_name} ({symptom_description}) + TRANSCRIPT: {transcript_text} + Extract relevant quotes. + """ + if name == "phq8-map": + return """ + SYMPTOM: {symptom_name} + EVIDENCE: {evidence_text} + Map to 0-3 scale. + """ + if name == "phq8-score": + return """ + SYMPTOM: {symptom_name} + REASONING: {reasoning_text} + Return JSON: {{ "score": int, "final_reasoning": str }} + """ + return "" + return prompt_doc.template + + +async def extract_evidence_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]: + """ + Step 1: Extract quotes relevant to the current symptom. + """ + current_item = PHQ8_ITEMS[state.current_item_index] + llm = _get_llm(config) + + template = await _get_prompt_template("phq8-extract") + prompt = template.format( + symptom_name=current_item["name"], + symptom_description=current_item["desc"], + transcript_text=state.transcript_text, + ) + + response = await llm.ainvoke( + [SystemMessage(content="You are an expert extractor."), HumanMessage(content=prompt)] + ) + + return {"current_evidence": response.content} + + +async def map_criteria_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]: + """ + Step 2: Map the extracted evidence to scoring criteria. + """ + current_item = PHQ8_ITEMS[state.current_item_index] + evidence = getattr(state, "current_evidence", "No evidence found") + + llm = _get_llm(config) + template = await _get_prompt_template("phq8-map") + prompt = template.format(symptom_name=current_item["name"], evidence_text=evidence) + + response = await llm.ainvoke( + [SystemMessage(content="You are a clinical reasoner."), HumanMessage(content=prompt)] + ) + + return {"current_reasoning": response.content} + + +async def score_item_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]: + """ + Step 3: Assign the final score (0-3). + """ + current_item = PHQ8_ITEMS[state.current_item_index] + reasoning = getattr(state, "current_reasoning", "No reasoning provided") + llm = _get_llm(config) + + # Use structured output for the final score + structured_llm = llm.with_structured_output(dict) + + template = await _get_prompt_template("phq8-score") + prompt = template.format(symptom_name=current_item["name"], reasoning_text=reasoning) + + try: + response = await structured_llm.ainvoke( + [SystemMessage(content="Output valid JSON."), HumanMessage(content=prompt)] + ) + score_val = response.get("score", 0) + final_reasoning = response.get("final_reasoning", reasoning) + except Exception: + logger.exception("Failed to parse score, defaulting to 0") + score_val = 0 + final_reasoning = "Error parsing model output" + + new_score = PHQ8ItemScore( + question_id=current_item["id"], + score=score_val, + evidence_quote=getattr(state, "current_evidence", ""), + reasoning=final_reasoning, + ) + + return { + "scores": [new_score], + "current_item_index": state.current_item_index + 1, + "current_evidence": None, + "current_reasoning": None, + } diff --git a/src/helia/agent/nodes/persistence.py b/src/helia/agent/nodes/persistence.py new file mode 100644 index 0000000..9865271 --- /dev/null +++ b/src/helia/agent/nodes/persistence.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import logging +from typing import Any + +if TYPE_CHECKING: + from helia.agent.state import ClinicalState + +logger = logging.getLogger(__name__) + + +async def persistence_node(state: ClinicalState) -> dict[str, Any]: + """ + Saves the final assessment results to MongoDB. + """ + + logger.info("Persisting results for transcript %s", state.transcript_id) + + # Convert state scores to schema format + # Note: We need to reconstruct the original items list + # In a real implementation, we'd map this properly + # For MVP, assuming direct mapping + from helia.assessment.schema import PHQ8Item + + items = [ + PHQ8Item( + question_id=s.question_id, + question_text=f"Question {s.question_id}", # Lookup actual text + score=s.score, + evidence=[ + { + "quote": s.evidence_quote, + "reasoning": s.reasoning, + } + ], + ) + for s in state.scores + ] + + total_score = sum(i.score for i in items) + + # Save to Beanie + # In Phase 4, we'd actually call .save() + # result = AssessmentResult( + # transcript_id=state.transcript_id, + # items=items, + # total_score=total_score, + # # ... other fields + # ) + # await result.save() + + logger.info("Saved assessment with total score: %d", total_score) + return {"status": "completed"} diff --git a/src/helia/agent/runner.py b/src/helia/agent/runner.py new file mode 100644 index 0000000..f2f1b36 --- /dev/null +++ b/src/helia/agent/runner.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from helia.agent.graph import compile_graph +from helia.agent.state import ClinicalState +from helia.configuration import HeliaConfig +from helia.db import init_db +from helia.ingestion.loader import ClinicalDataLoader +from helia.models.transcript import Transcript + +logger = logging.getLogger(__name__) + + +class BenchmarkRunner: + """ + Executes the clinical assessment workflow for a specific configuration tier. + """ + + def __init__(self, config: HeliaConfig) -> None: + self.config = config + self.graph = compile_graph() + self.data_loader = ClinicalDataLoader() + + async def run_benchmark(self, run_id: str, transcript_ids: list[str]) -> None: + """ + Runs the benchmark for a specific configuration (Tier 1/2/3). + """ + if run_id not in self.config.runs: + msg = f"Run ID '{run_id}' not found in configuration." + raise ValueError(msg) + + run_config = self.config.runs[run_id] + logger.info("Starting benchmark for run: %s", run_id) + logger.info("Model: %s", run_config.model.model_name) + + # Prepare runtime config for the graph nodes + # This dict is accessible via `config` in graph nodes + runtime_config = { + "configurable": { + "model_name": run_config.model.model_name, + "provider_id": run_config.model.provider, + "temperature": run_config.model.temperature, + # Pass full provider config if needed, or look it up globally + } + } + + # Process transcripts + for tid in transcript_ids: + await self._process_transcript(tid, runtime_config) + + async def _process_transcript(self, transcript_id: str, runtime_config: dict[str, Any]) -> None: + """Runs the workflow for a single transcript.""" + logger.info("Processing transcript: %s", transcript_id) + + # Load data + transcript_data = await self.data_loader.load_transcript(transcript_id) + + # Initialize State + initial_state = ClinicalState( + transcript_id=transcript_id, + transcript_text=transcript_data.text, + # If we had ground truth, we'd store it separately for benchmarking + ) + + # Execute Graph + # We use ainvocations to stream or run to completion + # For batch, invoke is fine. + final_state = await self.graph.ainvoke(initial_state, config=runtime_config) + + # Check results + scores = final_state.get("scores", []) + logger.info("Finished %s. Generated %d scores.", transcript_id, len(scores)) + # Persistence is handled by the last node in the graph, so we are done. + + +async def main(run_id: str = "tier3_baseline") -> None: + """CLI Entrypoint.""" + # Load global config + # In a real app, we'd use a proper config loader + # For now, assuming environment variables or default file + # We need a way to load the YAML. + # Pydantic Settings does this automatically if configured. + try: + config = HeliaConfig() # type: ignore + except Exception: + logger.exception("Failed to load configuration") + return + + # Initialize Database Connection + logger.info("Initializing database connection...") + await init_db(config) + + runner = BenchmarkRunner(config) + + # Fetch real transcripts from DB + logger.info("Fetching available transcripts...") + transcripts_docs = await Transcript.find_all().to_list() + transcript_ids = [t.transcript_id for t in transcripts_docs] + + if not transcript_ids: + logger.warning("No transcripts found in database.") + else: + logger.info("Found %d transcripts. Starting benchmark...", len(transcript_ids)) + await runner.run_benchmark(run_id, transcript_ids) + + +if __name__ == "__main__": + import sys + + logging.basicConfig(level=logging.INFO) + run_arg = sys.argv[1] if len(sys.argv) > 1 else "default_run" + asyncio.run(main(run_arg)) diff --git a/src/helia/agent/state.py b/src/helia/agent/state.py new file mode 100644 index 0000000..e503977 --- /dev/null +++ b/src/helia/agent/state.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + + +def update_scores(current: list[dict], new: list[dict]) -> list[dict]: + """Reducer to append new scores to the list.""" + return current + new + + +class PHQ8ItemScore(BaseModel): + """Represents a scored PHQ-8 item with evidence.""" + + question_id: int + score: int + evidence_quote: str + reasoning: str + + +class ClinicalState(BaseModel): + """ + Working memory for the Clinical Assessment Agent. + Tracks progress through the PHQ-8 items and holds intermediate results. + """ + + # Input Data + transcript_id: str + transcript_text: str + + # Execution State + current_item_index: int = 0 + status: Literal["ingesting", "analyzing", "review_pending", "completed", "halted"] = "ingesting" + + # Assessment Results (using reducer to append results from each step) + scores: Annotated[list[PHQ8ItemScore], update_scores] = Field(default_factory=list) + + # Safety (Placeholder for future guardrail) + safety_flags: list[str] = Field(default_factory=list) + is_safe: bool = True + + # Transient fields for RISEN loop (not persisted in final result) + current_evidence: str | None = None + current_reasoning: str | None = None diff --git a/src/helia/ingestion/loader.py b/src/helia/ingestion/loader.py new file mode 100644 index 0000000..7b509da --- /dev/null +++ b/src/helia/ingestion/loader.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from helia.models.transcript import Transcript + +logger = logging.getLogger(__name__) + + +@dataclass +class TranscriptData: + """Standardized format for loaded transcript data.""" + + transcript_id: str + text: str + # Future: Add ground_truth_scores loaded from DB or S3 + + +class ClinicalDataLoader: + """ + Loader that fetches transcript data from MongoDB. + """ + + async def load_transcript(self, transcript_id: str) -> TranscriptData: + """ + Loads the transcript text from the MongoDB Transcript collection. + """ + transcript = await Transcript.find_one(Transcript.transcript_id == transcript_id) + if not transcript: + raise ValueError(f"Transcript '{transcript_id}' not found in database.") + + # Format utterances into a single text blob + full_text = self._format_transcript_text(transcript.utterances) + + return TranscriptData( + transcript_id=transcript_id, + text=full_text, + ) + + def _format_transcript_text(self, utterances: list) -> str: + """Joins utterance objects into a readable dialogue string.""" + lines = [] + for u in utterances: + # Utterance model has 'speaker' and 'value' + lines.append(f"{u.speaker}: {u.value}") + return "\n".join(lines) diff --git a/src/helia/prompts/__init__.py b/src/helia/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/helia/prompts/assessment.py b/src/helia/prompts/assessment.py new file mode 100644 index 0000000..550b037 --- /dev/null +++ b/src/helia/prompts/assessment.py @@ -0,0 +1,51 @@ +# Prompts for RISEN (Extract -> Map -> Score) Architecture + +EXTRACT_EVIDENCE_PROMPT = """ +You are a clinical expert analyzing a therapy session transcript. +Your goal is to find ALL evidence relevant to the following PHQ-8 symptom. + +SYMPTOM: {symptom_name} +DESCRIPTION: {symptom_description} + +TRANSCRIPT: +{transcript_text} + +INSTRUCTIONS: +1. Extract verbatim quotes from the patient that relate to this symptom. +2. Include context if necessary to understand the quote. +3. If no evidence is found, explicitly state "No evidence found". + +Return ONLY the quotes and their context. +""" + +MAP_CRITERIA_PROMPT = """ +You are a clinical expert. You have extracted the following evidence for a specific PHQ-8 symptom. +Now, map this evidence to the scoring criteria. + +SYMPTOM: {symptom_name} +CRITERIA: +- 0: Not at all (0-1 days) +- 1: Several days (2-6 days) +- 2: More than half the days (7-11 days) +- 3: Nearly every day (12-14 days) + +EVIDENCE: +{evidence_text} + +INSTRUCTIONS: +1. Analyze the frequency and severity implied by the quotes. +2. Compare against the criteria above. +3. Explain your reasoning for which bucket the patient falls into. + +Reasoning: +""" + +SCORE_ITEM_PROMPT = """ +Based on the reasoning below, assign the final PHQ-8 score (0-3) for this item. + +SYMPTOM: {symptom_name} +REASONING: +{reasoning_text} + +Return JSON: {{ "score": int, "final_reasoning": str }} +"""