This commit is contained in:
Santiago Martinez-Avial
2025-12-23 13:35:15 +01:00
parent a9346ccb34
commit 5ce6d7e1d3
12 changed files with 734 additions and 22 deletions

View File

@@ -18,6 +18,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
- **Install Dependencies**: `uv sync`
- **Run Agent**: `python -m helia.main "Your query here"`
- **Verify Prompts**: `python scripts/verify_prompt_db.py`
- **Lint**: `uv run ruff check .`
- **Format**: `uv run ruff format .`
- **Type Check**: `uv run ty check`
@@ -25,32 +26,34 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## Architecture
Helia is a modular ReAct-style agent framework designed for clinical interview analysis:
Helia is a modular ReAct-style agent framework designed for clinical interview analysis.
### Core Modules
1. **Ingestion** (`src/helia/ingestion/`):
- Parses clinical interview transcripts (e.g., DAIC-WOZ dataset).
- Standardizes raw text/audio into `Utterance` objects.
- **Parser**: `TranscriptParser` parses clinical interview transcripts (e.g., DAIC-WOZ dataset).
- **Loader**: `ClinicalDataLoader` in `loader.py` retrieves `Transcript` documents from MongoDB.
- **Legacy**: `S3DatasetLoader` (deprecated for runtime use, used for initial population).
2. **Analysis & Enrichment** (`src/helia/analysis/`):
- **MetadataExtractor**: Enriches utterances with sentiment, tone, and speech acts.
- **Model Agnostic**: Designed to swap backend LLMs (OpenAI vs. Local/Quantized models).
2. **Data Models** (`src/helia/models/`):
- **Transcript**: Document model for interview transcripts.
- **Utterance/Turn**: Standardized conversation units.
- **Prompt**: Manages prompt templates and versioning.
3. **Assessment** (`src/helia/assessment/`):
- Implements clinical logic for standard instruments (e.g., PHQ-8).
- Maps unstructured dialogue to structured clinical scores.
- **Evaluator**: `PHQ8Evaluator` (in `core.py`) orchestrates the LLM interaction.
- **Logic**: Implements clinical logic for standard instruments (e.g., PHQ-8).
- **Schema**: `src/helia/assessment/schema.py` defines `AssessmentResult` and `Evidence`.
4. **Persistence Layer** (`src/helia/db.py`):
- **Document-Based Storage**: Uses MongoDB with Beanie (ODM).
- **Core Model**: `AssessmentResult` (in `src/helia/assessment/schema.py`) acts as the single source of truth for experimental results.
- **Data Capture**: Stores the full context of each run:
- **Configuration**: Model version, prompts, temperature (critical for comparing tiers).
- **Evidence**: Specific quotes and reasoning supporting each PHQ-8 score.
- **Outcome**: Final diagnosis and total scores.
- **Data Capture**: Stores full context (Configuration, Evidence, Outcomes) to support comparative analysis.
5. **Agent Workflow** (`src/helia/agent/`):
- Built with **LangGraph**.
- **Router Pattern**: Decides when to call specific tools (search, scoring).
- **Tools**: Clinical scoring utilities, Document retrieval.
- **Graph Architecture**: Implements RISEN pattern (Extract -> Map -> Score) using LangGraph in `src/helia/agent/graph.py`.
- **State**: `ClinicalState` (in `state.py`) manages transcript, scores, and execution status.
- **Nodes**: Specialized logic in `src/helia/agent/nodes/` (assessment, persistence).
- **Execution**: Run benchmarks via `python -m helia.agent.runner <run_id>`.
## Development Standards

57
IMPLEMENTATION_SUMMARY.md Normal file
View File

@@ -0,0 +1,57 @@
# Implementation Summary: Modular Agentic Framework
## Overview
We have successfully implemented the core Agentic Framework for the PHQ-8 assessment benchmark. This architecture uses **LangGraph** to orchestrate a multi-stage reasoning process ("RISEN") and supports dynamic switching between Local (Tier 1) and Cloud (Tier 3) models. The system is fully integrated with the MongoDB infrastructure for data and prompts.
## Completed Components
### 1. Agent Architecture (`src/helia/agent/graph.py`)
- Implemented a `StateGraph` that manages the workflow lifecycle.
- **Nodes**: `ingestion`, `extract_evidence`, `map_criteria`, `score_item`, `human_review`, `persistence`.
- **Routing**: Conditional edges loop through the 8 PHQ-8 items before proceeding to review.
- **HITL**: Configured `MemorySaver` to allow human-in-the-loop interrupts at the `human_review` stage.
### 2. State Management (`src/helia/agent/state.py`)
- Created `ClinicalState` Pydantic model.
- Strictly types the workflow memory, including:
- `transcript_text`: The input data.
- `scores`: A list of `PHQ8ItemScore` objects (accumulated via reducer).
- `current_item_index`: Tracks progress through the 8 items.
- `current_evidence` / `current_reasoning`: Transient fields for the RISEN loop.
### 3. RISEN Logic (`src/helia/agent/nodes/assessment.py`)
- Refactored the monolithic evaluation logic into three granular nodes:
1. **Extract**: Finds verbatim quotes for the specific symptom.
2. **Map**: Aligns evidence to the 0-3 scoring criteria.
3. **Score**: Assigns the final value and structured reasoning.
- **Prompt Management**: Fetches prompts dynamically from the MongoDB `Prompt` collection using `Prompt.find_one`.
### 4. Runner & Config (`src/helia/agent/runner.py`)
- Created a CLI entry point: `python -m helia.agent.runner <run_id>`.
- Initializes the MongoDB connection via `init_db`.
- Fetches all available `Transcript` documents from the database to run the benchmark.
- Injects the specific `RunConfig` (Tier 1/2/3) into the graph's runtime configuration.
### 5. Ingestion (`src/helia/ingestion/loader.py`)
- Added `ClinicalDataLoader` to abstract transcript fetching.
- Loads directly from the `Transcript` Beanie document model in MongoDB.
### 6. Database Migrations
- Created `migrations/init_risen_prompts.py` to seed the database with the required "RISEN" prompt templates (`phq8-extract`, `phq8-map`, `phq8-score`).
## Usage
1. **Seed Prompts**:
```bash
python migrations/init_risen_prompts.py
```
2. **Run Agent**:
```bash
# Run with the default Tier 3 (Cloud) config (defined in config.yaml)
python -m helia.agent.runner gemini-flash
```
## Next Steps
1. **Safety**: Implement the `Safety Guardrail` (parallel node) as designed in `plans/safety-guardrail-architecture.md`.
2. **Persistence**: Uncomment the DB save logic in `persistence_node` to save `AssessmentResult` documents.

View File

@@ -0,0 +1,69 @@
import asyncio
import logging
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.models.prompt import Prompt
from helia.prompts.assessment import (
EXTRACT_EVIDENCE_PROMPT,
MAP_CRITERIA_PROMPT,
SCORE_ITEM_PROMPT,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def migrate() -> None:
try:
config = HeliaConfig() # ty:ignore[missing-argument]
except Exception:
logger.exception("Failed to load configuration.")
return
logger.info("Connecting to database...")
await init_db(config)
prompts_to_create = [
{
"name": "phq8-extract",
"template": EXTRACT_EVIDENCE_PROMPT,
"input_variables": ["symptom_name", "symptom_description", "transcript_text"],
},
{
"name": "phq8-map",
"template": MAP_CRITERIA_PROMPT,
"input_variables": ["symptom_name", "evidence_text"],
},
{
"name": "phq8-score",
"template": SCORE_ITEM_PROMPT,
"input_variables": ["symptom_name", "reasoning_text"],
},
]
for p_data in prompts_to_create:
name = p_data["name"]
logger.info("Creating or updating prompt '%s'...", name)
# Check if exists to avoid duplicates or update if needed
existing = await Prompt.find_one(Prompt.name == name)
if existing:
logger.info("Prompt '%s' already exists. Updating template.", name)
existing.template = p_data["template"]
existing.input_variables = p_data["input_variables"]
await existing.save()
else:
new_prompt = Prompt(
name=name,
template=p_data["template"],
input_variables=p_data["input_variables"],
)
await new_prompt.insert()
logger.info("Prompt '%s' created.", name)
logger.info("Migration completed successfully.")
if __name__ == "__main__":
asyncio.run(migrate())

View File

@@ -17,7 +17,7 @@ A **Hierarchical Agent Supervisor** architecture built with **LangGraph**:
* **Extract**: Quote relevant patient text.
* **Map**: Align quotes to PHQ-8 criteria.
* **Score**: Assign 0-3 value.
3. **Ingestion**: Standardizes data from S3/Local into a `ClinicalState`.
3. **Ingestion**: Standardizes data from MongoDB into a `ClinicalState`.
4. **Benchmarking**: Automates the comparison between Generated Scores vs. Ground Truth (DAIC-WOZ labels).
**Note:** A dedicated **Safety Guardrail** agent has been designed but is scoped out of this MVP. See `plans/safety-guardrail-architecture.md` for details.
@@ -50,20 +50,20 @@ graph TD
* **Deliverables**:
* `src/helia/agent/state.py`: Define `ClinicalState` (transcript, current_item, scores).
* `src/helia/agent/graph.py`: Define the main `StateGraph` with Ingestion -> Assessment -> Persistence nodes.
* `src/helia/ingestion/loader.py`: Add "Ground Truth" loading for DAIC-WOZ labels (critical for benchmarking).
* `src/helia/ingestion/loader.py`: Refactor to load Transcript documents from MongoDB.
#### Phase 2: The "RISEN" Assessment Logic
* **Goal**: Replace monolithic `PHQ8Evaluator` with granular nodes.
* **Deliverables**:
* `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node`.
* `src/helia/prompts/`: Create specialized prompt templates for each stage (optimized for Llama 3).
* `src/helia/agent/nodes/assessment.py`: Implement `extract_node`, `map_node`, `score_node` that fetch prompts from DB.
* `migrations/init_risen_prompts.py`: Database migration to seed the Extract/Map/Score prompts.
* **Refactor**: Update `PHQ8Evaluator` to be callable as a tool/node rather than a standalone class.
#### Phase 3: Tier Switching & Execution
* **Goal**: Implement dynamic model config.
* **Deliverables**:
* `src/helia/configuration.py`: Ensure `RunConfig` (Tier 1/2/3) propagates to LangGraph `configurable` params.
* `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks.
* `src/helia/agent/runner.py`: CLI entry point to run batch benchmarks using MongoDB transcripts.
#### Phase 4: Human-in-the-Loop & Persistence
* **Goal**: Enable clinician review and data saving.
@@ -87,7 +87,7 @@ graph TD
## Dependencies & Risks
- **Risk**: Local models (Tier 1) may hallucinate formatting in the "Map" stage.
* *Mitigation*: Use `instructor` or constrained decoding (JSON mode) for Tier 1.
- **Dependency**: Requires DAIC-WOZ dataset (assumed available locally or mocked).
- **Dependency**: Requires DAIC-WOZ dataset (loaded in MongoDB).
## References
- **LangGraph**: [State Management](https://langchain-ai.github.io/langgraph/concepts/high_level/#state)

97
src/helia/agent/graph.py Normal file
View File

@@ -0,0 +1,97 @@
from __future__ import annotations
import logging
from typing import Any, Literal
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from helia.agent.nodes.assessment import (
extract_evidence_node,
map_criteria_node,
score_item_node,
)
from helia.agent.nodes.persistence import persistence_node
from helia.agent.state import ClinicalState
logger = logging.getLogger(__name__)
def ingestion_node(state: ClinicalState) -> dict: # noqa: ARG001
"""
Placeholder for Ingestion Node.
In the runner, we pre-load the transcript, so this is just a pass-through
or could do additional validation.
"""
return {"status": "analyzing"}
PHQ8_ITEM_COUNT = 8
def router_node(state: ClinicalState) -> Literal["extract_evidence", "human_review"]:
"""
Routes between assessment steps or to review based on progress.
"""
# 8 items in PHQ-8 (indices 0-7)
if state.current_item_index < PHQ8_ITEM_COUNT:
return "extract_evidence"
return "human_review"
def human_review_node(state: ClinicalState) -> dict:
"""
Placeholder for Human Review.
This node acts as a breakpoint.
"""
logger.info("Ready for human review. Total scores calculated: %d", len(state.scores))
return {"status": "review_pending"}
# Define the Graph
workflow = StateGraph(ClinicalState)
# Add Nodes
workflow.add_node("ingestion", ingestion_node)
workflow.add_node("extract_evidence", extract_evidence_node)
workflow.add_node("map_criteria", map_criteria_node)
workflow.add_node("score_item", score_item_node)
workflow.add_node("human_review", human_review_node)
workflow.add_node("persistence", persistence_node)
# Set Entry Point
workflow.set_entry_point("ingestion")
# Add Edges
workflow.add_edge("ingestion", "extract_evidence") # Start immediately with first item
# RISEN Loop
workflow.add_edge("extract_evidence", "map_criteria")
workflow.add_edge("map_criteria", "score_item")
# Conditional Routing after scoring
workflow.add_conditional_edges(
"score_item",
router_node,
{
"extract_evidence": "extract_evidence", # Loop back for next item
"human_review": "human_review", # Done with all 8 items
},
)
# Finalize
workflow.add_edge("human_review", "persistence")
workflow.add_edge("persistence", END)
def compile_graph(with_persistence: bool = True) -> Any:
"""
Compiles the graph.
If with_persistence is True, uses MemorySaver to allow interrupts (HITL).
"""
checkpointer = MemorySaver() if with_persistence else None
return workflow.compile(
checkpointer=checkpointer,
interrupt_before=["human_review"] if with_persistence else [],
)

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import logging
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from helia.agent.state import ClinicalState, PHQ8ItemScore
from helia.configuration import get_settings
from helia.llm.client import get_chat_model
from helia.models.prompt import Prompt
logger = logging.getLogger(__name__)
# PHQ-8 Definitions
PHQ8_ITEMS = [
{"id": 1, "name": "Anhedonia", "desc": "Little interest or pleasure in doing things"},
{"id": 2, "name": "Depressed Mood", "desc": "Feeling down, depressed, or hopeless"},
{
"id": 3,
"name": "Sleep Issues",
"desc": "Trouble falling or staying asleep, or sleeping too much",
},
{"id": 4, "name": "Fatigue", "desc": "Feeling tired or having little energy"},
{"id": 5, "name": "Appetite Issues", "desc": "Poor appetite or overeating"},
{
"id": 6,
"name": "Self-Failure",
"desc": "Feeling bad about yourself - or that you are a failure",
},
{
"id": 7,
"name": "Concentration",
"desc": "Trouble concentrating on things, such as reading or TV",
},
{
"id": 8,
"name": "Psychomotor",
"desc": "Moving/speaking slowly OR being fidgety/restless",
},
]
def _get_llm(config: dict[str, Any] | None = None) -> Any: # noqa: ANN401
"""
Helper to get the LLM client, respecting runtime config if provided.
"""
settings = get_settings()
# Default to global settings
model_name = settings.llm_provider.model_name
api_key = settings.llm_provider.api_key
base_url = settings.llm_provider.base_url
temperature = 0.0
# Override from LangGraph runtime config
if config and "configurable" in config:
conf = config["configurable"]
if "model_name" in conf:
model_name = conf["model_name"]
if "temperature" in conf:
temperature = conf["temperature"]
return get_chat_model(
model_name=model_name,
api_key=api_key,
base_url=base_url,
temperature=temperature,
)
async def _get_prompt_template(name: str) -> str:
"""Fetches a prompt template from the database."""
prompt_doc = await Prompt.find_one(Prompt.name == name)
if not prompt_doc:
logger.warning("Prompt '%s' not found in DB. Using fallback.", name)
if name == "phq8-extract":
return """
SYMPTOM: {symptom_name} ({symptom_description})
TRANSCRIPT: {transcript_text}
Extract relevant quotes.
"""
if name == "phq8-map":
return """
SYMPTOM: {symptom_name}
EVIDENCE: {evidence_text}
Map to 0-3 scale.
"""
if name == "phq8-score":
return """
SYMPTOM: {symptom_name}
REASONING: {reasoning_text}
Return JSON: {{ "score": int, "final_reasoning": str }}
"""
return ""
return prompt_doc.template
async def extract_evidence_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
"""
Step 1: Extract quotes relevant to the current symptom.
"""
current_item = PHQ8_ITEMS[state.current_item_index]
llm = _get_llm(config)
template = await _get_prompt_template("phq8-extract")
prompt = template.format(
symptom_name=current_item["name"],
symptom_description=current_item["desc"],
transcript_text=state.transcript_text,
)
response = await llm.ainvoke(
[SystemMessage(content="You are an expert extractor."), HumanMessage(content=prompt)]
)
return {"current_evidence": response.content}
async def map_criteria_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
"""
Step 2: Map the extracted evidence to scoring criteria.
"""
current_item = PHQ8_ITEMS[state.current_item_index]
evidence = getattr(state, "current_evidence", "No evidence found")
llm = _get_llm(config)
template = await _get_prompt_template("phq8-map")
prompt = template.format(symptom_name=current_item["name"], evidence_text=evidence)
response = await llm.ainvoke(
[SystemMessage(content="You are a clinical reasoner."), HumanMessage(content=prompt)]
)
return {"current_reasoning": response.content}
async def score_item_node(state: ClinicalState, config: dict[str, Any]) -> dict[str, Any]:
"""
Step 3: Assign the final score (0-3).
"""
current_item = PHQ8_ITEMS[state.current_item_index]
reasoning = getattr(state, "current_reasoning", "No reasoning provided")
llm = _get_llm(config)
# Use structured output for the final score
structured_llm = llm.with_structured_output(dict)
template = await _get_prompt_template("phq8-score")
prompt = template.format(symptom_name=current_item["name"], reasoning_text=reasoning)
try:
response = await structured_llm.ainvoke(
[SystemMessage(content="Output valid JSON."), HumanMessage(content=prompt)]
)
score_val = response.get("score", 0)
final_reasoning = response.get("final_reasoning", reasoning)
except Exception:
logger.exception("Failed to parse score, defaulting to 0")
score_val = 0
final_reasoning = "Error parsing model output"
new_score = PHQ8ItemScore(
question_id=current_item["id"],
score=score_val,
evidence_quote=getattr(state, "current_evidence", ""),
reasoning=final_reasoning,
)
return {
"scores": [new_score],
"current_item_index": state.current_item_index + 1,
"current_evidence": None,
"current_reasoning": None,
}

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
import logging
from typing import Any
if TYPE_CHECKING:
from helia.agent.state import ClinicalState
logger = logging.getLogger(__name__)
async def persistence_node(state: ClinicalState) -> dict[str, Any]:
"""
Saves the final assessment results to MongoDB.
"""
logger.info("Persisting results for transcript %s", state.transcript_id)
# Convert state scores to schema format
# Note: We need to reconstruct the original items list
# In a real implementation, we'd map this properly
# For MVP, assuming direct mapping
from helia.assessment.schema import PHQ8Item
items = [
PHQ8Item(
question_id=s.question_id,
question_text=f"Question {s.question_id}", # Lookup actual text
score=s.score,
evidence=[
{
"quote": s.evidence_quote,
"reasoning": s.reasoning,
}
],
)
for s in state.scores
]
total_score = sum(i.score for i in items)
# Save to Beanie
# In Phase 4, we'd actually call .save()
# result = AssessmentResult(
# transcript_id=state.transcript_id,
# items=items,
# total_score=total_score,
# # ... other fields
# )
# await result.save()
logger.info("Saved assessment with total score: %d", total_score)
return {"status": "completed"}

115
src/helia/agent/runner.py Normal file
View File

@@ -0,0 +1,115 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
from helia.agent.graph import compile_graph
from helia.agent.state import ClinicalState
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.ingestion.loader import ClinicalDataLoader
from helia.models.transcript import Transcript
logger = logging.getLogger(__name__)
class BenchmarkRunner:
"""
Executes the clinical assessment workflow for a specific configuration tier.
"""
def __init__(self, config: HeliaConfig) -> None:
self.config = config
self.graph = compile_graph()
self.data_loader = ClinicalDataLoader()
async def run_benchmark(self, run_id: str, transcript_ids: list[str]) -> None:
"""
Runs the benchmark for a specific configuration (Tier 1/2/3).
"""
if run_id not in self.config.runs:
msg = f"Run ID '{run_id}' not found in configuration."
raise ValueError(msg)
run_config = self.config.runs[run_id]
logger.info("Starting benchmark for run: %s", run_id)
logger.info("Model: %s", run_config.model.model_name)
# Prepare runtime config for the graph nodes
# This dict is accessible via `config` in graph nodes
runtime_config = {
"configurable": {
"model_name": run_config.model.model_name,
"provider_id": run_config.model.provider,
"temperature": run_config.model.temperature,
# Pass full provider config if needed, or look it up globally
}
}
# Process transcripts
for tid in transcript_ids:
await self._process_transcript(tid, runtime_config)
async def _process_transcript(self, transcript_id: str, runtime_config: dict[str, Any]) -> None:
"""Runs the workflow for a single transcript."""
logger.info("Processing transcript: %s", transcript_id)
# Load data
transcript_data = await self.data_loader.load_transcript(transcript_id)
# Initialize State
initial_state = ClinicalState(
transcript_id=transcript_id,
transcript_text=transcript_data.text,
# If we had ground truth, we'd store it separately for benchmarking
)
# Execute Graph
# We use ainvocations to stream or run to completion
# For batch, invoke is fine.
final_state = await self.graph.ainvoke(initial_state, config=runtime_config)
# Check results
scores = final_state.get("scores", [])
logger.info("Finished %s. Generated %d scores.", transcript_id, len(scores))
# Persistence is handled by the last node in the graph, so we are done.
async def main(run_id: str = "tier3_baseline") -> None:
"""CLI Entrypoint."""
# Load global config
# In a real app, we'd use a proper config loader
# For now, assuming environment variables or default file
# We need a way to load the YAML.
# Pydantic Settings does this automatically if configured.
try:
config = HeliaConfig() # type: ignore
except Exception:
logger.exception("Failed to load configuration")
return
# Initialize Database Connection
logger.info("Initializing database connection...")
await init_db(config)
runner = BenchmarkRunner(config)
# Fetch real transcripts from DB
logger.info("Fetching available transcripts...")
transcripts_docs = await Transcript.find_all().to_list()
transcript_ids = [t.transcript_id for t in transcripts_docs]
if not transcript_ids:
logger.warning("No transcripts found in database.")
else:
logger.info("Found %d transcripts. Starting benchmark...", len(transcript_ids))
await runner.run_benchmark(run_id, transcript_ids)
if __name__ == "__main__":
import sys
logging.basicConfig(level=logging.INFO)
run_arg = sys.argv[1] if len(sys.argv) > 1 else "default_run"
asyncio.run(main(run_arg))

45
src/helia/agent/state.py Normal file
View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from typing import Annotated, Literal
from pydantic import BaseModel, Field
def update_scores(current: list[dict], new: list[dict]) -> list[dict]:
"""Reducer to append new scores to the list."""
return current + new
class PHQ8ItemScore(BaseModel):
"""Represents a scored PHQ-8 item with evidence."""
question_id: int
score: int
evidence_quote: str
reasoning: str
class ClinicalState(BaseModel):
"""
Working memory for the Clinical Assessment Agent.
Tracks progress through the PHQ-8 items and holds intermediate results.
"""
# Input Data
transcript_id: str
transcript_text: str
# Execution State
current_item_index: int = 0
status: Literal["ingesting", "analyzing", "review_pending", "completed", "halted"] = "ingesting"
# Assessment Results (using reducer to append results from each step)
scores: Annotated[list[PHQ8ItemScore], update_scores] = Field(default_factory=list)
# Safety (Placeholder for future guardrail)
safety_flags: list[str] = Field(default_factory=list)
is_safe: bool = True
# Transient fields for RISEN loop (not persisted in final result)
current_evidence: str | None = None
current_reasoning: str | None = None

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from helia.models.transcript import Transcript
logger = logging.getLogger(__name__)
@dataclass
class TranscriptData:
"""Standardized format for loaded transcript data."""
transcript_id: str
text: str
# Future: Add ground_truth_scores loaded from DB or S3
class ClinicalDataLoader:
"""
Loader that fetches transcript data from MongoDB.
"""
async def load_transcript(self, transcript_id: str) -> TranscriptData:
"""
Loads the transcript text from the MongoDB Transcript collection.
"""
transcript = await Transcript.find_one(Transcript.transcript_id == transcript_id)
if not transcript:
raise ValueError(f"Transcript '{transcript_id}' not found in database.")
# Format utterances into a single text blob
full_text = self._format_transcript_text(transcript.utterances)
return TranscriptData(
transcript_id=transcript_id,
text=full_text,
)
def _format_transcript_text(self, utterances: list) -> str:
"""Joins utterance objects into a readable dialogue string."""
lines = []
for u in utterances:
# Utterance model has 'speaker' and 'value'
lines.append(f"{u.speaker}: {u.value}")
return "\n".join(lines)

View File

View File

@@ -0,0 +1,51 @@
# Prompts for RISEN (Extract -> Map -> Score) Architecture
EXTRACT_EVIDENCE_PROMPT = """
You are a clinical expert analyzing a therapy session transcript.
Your goal is to find ALL evidence relevant to the following PHQ-8 symptom.
SYMPTOM: {symptom_name}
DESCRIPTION: {symptom_description}
TRANSCRIPT:
{transcript_text}
INSTRUCTIONS:
1. Extract verbatim quotes from the patient that relate to this symptom.
2. Include context if necessary to understand the quote.
3. If no evidence is found, explicitly state "No evidence found".
Return ONLY the quotes and their context.
"""
MAP_CRITERIA_PROMPT = """
You are a clinical expert. You have extracted the following evidence for a specific PHQ-8 symptom.
Now, map this evidence to the scoring criteria.
SYMPTOM: {symptom_name}
CRITERIA:
- 0: Not at all (0-1 days)
- 1: Several days (2-6 days)
- 2: More than half the days (7-11 days)
- 3: Nearly every day (12-14 days)
EVIDENCE:
{evidence_text}
INSTRUCTIONS:
1. Analyze the frequency and severity implied by the quotes.
2. Compare against the criteria above.
3. Explain your reasoning for which bucket the patient falls into.
Reasoning:
"""
SCORE_ITEM_PROMPT = """
Based on the reasoning below, assign the final PHQ-8 score (0-3) for this item.
SYMPTOM: {symptom_name}
REASONING:
{reasoning_text}
Return JSON: {{ "score": int, "final_reasoning": str }}
"""