WIP
This commit is contained in:
35
CLAUDE.md
35
CLAUDE.md
@@ -18,6 +18,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
- **Install Dependencies**: `uv sync`
|
- **Install Dependencies**: `uv sync`
|
||||||
- **Run Agent**: `python -m helia.main "Your query here"`
|
- **Run Agent**: `python -m helia.main "Your query here"`
|
||||||
|
- **Verify Prompts**: `python scripts/verify_prompt_db.py`
|
||||||
- **Lint**: `uv run ruff check .`
|
- **Lint**: `uv run ruff check .`
|
||||||
- **Format**: `uv run ruff format .`
|
- **Format**: `uv run ruff format .`
|
||||||
- **Type Check**: `uv run ty check`
|
- **Type Check**: `uv run ty check`
|
||||||
@@ -25,32 +26,34 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
Helia is a modular ReAct-style agent framework designed for clinical interview analysis:
|
Helia is a modular ReAct-style agent framework designed for clinical interview analysis.
|
||||||
|
|
||||||
|
### Core Modules
|
||||||
|
|
||||||
1. **Ingestion** (`src/helia/ingestion/`):
|
1. **Ingestion** (`src/helia/ingestion/`):
|
||||||
- Parses clinical interview transcripts (e.g., DAIC-WOZ dataset).
|
- **Parser**: `TranscriptParser` parses clinical interview transcripts (e.g., DAIC-WOZ dataset).
|
||||||
- Standardizes raw text/audio into `Utterance` objects.
|
- **Loader**: `ClinicalDataLoader` in `loader.py` retrieves `Transcript` documents from MongoDB.
|
||||||
|
- **Legacy**: `S3DatasetLoader` (deprecated for runtime use, used for initial population).
|
||||||
|
|
||||||
2. **Analysis & Enrichment** (`src/helia/analysis/`):
|
2. **Data Models** (`src/helia/models/`):
|
||||||
- **MetadataExtractor**: Enriches utterances with sentiment, tone, and speech acts.
|
- **Transcript**: Document model for interview transcripts.
|
||||||
- **Model Agnostic**: Designed to swap backend LLMs (OpenAI vs. Local/Quantized models).
|
- **Utterance/Turn**: Standardized conversation units.
|
||||||
|
- **Prompt**: Manages prompt templates and versioning.
|
||||||
|
|
||||||
3. **Assessment** (`src/helia/assessment/`):
|
3. **Assessment** (`src/helia/assessment/`):
|
||||||
- Implements clinical logic for standard instruments (e.g., PHQ-8).
|
- **Evaluator**: `PHQ8Evaluator` (in `core.py`) orchestrates the LLM interaction.
|
||||||
- Maps unstructured dialogue to structured clinical scores.
|
- **Logic**: Implements clinical logic for standard instruments (e.g., PHQ-8).
|
||||||
|
- **Schema**: `src/helia/assessment/schema.py` defines `AssessmentResult` and `Evidence`.
|
||||||
|
|
||||||
4. **Persistence Layer** (`src/helia/db.py`):
|
4. **Persistence Layer** (`src/helia/db.py`):
|
||||||
- **Document-Based Storage**: Uses MongoDB with Beanie (ODM).
|
- **Document-Based Storage**: Uses MongoDB with Beanie (ODM).
|
||||||
- **Core Model**: `AssessmentResult` (in `src/helia/assessment/schema.py`) acts as the single source of truth for experimental results.
|
- **Data Capture**: Stores full context (Configuration, Evidence, Outcomes) to support comparative analysis.
|
||||||
- **Data Capture**: Stores the full context of each run:
|
|
||||||
- **Configuration**: Model version, prompts, temperature (critical for comparing tiers).
|
|
||||||
- **Evidence**: Specific quotes and reasoning supporting each PHQ-8 score.
|
|
||||||
- **Outcome**: Final diagnosis and total scores.
|
|
||||||
|
|
||||||
5. **Agent Workflow** (`src/helia/agent/`):
|
5. **Agent Workflow** (`src/helia/agent/`):
|
||||||
- Built with **LangGraph**.
|
- **Graph Architecture**: Implements RISEN pattern (Extract -> Map -> Score) using LangGraph in `src/helia/agent/graph.py`.
|
||||||
- **Router Pattern**: Decides when to call specific tools (search, scoring).
|
- **State**: `ClinicalState` (in `state.py`) manages transcript, scores, and execution status.
|
||||||
- **Tools**: Clinical scoring utilities, Document retrieval.
|
- **Nodes**: Specialized logic in `src/helia/agent/nodes/` (assessment, persistence).
|
||||||
|
- **Execution**: Run benchmarks via `python -m helia.agent.runner <run_id>`.
|
||||||
|
|
||||||
## Development Standards
|
## Development Standards
|
||||||
|
|
||||||
|
|||||||
57
IMPLEMENTATION_SUMMARY.md
Normal file
57
IMPLEMENTATION_SUMMARY.md
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# Implementation Summary: Modular Agentic Framework
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
We have successfully implemented the core Agentic Framework for the PHQ-8 assessment benchmark. This architecture uses **LangGraph** to orchestrate a multi-stage reasoning process ("RISEN") and supports dynamic switching between Local (Tier 1) and Cloud (Tier 3) models. The system is fully integrated with the MongoDB infrastructure for data and prompts.
|
||||||
|
|
||||||
|
## Completed Components
|
||||||
|
|
||||||
|
### 1. Agent Architecture (`src/helia/agent/graph.py`)
|
||||||
|
- Implemented a `StateGraph` that manages the workflow lifecycle.
|
||||||
|
- **Nodes**: `ingestion`, `extract_evidence`, `map_criteria`, `score_item`, `human_review`, `persistence`.
|
||||||
|
- **Routing**: Conditional edges loop through the 8 PHQ-8 items before proceeding to review.
|
||||||
|
- **HITL**: Configured `MemorySaver` to allow human-in-the-loop interrupts at the `human_review` stage.
|
||||||
|
|
||||||
|
### 2. State Management (`src/helia/agent/state.py`)
|
||||||
|
- Created `ClinicalState` Pydantic model.
|
||||||
|
- Strictly types the workflow memory, including:
|
||||||
|
- `transcript_text`: The input data.
|
||||||
|
- `scores`: A list of `PHQ8ItemScore` objects (accumulated via reducer).
|
||||||
|
- `current_item_index`: Tracks progress through the 8 items.
|
||||||
|
- `current_evidence` / `current_reasoning`: Transient fields for the RISEN loop.
|
||||||
|
|
||||||
|
### 3. RISEN Logic (`src/helia/agent/nodes/assessment.py`)
|
||||||
|
- Refactored the monolithic evaluation logic into three granular nodes:
|
||||||
|
1. **Extract**: Finds verbatim quotes for the specific symptom.
|
||||||
|
2. **Map**: Aligns evidence to the 0-3 scoring criteria.
|
||||||
|
3. **Score**: Assigns the final value and structured reasoning.
|
||||||
|
- **Prompt Management**: Fetches prompts dynamically from the MongoDB `Prompt` collection using `Prompt.find_one`.
|
||||||
|
|
||||||
|
### 4. Runner & Config (`src/helia/agent/runner.py`)
|
||||||
|
- Created a CLI entry point: `python -m helia.agent.runner <run_id>`.
|
||||||
|
- Initializes the MongoDB connection via `init_db`.
|
||||||
|
- Fetches all available `Transcript` documents from the database to run the benchmark.
|
||||||
|
- Injects the specific `RunConfig` (Tier 1/2/3) into the graph's runtime configuration.
|
||||||
|
|
||||||
|
### 5. Ingestion (`src/helia/ingestion/loader.py`)
|
||||||
|
- Added `ClinicalDataLoader` to abstract transcript fetching.
|
||||||
|
- Loads directly from the `Transcript` Beanie document model in MongoDB.
|
||||||
|
|
||||||
|
### 6. Database Migrations
|
||||||
|
- Created `migrations/init_risen_prompts.py` to seed the database with the required "RISEN" prompt templates (`phq8-extract`, `phq8-map`, `phq8-score`).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. **Seed Prompts**:
|
||||||
|
```bash
|
||||||
|
python migrations/init_risen_prompts.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run Agent**:
|
||||||
|
```bash
|
||||||
|
# Run with the default Tier 3 (Cloud) config (defined in config.yaml)
|
||||||
|
python -m helia.agent.runner gemini-flash
|
||||||
|
```
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
1. **Safety**: Implement the `Safety Guardrail` (parallel node) as designed in `plans/safety-guardrail-architecture.md`.
|
||||||
|
2. **Persistence**: Uncomment the DB save logic in `persistence_node` to save `AssessmentResult` documents.
|
||||||
69
migrations/init_risen_prompts.py
Normal file
69
migrations/init_risen_prompts.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from helia.configuration import HeliaConfig
|
||||||
|
from helia.db import init_db
|
||||||
|
from helia.models.prompt import Prompt
|
||||||
|
from helia.prompts.assessment import (
|
||||||
|
EXTRACT_EVIDENCE_PROMPT,
|
||||||
|
MAP_CRITERIA_PROMPT,
|
||||||
|
SCORE_ITEM_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate() -> None:
|
||||||
|
try:
|
||||||
|
config = HeliaConfig() # ty:ignore[missing-argument]
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load configuration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Connecting to database...")
|
||||||
|
await init_db(config)
|
||||||
|
|
||||||
|
prompts_to_create = [
|
||||||
|
{
|
||||||
|
"name": "phq8-extract",
|
||||||
|
"template": EXTRACT_EVIDENCE_PROMPT,
|
||||||
|
"input_variables": ["symptom_name", "symptom_description", "transcript_text"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "phq8-map",
|
||||||
|
"template": MAP_CRITERIA_PROMPT,
|
||||||
|
"input_variables": ["symptom_name", "evidence_text"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "phq8-score",
|
||||||
|
"template": SCORE_ITEM_PROMPT,
|
||||||
|
"input_variables": ["symptom_name", "reasoning_text"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for p_data in prompts_to_create:
|
||||||
|
name = p_data["name"]
|
||||||
|
logger.info("Creating or updating prompt '%s'...", name)
|
||||||
|
|
||||||
|
# Check if exists to avoid duplicates or update if needed
|
||||||
|
existing = await Prompt.find_one(Prompt.name == name)
|
||||||
|
if existing:
|
||||||
|
logger.info("Prompt '%s' already exists. Updating template.", name)
|
||||||
|
existing.template = p_data["template"]
|
||||||
|
existing.input_variables = p_data["input_variables"]
|
||||||
|
await existing.save()
|
||||||
|
else:
|
||||||
|
new_prompt = Prompt(
|
||||||
|
name=name,
|
||||||
|
template=p_data["template"],
|
||||||
|
input_variables=p_data["input_variables"],
|
||||||
|
)
|
||||||
|
await new_prompt.insert()
|
||||||
|
logger.info("Prompt '%s' created.", name)
|
||||||
|
|
||||||
|
logger.info("Migration completed successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(migrate())
|
||||||
@@ -17,7 +17,7 @@ A **Hierarchical Agent Supervisor** architecture built with **LangGraph**:
|
|||||||
* **Extract**: Quote relevant patient text.
|
* **Extract**: Quote relevant patient text.
|
||||||
* **Map**: Align quotes to PHQ-8 criteria.
|
* **Map**: Align quotes to PHQ-8 criteria.
|
||||||
* **Score**: Assign 0-3 value.
|
* **Score**: Assign 0-3 value.
|
||||||
3. **Ingestion**: Standardizes data from S3/Local into a `ClinicalState`.
|
3. **Ingestion**: Standardizes data from MongoDB into a `ClinicalState`.
|
||||||
4. **Benchmarking**: Automates the comparison between Generated Scores vs. Ground Truth (DAIC-WOZ labels).
|
4. **Benchmarking**: Automates the comparison between Generated Scores vs. Ground Truth (DAIC-WOZ labels).
|
||||||
|
|
||||||
**Note:** A dedicated **Safety Guardrail** agent has been designed but is scoped out of this MVP. See `plans/safety-guardrail-architecture.md` for details.
|
**Note:** A dedicated **Safety Guardrail** agent has been designed but is scoped out of this MVP. See `plans/safety-guardrail-architecture.md` for details.
|
||||||
@@ -50,20 +50,20 @@ graph TD
|
|||||||
* **Deliverables**:
|
* **Deliverables**:
|
||||||
* `src/helia/agent/state.py`: Define `ClinicalState` (transcript, current_item, scores).
|
* `src/helia/agent/state.py`: Define `ClinicalState` (transcript, current_item, scores).
|
||||||
* `src/helia/agent/graph.py`: Define the main `StateGraph` with Ingestion -> Assessment -> Persistence nodes.
|
* `src/helia/agent/graph.py`: Define the main `StateGraph` with Ingestion -> Assessment -> Persistence nodes.
|
||||||
* `src/helia/ingestion/loader.py`: Add "Ground Truth" loading for DAIC-WOZ labels (critical for benchmarking).
|
* `src/helia/ingestion/loader.py`: Refactor to load Transcript documents from MongoDB.
|
||||||
|
|
||||||
#### Phase 2: The "RISEN" Assessment Logic
|
#### Phase 2: The "RISEN" Assessment Logic
|
||||||
* **Goal**: Replace monolithic `PHQ8Evaluator` with granular nodes.
|
* **Goal**: Replace monolithic `PHQ8Evaluator` with granular nodes.
|
||||||
* **Deliverables**:
|
* **Deliverables**:
|
||||||
* `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node`.
|
* `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node` that fetch prompts from DB.
|
||||||
* `src/helia/prompts/`: Create specialized prompt templates for each stage (optimized for Llama 3).
|
* `migrations/init_risen_prompts.py`: Database migration to seed the Extract/Map/Score prompts.
|
||||||
* **Refactor**: Update `PHQ8Evaluator` to be callable as a tool/node rather than a standalone class.
|
* **Refactor**: Update `PHQ8Evaluator` to be callable as a tool/node rather than a standalone class.
|
||||||
|
|
||||||
#### Phase 3: Tier Switching & Execution
|
#### Phase 3: Tier Switching & Execution
|
||||||
* **Goal**: Implement dynamic model config.
|
* **Goal**: Implement dynamic model config.
|
||||||
* **Deliverables**:
|
* **Deliverables**:
|
||||||
* `src/helia/configuration.py`: Ensure `RunConfig` (Tier 1/2/3) propagates to LangGraph `configurable` params.
|
* `src/helia/configuration.py`: Ensure `RunConfig` (Tier 1/2/3) propagates to LangGraph `configurable` params.
|
||||||
* `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks.
|
* `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks using MongoDB transcripts.
|
||||||
|
|
||||||
#### Phase 4: Human-in-the-Loop & Persistence
|
#### Phase 4: Human-in-the-Loop & Persistence
|
||||||
* **Goal**: Enable clinician review and data saving.
|
* **Goal**: Enable clinician review and data saving.
|
||||||
@@ -87,7 +87,7 @@ graph TD
|
|||||||
## Dependencies & Risks
|
## Dependencies & Risks
|
||||||
- **Risk**: Local models (Tier 1) may hallucinate formatting in the "Map" stage.
|
- **Risk**: Local models (Tier 1) may hallucinate formatting in the "Map" stage.
|
||||||
* *Mitigation*: Use `instructor` or constrained decoding (JSON mode) for Tier 1.
|
* *Mitigation*: Use `instructor` or constrained decoding (JSON mode) for Tier 1.
|
||||||
- **Dependency**: Requires DAIC-WOZ dataset (assumed available locally or mocked).
|
- **Dependency**: Requires DAIC-WOZ dataset (loaded in MongoDB).
|
||||||
|
|
||||||
## References
|
## References
|
||||||
- **LangGraph**: [State Management](https://langchain-ai.github.io/langgraph/concepts/high_level/#state)
|
- **LangGraph**: [State Management](https://langchain-ai.github.io/langgraph/concepts/high_level/#state)
|
||||||
|
|||||||
97
src/helia/agent/graph.py
Normal file
97
src/helia/agent/graph.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
|
||||||
|
from helia.agent.nodes.assessment import (
|
||||||
|
extract_evidence_node,
|
||||||
|
map_criteria_node,
|
||||||
|
score_item_node,
|
||||||
|
)
|
||||||
|
from helia.agent.nodes.persistence import persistence_node
|
||||||
|
from helia.agent.state import ClinicalState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def ingestion_node(state: ClinicalState) -> dict: # noqa: ARG001
|
||||||
|
"""
|
||||||
|
Placeholder for Ingestion Node.
|
||||||
|
In the runner, we pre-load the transcript, so this is just a pass-through
|
||||||
|
or could do additional validation.
|
||||||
|
"""
|
||||||
|
return {"status": "analyzing"}
|
||||||
|
|
||||||
|
|
||||||
|
PHQ8_ITEM_COUNT = 8
|
||||||
|
|
||||||
|
|
||||||
|
def router_node(state: ClinicalState) -> Literal["extract_evidence", "human_review"]:
|
||||||
|
"""
|
||||||
|
Routes between assessment steps or to review based on progress.
|
||||||
|
"""
|
||||||
|
# 8 items in PHQ-8 (indices 0-7)
|
||||||
|
if state.current_item_index < PHQ8_ITEM_COUNT:
|
||||||
|
return "extract_evidence"
|
||||||
|
return "human_review"
|
||||||
|
|
||||||
|
|
||||||
|
def human_review_node(state: ClinicalState) -> dict:
|
||||||
|
"""
|
||||||
|
Placeholder for Human Review.
|
||||||
|
This node acts as a breakpoint.
|
||||||
|
"""
|
||||||
|
logger.info("Ready for human review. Total scores calculated: %d", len(state.scores))
|
||||||
|
return {"status": "review_pending"}
|
||||||
|
|
||||||
|
|
||||||
|
# Define the Graph
|
||||||
|
workflow = StateGraph(ClinicalState)
|
||||||
|
|
||||||
|
# Add Nodes
|
||||||
|
workflow.add_node("ingestion", ingestion_node)
|
||||||
|
workflow.add_node("extract_evidence", extract_evidence_node)
|
||||||
|
workflow.add_node("map_criteria", map_criteria_node)
|
||||||
|
workflow.add_node("score_item", score_item_node)
|
||||||
|
workflow.add_node("human_review", human_review_node)
|
||||||
|
workflow.add_node("persistence", persistence_node)
|
||||||
|
|
||||||
|
# Set Entry Point
|
||||||
|
workflow.set_entry_point("ingestion")
|
||||||
|
|
||||||
|
# Add Edges
|
||||||
|
workflow.add_edge("ingestion", "extract_evidence") # Start immediately with first item
|
||||||
|
|
||||||
|
# RISEN Loop
|
||||||
|
workflow.add_edge("extract_evidence", "map_criteria")
|
||||||
|
workflow.add_edge("map_criteria", "score_item")
|
||||||
|
|
||||||
|
# Conditional Routing after scoring
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"score_item",
|
||||||
|
router_node,
|
||||||
|
{
|
||||||
|
"extract_evidence": "extract_evidence", # Loop back for next item
|
||||||
|
"human_review": "human_review", # Done with all 8 items
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finalize
|
||||||
|
workflow.add_edge("human_review", "persistence")
|
||||||
|
workflow.add_edge("persistence", END)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_graph(with_persistence: bool = True) -> Any:
|
||||||
|
"""
|
||||||
|
Compiles the graph.
|
||||||
|
If with_persistence is True, uses MemorySaver to allow interrupts (HITL).
|
||||||
|
"""
|
||||||
|
checkpointer = MemorySaver() if with_persistence else None
|
||||||
|
|
||||||
|
return workflow.compile(
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
interrupt_before=["human_review"] if with_persistence else [],
|
||||||
|
)
|
||||||
175
src/helia/agent/nodes/assessment.py
Normal file
175
src/helia/agent/nodes/assessment.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from helia.agent.state import ClinicalState, PHQ8ItemScore
|
||||||
|
from helia.configuration import get_settings
|
||||||
|
from helia.llm.client import get_chat_model
|
||||||
|
from helia.models.prompt import Prompt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# PHQ-8 Definitions
|
||||||
|
PHQ8_ITEMS = [
|
||||||
|
{"id": 1, "name": "Anhedonia", "desc": "Little interest or pleasure in doing things"},
|
||||||
|
{"id": 2, "name": "Depressed Mood", "desc": "Feeling down, depressed, or hopeless"},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"name": "Sleep Issues",
|
||||||
|
"desc": "Trouble falling or staying asleep, or sleeping too much",
|
||||||
|
},
|
||||||
|
{"id": 4, "name": "Fatigue", "desc": "Feeling tired or having little energy"},
|
||||||
|
{"id": 5, "name": "Appetite Issues", "desc": "Poor appetite or overeating"},
|
||||||
|
{
|
||||||
|
"id": 6,
|
||||||
|
"name": "Self-Failure",
|
||||||
|
"desc": "Feeling bad about yourself - or that you are a failure",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7,
|
||||||
|
"name": "Concentration",
|
||||||
|
"desc": "Trouble concentrating on things, such as reading or TV",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8,
|
||||||
|
"name": "Psychomotor",
|
||||||
|
"desc": "Moving/speaking slowly OR being fidgety/restless",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm(config: dict[str, Any] | None = None) -> Any: # noqa: ANN401
|
||||||
|
"""
|
||||||
|
Helper to get the LLM client, respecting runtime config if provided.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Default to global settings
|
||||||
|
model_name = settings.llm_provider.model_name
|
||||||
|
api_key = settings.llm_provider.api_key
|
||||||
|
base_url = settings.llm_provider.base_url
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
# Override from LangGraph runtime config
|
||||||
|
if config and "configurable" in config:
|
||||||
|
conf = config["configurable"]
|
||||||
|
if "model_name" in conf:
|
||||||
|
model_name = conf["model_name"]
|
||||||
|
if "temperature" in conf:
|
||||||
|
temperature = conf["temperature"]
|
||||||
|
|
||||||
|
return get_chat_model(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_prompt_template(name: str) -> str:
|
||||||
|
"""Fetches a prompt template from the database."""
|
||||||
|
prompt_doc = await Prompt.find_one(Prompt.name == name)
|
||||||
|
if not prompt_doc:
|
||||||
|
logger.warning("Prompt '%s' not found in DB. Using fallback.", name)
|
||||||
|
if name == "phq8-extract":
|
||||||
|
return """
|
||||||
|
SYMPTOM: {symptom_name} ({symptom_description})
|
||||||
|
TRANSCRIPT: {transcript_text}
|
||||||
|
Extract relevant quotes.
|
||||||
|
"""
|
||||||
|
if name == "phq8-map":
|
||||||
|
return """
|
||||||
|
SYMPTOM: {symptom_name}
|
||||||
|
EVIDENCE: {evidence_text}
|
||||||
|
Map to 0-3 scale.
|
||||||
|
"""
|
||||||
|
if name == "phq8-score":
|
||||||
|
return """
|
||||||
|
SYMPTOM: {symptom_name}
|
||||||
|
REASONING: {reasoning_text}
|
||||||
|
Return JSON: {{ "score": int, "final_reasoning": str }}
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
return prompt_doc.template
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_evidence_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Step 1: Extract quotes relevant to the current symptom.
|
||||||
|
"""
|
||||||
|
current_item = PHQ8_ITEMS[state.current_item_index]
|
||||||
|
llm = _get_llm(config)
|
||||||
|
|
||||||
|
template = await _get_prompt_template("phq8-extract")
|
||||||
|
prompt = template.format(
|
||||||
|
symptom_name=current_item["name"],
|
||||||
|
symptom_description=current_item["desc"],
|
||||||
|
transcript_text=state.transcript_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content="You are an expert extractor."), HumanMessage(content=prompt)]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"current_evidence": response.content}
|
||||||
|
|
||||||
|
|
||||||
|
async def map_criteria_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Step 2: Map the extracted evidence to scoring criteria.
|
||||||
|
"""
|
||||||
|
current_item = PHQ8_ITEMS[state.current_item_index]
|
||||||
|
evidence = getattr(state, "current_evidence", "No evidence found")
|
||||||
|
|
||||||
|
llm = _get_llm(config)
|
||||||
|
template = await _get_prompt_template("phq8-map")
|
||||||
|
prompt = template.format(symptom_name=current_item["name"], evidence_text=evidence)
|
||||||
|
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content="You are a clinical reasoner."), HumanMessage(content=prompt)]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"current_reasoning": response.content}
|
||||||
|
|
||||||
|
|
||||||
|
async def score_item_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Step 3: Assign the final score (0-3).
|
||||||
|
"""
|
||||||
|
current_item = PHQ8_ITEMS[state.current_item_index]
|
||||||
|
reasoning = getattr(state, "current_reasoning", "No reasoning provided")
|
||||||
|
llm = _get_llm(config)
|
||||||
|
|
||||||
|
# Use structured output for the final score
|
||||||
|
structured_llm = llm.with_structured_output(dict)
|
||||||
|
|
||||||
|
template = await _get_prompt_template("phq8-score")
|
||||||
|
prompt = template.format(symptom_name=current_item["name"], reasoning_text=reasoning)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await structured_llm.ainvoke(
|
||||||
|
[SystemMessage(content="Output valid JSON."), HumanMessage(content=prompt)]
|
||||||
|
)
|
||||||
|
score_val = response.get("score", 0)
|
||||||
|
final_reasoning = response.get("final_reasoning", reasoning)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to parse score, defaulting to 0")
|
||||||
|
score_val = 0
|
||||||
|
final_reasoning = "Error parsing model output"
|
||||||
|
|
||||||
|
new_score = PHQ8ItemScore(
|
||||||
|
question_id=current_item["id"],
|
||||||
|
score=score_val,
|
||||||
|
evidence_quote=getattr(state, "current_evidence", ""),
|
||||||
|
reasoning=final_reasoning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"scores": [new_score],
|
||||||
|
"current_item_index": state.current_item_index + 1,
|
||||||
|
"current_evidence": None,
|
||||||
|
"current_reasoning": None,
|
||||||
|
}
|
||||||
53
src/helia/agent/nodes/persistence.py
Normal file
53
src/helia/agent/nodes/persistence.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from helia.agent.state import ClinicalState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def persistence_node(state: ClinicalState) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Saves the final assessment results to MongoDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Persisting results for transcript %s", state.transcript_id)
|
||||||
|
|
||||||
|
# Convert state scores to schema format
|
||||||
|
# Note: We need to reconstruct the original items list
|
||||||
|
# In a real implementation, we'd map this properly
|
||||||
|
# For MVP, assuming direct mapping
|
||||||
|
from helia.assessment.schema import PHQ8Item
|
||||||
|
|
||||||
|
items = [
|
||||||
|
PHQ8Item(
|
||||||
|
question_id=s.question_id,
|
||||||
|
question_text=f"Question {s.question_id}", # Lookup actual text
|
||||||
|
score=s.score,
|
||||||
|
evidence=[
|
||||||
|
{
|
||||||
|
"quote": s.evidence_quote,
|
||||||
|
"reasoning": s.reasoning,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for s in state.scores
|
||||||
|
]
|
||||||
|
|
||||||
|
total_score = sum(i.score for i in items)
|
||||||
|
|
||||||
|
# Save to Beanie
|
||||||
|
# In Phase 4, we'd actually call .save()
|
||||||
|
# result = AssessmentResult(
|
||||||
|
# transcript_id=state.transcript_id,
|
||||||
|
# items=items,
|
||||||
|
# total_score=total_score,
|
||||||
|
# # ... other fields
|
||||||
|
# )
|
||||||
|
# await result.save()
|
||||||
|
|
||||||
|
logger.info("Saved assessment with total score: %d", total_score)
|
||||||
|
return {"status": "completed"}
|
||||||
115
src/helia/agent/runner.py
Normal file
115
src/helia/agent/runner.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from helia.agent.graph import compile_graph
|
||||||
|
from helia.agent.state import ClinicalState
|
||||||
|
from helia.configuration import HeliaConfig
|
||||||
|
from helia.db import init_db
|
||||||
|
from helia.ingestion.loader import ClinicalDataLoader
|
||||||
|
from helia.models.transcript import Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkRunner:
|
||||||
|
"""
|
||||||
|
Executes the clinical assessment workflow for a specific configuration tier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: HeliaConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.graph = compile_graph()
|
||||||
|
self.data_loader = ClinicalDataLoader()
|
||||||
|
|
||||||
|
async def run_benchmark(self, run_id: str, transcript_ids: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Runs the benchmark for a specific configuration (Tier 1/2/3).
|
||||||
|
"""
|
||||||
|
if run_id not in self.config.runs:
|
||||||
|
msg = f"Run ID '{run_id}' not found in configuration."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
run_config = self.config.runs[run_id]
|
||||||
|
logger.info("Starting benchmark for run: %s", run_id)
|
||||||
|
logger.info("Model: %s", run_config.model.model_name)
|
||||||
|
|
||||||
|
# Prepare runtime config for the graph nodes
|
||||||
|
# This dict is accessible via `config` in graph nodes
|
||||||
|
runtime_config = {
|
||||||
|
"configurable": {
|
||||||
|
"model_name": run_config.model.model_name,
|
||||||
|
"provider_id": run_config.model.provider,
|
||||||
|
"temperature": run_config.model.temperature,
|
||||||
|
# Pass full provider config if needed, or look it up globally
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process transcripts
|
||||||
|
for tid in transcript_ids:
|
||||||
|
await self._process_transcript(tid, runtime_config)
|
||||||
|
|
||||||
|
async def _process_transcript(self, transcript_id: str, runtime_config: dict[str, Any]) -> None:
|
||||||
|
"""Runs the workflow for a single transcript."""
|
||||||
|
logger.info("Processing transcript: %s", transcript_id)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
transcript_data = await self.data_loader.load_transcript(transcript_id)
|
||||||
|
|
||||||
|
# Initialize State
|
||||||
|
initial_state = ClinicalState(
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
transcript_text=transcript_data.text,
|
||||||
|
# If we had ground truth, we'd store it separately for benchmarking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute Graph
|
||||||
|
# We use ainvocations to stream or run to completion
|
||||||
|
# For batch, invoke is fine.
|
||||||
|
final_state = await self.graph.ainvoke(initial_state, config=runtime_config)
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
scores = final_state.get("scores", [])
|
||||||
|
logger.info("Finished %s. Generated %d scores.", transcript_id, len(scores))
|
||||||
|
# Persistence is handled by the last node in the graph, so we are done.
|
||||||
|
|
||||||
|
|
||||||
|
async def main(run_id: str = "tier3_baseline") -> None:
|
||||||
|
"""CLI Entrypoint."""
|
||||||
|
# Load global config
|
||||||
|
# In a real app, we'd use a proper config loader
|
||||||
|
# For now, assuming environment variables or default file
|
||||||
|
# We need a way to load the YAML.
|
||||||
|
# Pydantic Settings does this automatically if configured.
|
||||||
|
try:
|
||||||
|
config = HeliaConfig() # type: ignore
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize Database Connection
|
||||||
|
logger.info("Initializing database connection...")
|
||||||
|
await init_db(config)
|
||||||
|
|
||||||
|
runner = BenchmarkRunner(config)
|
||||||
|
|
||||||
|
# Fetch real transcripts from DB
|
||||||
|
logger.info("Fetching available transcripts...")
|
||||||
|
transcripts_docs = await Transcript.find_all().to_list()
|
||||||
|
transcript_ids = [t.transcript_id for t in transcripts_docs]
|
||||||
|
|
||||||
|
if not transcript_ids:
|
||||||
|
logger.warning("No transcripts found in database.")
|
||||||
|
else:
|
||||||
|
logger.info("Found %d transcripts. Starting benchmark...", len(transcript_ids))
|
||||||
|
await runner.run_benchmark(run_id, transcript_ids)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
run_arg = sys.argv[1] if len(sys.argv) > 1 else "default_run"
|
||||||
|
asyncio.run(main(run_arg))
|
||||||
45
src/helia/agent/state.py
Normal file
45
src/helia/agent/state.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
def update_scores(current: list[dict], new: list[dict]) -> list[dict]:
|
||||||
|
"""Reducer to append new scores to the list."""
|
||||||
|
return current + new
|
||||||
|
|
||||||
|
|
||||||
|
class PHQ8ItemScore(BaseModel):
|
||||||
|
"""Represents a scored PHQ-8 item with evidence."""
|
||||||
|
|
||||||
|
question_id: int
|
||||||
|
score: int
|
||||||
|
evidence_quote: str
|
||||||
|
reasoning: str
|
||||||
|
|
||||||
|
|
||||||
|
class ClinicalState(BaseModel):
|
||||||
|
"""
|
||||||
|
Working memory for the Clinical Assessment Agent.
|
||||||
|
Tracks progress through the PHQ-8 items and holds intermediate results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Input Data
|
||||||
|
transcript_id: str
|
||||||
|
transcript_text: str
|
||||||
|
|
||||||
|
# Execution State
|
||||||
|
current_item_index: int = 0
|
||||||
|
status: Literal["ingesting", "analyzing", "review_pending", "completed", "halted"] = "ingesting"
|
||||||
|
|
||||||
|
# Assessment Results (using reducer to append results from each step)
|
||||||
|
scores: Annotated[list[PHQ8ItemScore], update_scores] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# Safety (Placeholder for future guardrail)
|
||||||
|
safety_flags: list[str] = Field(default_factory=list)
|
||||||
|
is_safe: bool = True
|
||||||
|
|
||||||
|
# Transient fields for RISEN loop (not persisted in final result)
|
||||||
|
current_evidence: str | None = None
|
||||||
|
current_reasoning: str | None = None
|
||||||
47
src/helia/ingestion/loader.py
Normal file
47
src/helia/ingestion/loader.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from helia.models.transcript import Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriptData:
|
||||||
|
"""Standardized format for loaded transcript data."""
|
||||||
|
|
||||||
|
transcript_id: str
|
||||||
|
text: str
|
||||||
|
# Future: Add ground_truth_scores loaded from DB or S3
|
||||||
|
|
||||||
|
|
||||||
|
class ClinicalDataLoader:
|
||||||
|
"""
|
||||||
|
Loader that fetches transcript data from MongoDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def load_transcript(self, transcript_id: str) -> TranscriptData:
|
||||||
|
"""
|
||||||
|
Loads the transcript text from the MongoDB Transcript collection.
|
||||||
|
"""
|
||||||
|
transcript = await Transcript.find_one(Transcript.transcript_id == transcript_id)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError(f"Transcript '{transcript_id}' not found in database.")
|
||||||
|
|
||||||
|
# Format utterances into a single text blob
|
||||||
|
full_text = self._format_transcript_text(transcript.utterances)
|
||||||
|
|
||||||
|
return TranscriptData(
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
text=full_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_transcript_text(self, utterances: list) -> str:
|
||||||
|
"""Joins utterance objects into a readable dialogue string."""
|
||||||
|
lines = []
|
||||||
|
for u in utterances:
|
||||||
|
# Utterance model has 'speaker' and 'value'
|
||||||
|
lines.append(f"{u.speaker}: {u.value}")
|
||||||
|
return "\n".join(lines)
|
||||||
0
src/helia/prompts/__init__.py
Normal file
0
src/helia/prompts/__init__.py
Normal file
51
src/helia/prompts/assessment.py
Normal file
51
src/helia/prompts/assessment.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Prompts for RISEN (Extract -> Map -> Score) Architecture
|
||||||
|
|
||||||
|
EXTRACT_EVIDENCE_PROMPT = """
|
||||||
|
You are a clinical expert analyzing a therapy session transcript.
|
||||||
|
Your goal is to find ALL evidence relevant to the following PHQ-8 symptom.
|
||||||
|
|
||||||
|
SYMPTOM: {symptom_name}
|
||||||
|
DESCRIPTION: {symptom_description}
|
||||||
|
|
||||||
|
TRANSCRIPT:
|
||||||
|
{transcript_text}
|
||||||
|
|
||||||
|
INSTRUCTIONS:
|
||||||
|
1. Extract verbatim quotes from the patient that relate to this symptom.
|
||||||
|
2. Include context if necessary to understand the quote.
|
||||||
|
3. If no evidence is found, explicitly state "No evidence found".
|
||||||
|
|
||||||
|
Return ONLY the quotes and their context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MAP_CRITERIA_PROMPT = """
|
||||||
|
You are a clinical expert. You have extracted the following evidence for a specific PHQ-8 symptom.
|
||||||
|
Now, map this evidence to the scoring criteria.
|
||||||
|
|
||||||
|
SYMPTOM: {symptom_name}
|
||||||
|
CRITERIA:
|
||||||
|
- 0: Not at all (0-1 days)
|
||||||
|
- 1: Several days (2-6 days)
|
||||||
|
- 2: More than half the days (7-11 days)
|
||||||
|
- 3: Nearly every day (12-14 days)
|
||||||
|
|
||||||
|
EVIDENCE:
|
||||||
|
{evidence_text}
|
||||||
|
|
||||||
|
INSTRUCTIONS:
|
||||||
|
1. Analyze the frequency and severity implied by the quotes.
|
||||||
|
2. Compare against the criteria above.
|
||||||
|
3. Explain your reasoning for which bucket the patient falls into.
|
||||||
|
|
||||||
|
Reasoning:
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCORE_ITEM_PROMPT = """
|
||||||
|
Based on the reasoning below, assign the final PHQ-8 score (0-3) for this item.
|
||||||
|
|
||||||
|
SYMPTOM: {symptom_name}
|
||||||
|
REASONING:
|
||||||
|
{reasoning_text}
|
||||||
|
|
||||||
|
Return JSON: {{ "score": int, "final_reasoning": str }}
|
||||||
|
"""
|
||||||
Reference in New Issue
Block a user