WIP
This commit is contained in:
35
CLAUDE.md
35
CLAUDE.md
@@ -18,6 +18,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
- **Install Dependencies**: `uv sync`
|
||||
- **Run Agent**: `python -m helia.main "Your query here"`
|
||||
- **Verify Prompts**: `python scripts/verify_prompt_db.py`
|
||||
- **Lint**: `uv run ruff check .`
|
||||
- **Format**: `uv run ruff format .`
|
||||
- **Type Check**: `uv run ty check`
|
||||
@@ -25,32 +26,34 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## Architecture
|
||||
|
||||
Helia is a modular ReAct-style agent framework designed for clinical interview analysis:
|
||||
Helia is a modular ReAct-style agent framework designed for clinical interview analysis.
|
||||
|
||||
### Core Modules
|
||||
|
||||
1. **Ingestion** (`src/helia/ingestion/`):
|
||||
- Parses clinical interview transcripts (e.g., DAIC-WOZ dataset).
|
||||
- Standardizes raw text/audio into `Utterance` objects.
|
||||
- **Parser**: `TranscriptParser` parses clinical interview transcripts (e.g., DAIC-WOZ dataset).
|
||||
- **Loader**: `ClinicalDataLoader` in `loader.py` retrieves `Transcript` documents from MongoDB.
|
||||
- **Legacy**: `S3DatasetLoader` (deprecated for runtime use, used for initial population).
|
||||
|
||||
2. **Analysis & Enrichment** (`src/helia/analysis/`):
|
||||
- **MetadataExtractor**: Enriches utterances with sentiment, tone, and speech acts.
|
||||
- **Model Agnostic**: Designed to swap backend LLMs (OpenAI vs. Local/Quantized models).
|
||||
2. **Data Models** (`src/helia/models/`):
|
||||
- **Transcript**: Document model for interview transcripts.
|
||||
- **Utterance/Turn**: Standardized conversation units.
|
||||
- **Prompt**: Manages prompt templates and versioning.
|
||||
|
||||
3. **Assessment** (`src/helia/assessment/`):
|
||||
- Implements clinical logic for standard instruments (e.g., PHQ-8).
|
||||
- Maps unstructured dialogue to structured clinical scores.
|
||||
- **Evaluator**: `PHQ8Evaluator` (in `core.py`) orchestrates the LLM interaction.
|
||||
- **Logic**: Implements clinical logic for standard instruments (e.g., PHQ-8).
|
||||
- **Schema**: `src/helia/assessment/schema.py` defines `AssessmentResult` and `Evidence`.
|
||||
|
||||
4. **Persistence Layer** (`src/helia/db.py`):
|
||||
- **Document-Based Storage**: Uses MongoDB with Beanie (ODM).
|
||||
- **Core Model**: `AssessmentResult` (in `src/helia/assessment/schema.py`) acts as the single source of truth for experimental results.
|
||||
- **Data Capture**: Stores the full context of each run:
|
||||
- **Configuration**: Model version, prompts, temperature (critical for comparing tiers).
|
||||
- **Evidence**: Specific quotes and reasoning supporting each PHQ-8 score.
|
||||
- **Outcome**: Final diagnosis and total scores.
|
||||
- **Data Capture**: Stores full context (Configuration, Evidence, Outcomes) to support comparative analysis.
|
||||
|
||||
5. **Agent Workflow** (`src/helia/agent/`):
|
||||
- Built with **LangGraph**.
|
||||
- **Router Pattern**: Decides when to call specific tools (search, scoring).
|
||||
- **Tools**: Clinical scoring utilities, Document retrieval.
|
||||
- **Graph Architecture**: Implements RISEN pattern (Extract -> Map -> Score) using LangGraph in `src/helia/agent/graph.py`.
|
||||
- **State**: `ClinicalState` (in `state.py`) manages transcript, scores, and execution status.
|
||||
- **Nodes**: Specialized logic in `src/helia/agent/nodes/` (assessment, persistence).
|
||||
- **Execution**: Run benchmarks via `python -m helia.agent.runner <run_id>`.
|
||||
|
||||
## Development Standards
|
||||
|
||||
|
||||
57
IMPLEMENTATION_SUMMARY.md
Normal file
57
IMPLEMENTATION_SUMMARY.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Implementation Summary: Modular Agentic Framework
|
||||
|
||||
## Overview
|
||||
We have successfully implemented the core Agentic Framework for the PHQ-8 assessment benchmark. This architecture uses **LangGraph** to orchestrate a multi-stage reasoning process ("RISEN") and supports dynamic switching between Local (Tier 1) and Cloud (Tier 3) models. The system is fully integrated with the MongoDB infrastructure for data and prompts.
|
||||
|
||||
## Completed Components
|
||||
|
||||
### 1. Agent Architecture (`src/helia/agent/graph.py`)
|
||||
- Implemented a `StateGraph` that manages the workflow lifecycle.
|
||||
- **Nodes**: `ingestion`, `extract_evidence`, `map_criteria`, `score_item`, `human_review`, `persistence`.
|
||||
- **Routing**: Conditional edges loop through the 8 PHQ-8 items before proceeding to review.
|
||||
- **HITL**: Configured `MemorySaver` to allow human-in-the-loop interrupts at the `human_review` stage.
|
||||
|
||||
### 2. State Management (`src/helia/agent/state.py`)
|
||||
- Created `ClinicalState` Pydantic model.
|
||||
- Strictly types the workflow memory, including:
|
||||
- `transcript_text`: The input data.
|
||||
- `scores`: A list of `PHQ8ItemScore` objects (accumulated via reducer).
|
||||
- `current_item_index`: Tracks progress through the 8 items.
|
||||
- `current_evidence` / `current_reasoning`: Transient fields for the RISEN loop.
|
||||
|
||||
### 3. RISEN Logic (`src/helia/agent/nodes/assessment.py`)
|
||||
- Refactored the monolithic evaluation logic into three granular nodes:
|
||||
1. **Extract**: Finds verbatim quotes for the specific symptom.
|
||||
2. **Map**: Aligns evidence to the 0-3 scoring criteria.
|
||||
3. **Score**: Assigns the final value and structured reasoning.
|
||||
- **Prompt Management**: Fetches prompts dynamically from the MongoDB `Prompt` collection using `Prompt.find_one`.
|
||||
|
||||
### 4. Runner & Config (`src/helia/agent/runner.py`)
|
||||
- Created a CLI entry point: `python -m helia.agent.runner <run_id>`.
|
||||
- Initializes the MongoDB connection via `init_db`.
|
||||
- Fetches all available `Transcript` documents from the database to run the benchmark.
|
||||
- Injects the specific `RunConfig` (Tier 1/2/3) into the graph's runtime configuration.
|
||||
|
||||
### 5. Ingestion (`src/helia/ingestion/loader.py`)
|
||||
- Added `ClinicalDataLoader` to abstract transcript fetching.
|
||||
- Loads directly from the `Transcript` Beanie document model in MongoDB.
|
||||
|
||||
### 6. Database Migrations
|
||||
- Created `migrations/init_risen_prompts.py` to seed the database with the required "RISEN" prompt templates (`phq8-extract`, `phq8-map`, `phq8-score`).
|
||||
|
||||
## Usage
|
||||
|
||||
1. **Seed Prompts**:
|
||||
```bash
|
||||
python migrations/init_risen_prompts.py
|
||||
```
|
||||
|
||||
2. **Run Agent**:
|
||||
```bash
|
||||
# Run with the default Tier 3 (Cloud) config (defined in config.yaml)
|
||||
python -m helia.agent.runner gemini-flash
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
1. **Safety**: Implement the `Safety Guardrail` (parallel node) as designed in `plans/safety-guardrail-architecture.md`.
|
||||
2. **Persistence**: Uncomment the DB save logic in `persistence_node` to save `AssessmentResult` documents.
|
||||
69
migrations/init_risen_prompts.py
Normal file
69
migrations/init_risen_prompts.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from helia.configuration import HeliaConfig
|
||||
from helia.db import init_db
|
||||
from helia.models.prompt import Prompt
|
||||
from helia.prompts.assessment import (
|
||||
EXTRACT_EVIDENCE_PROMPT,
|
||||
MAP_CRITERIA_PROMPT,
|
||||
SCORE_ITEM_PROMPT,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def migrate() -> None:
|
||||
try:
|
||||
config = HeliaConfig() # ty:ignore[missing-argument]
|
||||
except Exception:
|
||||
logger.exception("Failed to load configuration.")
|
||||
return
|
||||
|
||||
logger.info("Connecting to database...")
|
||||
await init_db(config)
|
||||
|
||||
prompts_to_create = [
|
||||
{
|
||||
"name": "phq8-extract",
|
||||
"template": EXTRACT_EVIDENCE_PROMPT,
|
||||
"input_variables": ["symptom_name", "symptom_description", "transcript_text"],
|
||||
},
|
||||
{
|
||||
"name": "phq8-map",
|
||||
"template": MAP_CRITERIA_PROMPT,
|
||||
"input_variables": ["symptom_name", "evidence_text"],
|
||||
},
|
||||
{
|
||||
"name": "phq8-score",
|
||||
"template": SCORE_ITEM_PROMPT,
|
||||
"input_variables": ["symptom_name", "reasoning_text"],
|
||||
},
|
||||
]
|
||||
|
||||
for p_data in prompts_to_create:
|
||||
name = p_data["name"]
|
||||
logger.info("Creating or updating prompt '%s'...", name)
|
||||
|
||||
# Check if exists to avoid duplicates or update if needed
|
||||
existing = await Prompt.find_one(Prompt.name == name)
|
||||
if existing:
|
||||
logger.info("Prompt '%s' already exists. Updating template.", name)
|
||||
existing.template = p_data["template"]
|
||||
existing.input_variables = p_data["input_variables"]
|
||||
await existing.save()
|
||||
else:
|
||||
new_prompt = Prompt(
|
||||
name=name,
|
||||
template=p_data["template"],
|
||||
input_variables=p_data["input_variables"],
|
||||
)
|
||||
await new_prompt.insert()
|
||||
logger.info("Prompt '%s' created.", name)
|
||||
|
||||
logger.info("Migration completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(migrate())
|
||||
@@ -17,7 +17,7 @@ A **Hierarchical Agent Supervisor** architecture built with **LangGraph**:
|
||||
* **Extract**: Quote relevant patient text.
|
||||
* **Map**: Align quotes to PHQ-8 criteria.
|
||||
* **Score**: Assign 0-3 value.
|
||||
3. **Ingestion**: Standardizes data from S3/Local into a `ClinicalState`.
|
||||
3. **Ingestion**: Standardizes data from MongoDB into a `ClinicalState`.
|
||||
4. **Benchmarking**: Automates the comparison between Generated Scores vs. Ground Truth (DAIC-WOZ labels).
|
||||
|
||||
**Note:** A dedicated **Safety Guardrail** agent has been designed but is scoped out of this MVP. See `plans/safety-guardrail-architecture.md` for details.
|
||||
@@ -50,20 +50,20 @@ graph TD
|
||||
* **Deliverables**:
|
||||
* `src/helia/agent/state.py`: Define `ClinicalState` (transcript, current_item, scores).
|
||||
* `src/helia/agent/graph.py`: Define the main `StateGraph` with Ingestion -> Assessment -> Persistence nodes.
|
||||
* `src/helia/ingestion/loader.py`: Add "Ground Truth" loading for DAIC-WOZ labels (critical for benchmarking).
|
||||
* `src/helia/ingestion/loader.py`: Refactor to load Transcript documents from MongoDB.
|
||||
|
||||
#### Phase 2: The "RISEN" Assessment Logic
|
||||
* **Goal**: Replace monolithic `PHQ8Evaluator` with granular nodes.
|
||||
* **Deliverables**:
|
||||
* `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node`.
|
||||
* `src/helia/prompts/`: Create specialized prompt templates for each stage (optimized for Llama 3).
|
||||
* `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node` that fetch prompts from DB.
|
||||
* `migrations/init_risen_prompts.py`: Database migration to seed the Extract/Map/Score prompts.
|
||||
* **Refactor**: Update `PHQ8Evaluator` to be callable as a tool/node rather than a standalone class.
|
||||
|
||||
#### Phase 3: Tier Switching & Execution
|
||||
* **Goal**: Implement dynamic model config.
|
||||
* **Deliverables**:
|
||||
* `src/helia/configuration.py`: Ensure `RunConfig` (Tier 1/2/3) propagates to LangGraph `configurable` params.
|
||||
* `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks.
|
||||
* `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks using MongoDB transcripts.
|
||||
|
||||
#### Phase 4: Human-in-the-Loop & Persistence
|
||||
* **Goal**: Enable clinician review and data saving.
|
||||
@@ -87,7 +87,7 @@ graph TD
|
||||
## Dependencies & Risks
|
||||
- **Risk**: Local models (Tier 1) may hallucinate formatting in the "Map" stage.
|
||||
* *Mitigation*: Use `instructor` or constrained decoding (JSON mode) for Tier 1.
|
||||
- **Dependency**: Requires DAIC-WOZ dataset (assumed available locally or mocked).
|
||||
- **Dependency**: Requires DAIC-WOZ dataset (loaded in MongoDB).
|
||||
|
||||
## References
|
||||
- **LangGraph**: [State Management](https://langchain-ai.github.io/langgraph/concepts/high_level/#state)
|
||||
|
||||
97
src/helia/agent/graph.py
Normal file
97
src/helia/agent/graph.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from helia.agent.nodes.assessment import (
|
||||
extract_evidence_node,
|
||||
map_criteria_node,
|
||||
score_item_node,
|
||||
)
|
||||
from helia.agent.nodes.persistence import persistence_node
|
||||
from helia.agent.state import ClinicalState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ingestion_node(state: ClinicalState) -> dict: # noqa: ARG001
|
||||
"""
|
||||
Placeholder for Ingestion Node.
|
||||
In the runner, we pre-load the transcript, so this is just a pass-through
|
||||
or could do additional validation.
|
||||
"""
|
||||
return {"status": "analyzing"}
|
||||
|
||||
|
||||
PHQ8_ITEM_COUNT = 8
|
||||
|
||||
|
||||
def router_node(state: ClinicalState) -> Literal["extract_evidence", "human_review"]:
|
||||
"""
|
||||
Routes between assessment steps or to review based on progress.
|
||||
"""
|
||||
# 8 items in PHQ-8 (indices 0-7)
|
||||
if state.current_item_index < PHQ8_ITEM_COUNT:
|
||||
return "extract_evidence"
|
||||
return "human_review"
|
||||
|
||||
|
||||
def human_review_node(state: ClinicalState) -> dict:
|
||||
"""
|
||||
Placeholder for Human Review.
|
||||
This node acts as a breakpoint.
|
||||
"""
|
||||
logger.info("Ready for human review. Total scores calculated: %d", len(state.scores))
|
||||
return {"status": "review_pending"}
|
||||
|
||||
|
||||
# Define the Graph
|
||||
workflow = StateGraph(ClinicalState)
|
||||
|
||||
# Add Nodes
|
||||
workflow.add_node("ingestion", ingestion_node)
|
||||
workflow.add_node("extract_evidence", extract_evidence_node)
|
||||
workflow.add_node("map_criteria", map_criteria_node)
|
||||
workflow.add_node("score_item", score_item_node)
|
||||
workflow.add_node("human_review", human_review_node)
|
||||
workflow.add_node("persistence", persistence_node)
|
||||
|
||||
# Set Entry Point
|
||||
workflow.set_entry_point("ingestion")
|
||||
|
||||
# Add Edges
|
||||
workflow.add_edge("ingestion", "extract_evidence") # Start immediately with first item
|
||||
|
||||
# RISEN Loop
|
||||
workflow.add_edge("extract_evidence", "map_criteria")
|
||||
workflow.add_edge("map_criteria", "score_item")
|
||||
|
||||
# Conditional Routing after scoring
|
||||
workflow.add_conditional_edges(
|
||||
"score_item",
|
||||
router_node,
|
||||
{
|
||||
"extract_evidence": "extract_evidence", # Loop back for next item
|
||||
"human_review": "human_review", # Done with all 8 items
|
||||
},
|
||||
)
|
||||
|
||||
# Finalize
|
||||
workflow.add_edge("human_review", "persistence")
|
||||
workflow.add_edge("persistence", END)
|
||||
|
||||
|
||||
def compile_graph(with_persistence: bool = True) -> Any:
|
||||
"""
|
||||
Compiles the graph.
|
||||
If with_persistence is True, uses MemorySaver to allow interrupts (HITL).
|
||||
"""
|
||||
checkpointer = MemorySaver() if with_persistence else None
|
||||
|
||||
return workflow.compile(
|
||||
checkpointer=checkpointer,
|
||||
interrupt_before=["human_review"] if with_persistence else [],
|
||||
)
|
||||
175
src/helia/agent/nodes/assessment.py
Normal file
175
src/helia/agent/nodes/assessment.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from helia.agent.state import ClinicalState, PHQ8ItemScore
|
||||
from helia.configuration import get_settings
|
||||
from helia.llm.client import get_chat_model
|
||||
from helia.models.prompt import Prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PHQ-8 Definitions
|
||||
PHQ8_ITEMS = [
|
||||
{"id": 1, "name": "Anhedonia", "desc": "Little interest or pleasure in doing things"},
|
||||
{"id": 2, "name": "Depressed Mood", "desc": "Feeling down, depressed, or hopeless"},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "Sleep Issues",
|
||||
"desc": "Trouble falling or staying asleep, or sleeping too much",
|
||||
},
|
||||
{"id": 4, "name": "Fatigue", "desc": "Feeling tired or having little energy"},
|
||||
{"id": 5, "name": "Appetite Issues", "desc": "Poor appetite or overeating"},
|
||||
{
|
||||
"id": 6,
|
||||
"name": "Self-Failure",
|
||||
"desc": "Feeling bad about yourself - or that you are a failure",
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"name": "Concentration",
|
||||
"desc": "Trouble concentrating on things, such as reading or TV",
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"name": "Psychomotor",
|
||||
"desc": "Moving/speaking slowly OR being fidgety/restless",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _get_llm(config: dict[str, Any] | None = None) -> Any: # noqa: ANN401
|
||||
"""
|
||||
Helper to get the LLM client, respecting runtime config if provided.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Default to global settings
|
||||
model_name = settings.llm_provider.model_name
|
||||
api_key = settings.llm_provider.api_key
|
||||
base_url = settings.llm_provider.base_url
|
||||
temperature = 0.0
|
||||
|
||||
# Override from LangGraph runtime config
|
||||
if config and "configurable" in config:
|
||||
conf = config["configurable"]
|
||||
if "model_name" in conf:
|
||||
model_name = conf["model_name"]
|
||||
if "temperature" in conf:
|
||||
temperature = conf["temperature"]
|
||||
|
||||
return get_chat_model(
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
async def _get_prompt_template(name: str) -> str:
|
||||
"""Fetches a prompt template from the database."""
|
||||
prompt_doc = await Prompt.find_one(Prompt.name == name)
|
||||
if not prompt_doc:
|
||||
logger.warning("Prompt '%s' not found in DB. Using fallback.", name)
|
||||
if name == "phq8-extract":
|
||||
return """
|
||||
SYMPTOM: {symptom_name} ({symptom_description})
|
||||
TRANSCRIPT: {transcript_text}
|
||||
Extract relevant quotes.
|
||||
"""
|
||||
if name == "phq8-map":
|
||||
return """
|
||||
SYMPTOM: {symptom_name}
|
||||
EVIDENCE: {evidence_text}
|
||||
Map to 0-3 scale.
|
||||
"""
|
||||
if name == "phq8-score":
|
||||
return """
|
||||
SYMPTOM: {symptom_name}
|
||||
REASONING: {reasoning_text}
|
||||
Return JSON: {{ "score": int, "final_reasoning": str }}
|
||||
"""
|
||||
return ""
|
||||
return prompt_doc.template
|
||||
|
||||
|
||||
async def extract_evidence_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Step 1: Extract quotes relevant to the current symptom.
|
||||
"""
|
||||
current_item = PHQ8_ITEMS[state.current_item_index]
|
||||
llm = _get_llm(config)
|
||||
|
||||
template = await _get_prompt_template("phq8-extract")
|
||||
prompt = template.format(
|
||||
symptom_name=current_item["name"],
|
||||
symptom_description=current_item["desc"],
|
||||
transcript_text=state.transcript_text,
|
||||
)
|
||||
|
||||
response = await llm.ainvoke(
|
||||
[SystemMessage(content="You are an expert extractor."), HumanMessage(content=prompt)]
|
||||
)
|
||||
|
||||
return {"current_evidence": response.content}
|
||||
|
||||
|
||||
async def map_criteria_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Step 2: Map the extracted evidence to scoring criteria.
|
||||
"""
|
||||
current_item = PHQ8_ITEMS[state.current_item_index]
|
||||
evidence = getattr(state, "current_evidence", "No evidence found")
|
||||
|
||||
llm = _get_llm(config)
|
||||
template = await _get_prompt_template("phq8-map")
|
||||
prompt = template.format(symptom_name=current_item["name"], evidence_text=evidence)
|
||||
|
||||
response = await llm.ainvoke(
|
||||
[SystemMessage(content="You are a clinical reasoner."), HumanMessage(content=prompt)]
|
||||
)
|
||||
|
||||
return {"current_reasoning": response.content}
|
||||
|
||||
|
||||
async def score_item_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Step 3: Assign the final score (0-3).
|
||||
"""
|
||||
current_item = PHQ8_ITEMS[state.current_item_index]
|
||||
reasoning = getattr(state, "current_reasoning", "No reasoning provided")
|
||||
llm = _get_llm(config)
|
||||
|
||||
# Use structured output for the final score
|
||||
structured_llm = llm.with_structured_output(dict)
|
||||
|
||||
template = await _get_prompt_template("phq8-score")
|
||||
prompt = template.format(symptom_name=current_item["name"], reasoning_text=reasoning)
|
||||
|
||||
try:
|
||||
response = await structured_llm.ainvoke(
|
||||
[SystemMessage(content="Output valid JSON."), HumanMessage(content=prompt)]
|
||||
)
|
||||
score_val = response.get("score", 0)
|
||||
final_reasoning = response.get("final_reasoning", reasoning)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse score, defaulting to 0")
|
||||
score_val = 0
|
||||
final_reasoning = "Error parsing model output"
|
||||
|
||||
new_score = PHQ8ItemScore(
|
||||
question_id=current_item["id"],
|
||||
score=score_val,
|
||||
evidence_quote=getattr(state, "current_evidence", ""),
|
||||
reasoning=final_reasoning,
|
||||
)
|
||||
|
||||
return {
|
||||
"scores": [new_score],
|
||||
"current_item_index": state.current_item_index + 1,
|
||||
"current_evidence": None,
|
||||
"current_reasoning": None,
|
||||
}
|
||||
53
src/helia/agent/nodes/persistence.py
Normal file
53
src/helia/agent/nodes/persistence.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from helia.agent.state import ClinicalState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def persistence_node(state: ClinicalState) -> dict[str, Any]:
|
||||
"""
|
||||
Saves the final assessment results to MongoDB.
|
||||
"""
|
||||
|
||||
logger.info("Persisting results for transcript %s", state.transcript_id)
|
||||
|
||||
# Convert state scores to schema format
|
||||
# Note: We need to reconstruct the original items list
|
||||
# In a real implementation, we'd map this properly
|
||||
# For MVP, assuming direct mapping
|
||||
from helia.assessment.schema import PHQ8Item
|
||||
|
||||
items = [
|
||||
PHQ8Item(
|
||||
question_id=s.question_id,
|
||||
question_text=f"Question {s.question_id}", # Lookup actual text
|
||||
score=s.score,
|
||||
evidence=[
|
||||
{
|
||||
"quote": s.evidence_quote,
|
||||
"reasoning": s.reasoning,
|
||||
}
|
||||
],
|
||||
)
|
||||
for s in state.scores
|
||||
]
|
||||
|
||||
total_score = sum(i.score for i in items)
|
||||
|
||||
# Save to Beanie
|
||||
# In Phase 4, we'd actually call .save()
|
||||
# result = AssessmentResult(
|
||||
# transcript_id=state.transcript_id,
|
||||
# items=items,
|
||||
# total_score=total_score,
|
||||
# # ... other fields
|
||||
# )
|
||||
# await result.save()
|
||||
|
||||
logger.info("Saved assessment with total score: %d", total_score)
|
||||
return {"status": "completed"}
|
||||
115
src/helia/agent/runner.py
Normal file
115
src/helia/agent/runner.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from helia.agent.graph import compile_graph
|
||||
from helia.agent.state import ClinicalState
|
||||
from helia.configuration import HeliaConfig
|
||||
from helia.db import init_db
|
||||
from helia.ingestion.loader import ClinicalDataLoader
|
||||
from helia.models.transcript import Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BenchmarkRunner:
|
||||
"""
|
||||
Executes the clinical assessment workflow for a specific configuration tier.
|
||||
"""
|
||||
|
||||
def __init__(self, config: HeliaConfig) -> None:
|
||||
self.config = config
|
||||
self.graph = compile_graph()
|
||||
self.data_loader = ClinicalDataLoader()
|
||||
|
||||
async def run_benchmark(self, run_id: str, transcript_ids: list[str]) -> None:
|
||||
"""
|
||||
Runs the benchmark for a specific configuration (Tier 1/2/3).
|
||||
"""
|
||||
if run_id not in self.config.runs:
|
||||
msg = f"Run ID '{run_id}' not found in configuration."
|
||||
raise ValueError(msg)
|
||||
|
||||
run_config = self.config.runs[run_id]
|
||||
logger.info("Starting benchmark for run: %s", run_id)
|
||||
logger.info("Model: %s", run_config.model.model_name)
|
||||
|
||||
# Prepare runtime config for the graph nodes
|
||||
# This dict is accessible via `config` in graph nodes
|
||||
runtime_config = {
|
||||
"configurable": {
|
||||
"model_name": run_config.model.model_name,
|
||||
"provider_id": run_config.model.provider,
|
||||
"temperature": run_config.model.temperature,
|
||||
# Pass full provider config if needed, or look it up globally
|
||||
}
|
||||
}
|
||||
|
||||
# Process transcripts
|
||||
for tid in transcript_ids:
|
||||
await self._process_transcript(tid, runtime_config)
|
||||
|
||||
async def _process_transcript(self, transcript_id: str, runtime_config: dict[str, Any]) -> None:
|
||||
"""Runs the workflow for a single transcript."""
|
||||
logger.info("Processing transcript: %s", transcript_id)
|
||||
|
||||
# Load data
|
||||
transcript_data = await self.data_loader.load_transcript(transcript_id)
|
||||
|
||||
# Initialize State
|
||||
initial_state = ClinicalState(
|
||||
transcript_id=transcript_id,
|
||||
transcript_text=transcript_data.text,
|
||||
# If we had ground truth, we'd store it separately for benchmarking
|
||||
)
|
||||
|
||||
# Execute Graph
|
||||
# We use ainvocations to stream or run to completion
|
||||
# For batch, invoke is fine.
|
||||
final_state = await self.graph.ainvoke(initial_state, config=runtime_config)
|
||||
|
||||
# Check results
|
||||
scores = final_state.get("scores", [])
|
||||
logger.info("Finished %s. Generated %d scores.", transcript_id, len(scores))
|
||||
# Persistence is handled by the last node in the graph, so we are done.
|
||||
|
||||
|
||||
async def main(run_id: str = "tier3_baseline") -> None:
|
||||
"""CLI Entrypoint."""
|
||||
# Load global config
|
||||
# In a real app, we'd use a proper config loader
|
||||
# For now, assuming environment variables or default file
|
||||
# We need a way to load the YAML.
|
||||
# Pydantic Settings does this automatically if configured.
|
||||
try:
|
||||
config = HeliaConfig() # type: ignore
|
||||
except Exception:
|
||||
logger.exception("Failed to load configuration")
|
||||
return
|
||||
|
||||
# Initialize Database Connection
|
||||
logger.info("Initializing database connection...")
|
||||
await init_db(config)
|
||||
|
||||
runner = BenchmarkRunner(config)
|
||||
|
||||
# Fetch real transcripts from DB
|
||||
logger.info("Fetching available transcripts...")
|
||||
transcripts_docs = await Transcript.find_all().to_list()
|
||||
transcript_ids = [t.transcript_id for t in transcripts_docs]
|
||||
|
||||
if not transcript_ids:
|
||||
logger.warning("No transcripts found in database.")
|
||||
else:
|
||||
logger.info("Found %d transcripts. Starting benchmark...", len(transcript_ids))
|
||||
await runner.run_benchmark(run_id, transcript_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
run_arg = sys.argv[1] if len(sys.argv) > 1 else "default_run"
|
||||
asyncio.run(main(run_arg))
|
||||
45
src/helia/agent/state.py
Normal file
45
src/helia/agent/state.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def update_scores(current: list[dict], new: list[dict]) -> list[dict]:
|
||||
"""Reducer to append new scores to the list."""
|
||||
return current + new
|
||||
|
||||
|
||||
class PHQ8ItemScore(BaseModel):
|
||||
"""Represents a scored PHQ-8 item with evidence."""
|
||||
|
||||
question_id: int
|
||||
score: int
|
||||
evidence_quote: str
|
||||
reasoning: str
|
||||
|
||||
|
||||
class ClinicalState(BaseModel):
|
||||
"""
|
||||
Working memory for the Clinical Assessment Agent.
|
||||
Tracks progress through the PHQ-8 items and holds intermediate results.
|
||||
"""
|
||||
|
||||
# Input Data
|
||||
transcript_id: str
|
||||
transcript_text: str
|
||||
|
||||
# Execution State
|
||||
current_item_index: int = 0
|
||||
status: Literal["ingesting", "analyzing", "review_pending", "completed", "halted"] = "ingesting"
|
||||
|
||||
# Assessment Results (using reducer to append results from each step)
|
||||
scores: Annotated[list[PHQ8ItemScore], update_scores] = Field(default_factory=list)
|
||||
|
||||
# Safety (Placeholder for future guardrail)
|
||||
safety_flags: list[str] = Field(default_factory=list)
|
||||
is_safe: bool = True
|
||||
|
||||
# Transient fields for RISEN loop (not persisted in final result)
|
||||
current_evidence: str | None = None
|
||||
current_reasoning: str | None = None
|
||||
47
src/helia/ingestion/loader.py
Normal file
47
src/helia/ingestion/loader.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from helia.models.transcript import Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptData:
|
||||
"""Standardized format for loaded transcript data."""
|
||||
|
||||
transcript_id: str
|
||||
text: str
|
||||
# Future: Add ground_truth_scores loaded from DB or S3
|
||||
|
||||
|
||||
class ClinicalDataLoader:
|
||||
"""
|
||||
Loader that fetches transcript data from MongoDB.
|
||||
"""
|
||||
|
||||
async def load_transcript(self, transcript_id: str) -> TranscriptData:
|
||||
"""
|
||||
Loads the transcript text from the MongoDB Transcript collection.
|
||||
"""
|
||||
transcript = await Transcript.find_one(Transcript.transcript_id == transcript_id)
|
||||
if not transcript:
|
||||
raise ValueError(f"Transcript '{transcript_id}' not found in database.")
|
||||
|
||||
# Format utterances into a single text blob
|
||||
full_text = self._format_transcript_text(transcript.utterances)
|
||||
|
||||
return TranscriptData(
|
||||
transcript_id=transcript_id,
|
||||
text=full_text,
|
||||
)
|
||||
|
||||
def _format_transcript_text(self, utterances: list) -> str:
|
||||
"""Joins utterance objects into a readable dialogue string."""
|
||||
lines = []
|
||||
for u in utterances:
|
||||
# Utterance model has 'speaker' and 'value'
|
||||
lines.append(f"{u.speaker}: {u.value}")
|
||||
return "\n".join(lines)
|
||||
0
src/helia/prompts/__init__.py
Normal file
0
src/helia/prompts/__init__.py
Normal file
51
src/helia/prompts/assessment.py
Normal file
51
src/helia/prompts/assessment.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Prompts for RISEN (Extract -> Map -> Score) Architecture
|
||||
|
||||
EXTRACT_EVIDENCE_PROMPT = """
|
||||
You are a clinical expert analyzing a therapy session transcript.
|
||||
Your goal is to find ALL evidence relevant to the following PHQ-8 symptom.
|
||||
|
||||
SYMPTOM: {symptom_name}
|
||||
DESCRIPTION: {symptom_description}
|
||||
|
||||
TRANSCRIPT:
|
||||
{transcript_text}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Extract verbatim quotes from the patient that relate to this symptom.
|
||||
2. Include context if necessary to understand the quote.
|
||||
3. If no evidence is found, explicitly state "No evidence found".
|
||||
|
||||
Return ONLY the quotes and their context.
|
||||
"""
|
||||
|
||||
MAP_CRITERIA_PROMPT = """
|
||||
You are a clinical expert. You have extracted the following evidence for a specific PHQ-8 symptom.
|
||||
Now, map this evidence to the scoring criteria.
|
||||
|
||||
SYMPTOM: {symptom_name}
|
||||
CRITERIA:
|
||||
- 0: Not at all (0-1 days)
|
||||
- 1: Several days (2-6 days)
|
||||
- 2: More than half the days (7-11 days)
|
||||
- 3: Nearly every day (12-14 days)
|
||||
|
||||
EVIDENCE:
|
||||
{evidence_text}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Analyze the frequency and severity implied by the quotes.
|
||||
2. Compare against the criteria above.
|
||||
3. Explain your reasoning for which bucket the patient falls into.
|
||||
|
||||
Reasoning:
|
||||
"""
|
||||
|
||||
SCORE_ITEM_PROMPT = """
|
||||
Based on the reasoning below, assign the final PHQ-8 score (0-3) for this item.
|
||||
|
||||
SYMPTOM: {symptom_name}
|
||||
REASONING:
|
||||
{reasoning_text}
|
||||
|
||||
Return JSON: {{ "score": int, "final_reasoning": str }}
|
||||
"""
|
||||
Reference in New Issue
Block a user