feat: implement PHQ-8 assessment prompt and refactor related components for improved functionality

This commit is contained in:
Santiago Martinez-Avial
2025-12-23 01:20:20 +01:00
parent 69fc70ea65
commit a9346ccb34
11 changed files with 166 additions and 92 deletions

View File

@@ -0,0 +1,3 @@
from . import init_db
__all__ = ["init_db"]

12
migrations/init_db.py Normal file
View File

@@ -0,0 +1,12 @@
import asyncio
from . import init_prompts, init_transcripts
async def migrate() -> None:
await init_prompts.migrate()
await init_transcripts.migrate()
if __name__ == "__main__":
asyncio.run(migrate())

View File

@@ -0,0 +1,69 @@
import asyncio
import logging
from helia.configuration import HeliaConfig
from helia.db import init_db
from helia.models.prompt import Prompt
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEFAULT_PROMPT = """You are an expert clinical psychologist analyzing a patient interview transcript.
Your task is to assess the patient according to the PHQ-8 (Patient Health Questionnaire-8) criteria.
The transcript is provided below. You must analyze the ENTIRE transcript to find evidence for each of the 8 criteria.
CRITERIA TO ASSESS:
1. Little interest or pleasure in doing things (Anhedonia)
2. Feeling down, depressed, or hopeless
3. Trouble falling or staying asleep, or sleeping too much
4. Feeling tired or having little energy
5. Poor appetite or overeating
6. Feeling bad about yourself - or that you are a failure or have let yourself or your family down
7. Trouble concentrating on things, such as reading the newspaper or watching television
8. Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual (Psychomotor agitation/retardation)
SCORING SCALE:
0 = Not at all
1 = Several days
2 = More than half the days
3 = Nearly every day
INSTRUCTIONS:
- For EACH of the 8 items, you must provide:
- A score (0-3).
- A list of DIRECT QUOTES from the transcript that support this score.
- A reasoning explanation.
- If there is NO evidence for a symptom, score it as 0.
- Be conservative: do not hallucinate symptoms. Only score if the patient explicitly mentions it or strong context implies it.
TRANSCRIPT:
{transcript_text}
"""
async def migrate() -> None:
try:
config = HeliaConfig() # ty:ignore[missing-argument]
except Exception:
logger.exception("Failed to load configuration: %s.")
return
logger.info("Connecting to database...")
await init_db(config)
prompt_name = "phq8-assessment"
logger.info("Creating initial prompt '%s'...", prompt_name)
new_prompt = Prompt(
name=prompt_name,
template=DEFAULT_PROMPT,
input_variables=["transcript_text"],
)
await new_prompt.insert()
logger.info("Prompt created successfully.")
if __name__ == "__main__":
asyncio.run(migrate())

View File

@@ -65,10 +65,7 @@ async def migrate() -> None:
s3_loader = S3DatasetLoader(config.s3)
keys = await s3_loader.list_transcripts()
logger.info("Found %d transcript files in S3.", len(keys))
await Transcript.delete_all()
logger.info("Cleared existing transcripts.")
logger.info("%d transcript files in S3.", len(keys))
id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$")

View File

@@ -8,6 +8,7 @@ from helia.assessment.schema import (
EvaluatorConfig,
)
from helia.llm.client import get_chat_model
from helia.models.prompt import Prompt
if TYPE_CHECKING:
from helia.configuration import RunConfig
@@ -19,40 +20,6 @@ SYMPTOM_SEVERITY_THRESHOLD = 2
MAJOR_DEPRESSION_MIN_SEVERE = 5
OTHER_DEPRESSION_MIN_SEVERE = 2
DEFAULT_PROMPT = """
You are an expert clinical psychologist analyzing a patient interview transcript.
Your task is to assess the patient according to the PHQ-8 (Patient Health Questionnaire-8) criteria.
The transcript is provided below. You must analyze the ENTIRE transcript to find evidence for each of the 8 criteria.
CRITERIA TO ASSESS:
1. Little interest or pleasure in doing things (Anhedonia)
2. Feeling down, depressed, or hopeless
3. Trouble falling or staying asleep, or sleeping too much
4. Feeling tired or having little energy
5. Poor appetite or overeating
6. Feeling bad about yourself - or that you are a failure or have let yourself or your family down
7. Trouble concentrating on things, such as reading the newspaper or watching television
8. Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual (Psychomotor agitation/retardation)
SCORING SCALE:
0 = Not at all
1 = Several days
2 = More than half the days
3 = Nearly every day
INSTRUCTIONS:
- For EACH of the 8 items, you must provide:
- A score (0-3).
- A list of DIRECT QUOTES from the transcript that support this score.
- A reasoning explanation.
- If there is NO evidence for a symptom, score it as 0.
- Be conservative: do not hallucinate symptoms. Only score if the patient explicitly mentions it or strong context implies it.
TRANSCRIPT:
{transcript_text}
"""
class PHQ8Evaluator:
def __init__(self, config: EvaluatorConfig) -> None:
@@ -65,6 +32,27 @@ class PHQ8Evaluator:
temperature=self.config.temperature,
)
async def _get_prompt_template(self) -> str:
"""
Fetch the prompt template from the database based on configuration.
Falls back to the hardcoded default if not found.
"""
query = Prompt.find(Prompt.name == self.config.prompt_id)
if self.config.prompt_version is not None:
query = query.find(Prompt.version == self.config.prompt_version)
prompt_doc = await query.first_or_none()
else:
prompt_doc = await query.sort("-version").first_or_none()
if prompt_doc:
return prompt_doc.template
raise ValueError(
f"Prompt '{self.config.prompt_id}' "
f"version '{self.config.prompt_version}' not found in database."
)
async def evaluate(
self, transcript: Transcript, original_run_config: RunConfig
) -> AssessmentResult:
@@ -73,10 +61,9 @@ class PHQ8Evaluator:
"""
transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
# 2. Prepare Prompt
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
template = await self._get_prompt_template()
final_prompt = template.format(transcript_text=transcript_text)
# 3. Call LLM (Async with Structured Output)
structured_llm = self.llm.with_structured_output(AssessmentResponse)
messages = [
@@ -87,7 +74,6 @@ class PHQ8Evaluator:
response_obj = cast("AssessmentResponse", await structured_llm.ainvoke(messages))
items = response_obj.items
# 4. Calculate Diagnostics
total_score = sum(item.score for item in items)
diagnosis_cutpoint = total_score >= DIAGNOSIS_THRESHOLD
@@ -97,7 +83,7 @@ class PHQ8Evaluator:
count_severe = sum(1 for i in items if i.score >= SYMPTOM_SEVERITY_THRESHOLD)
has_core = (items[0].score >= SYMPTOM_SEVERITY_THRESHOLD) or (
items[1].score >= SYMPTOM_SEVERITY_THRESHOLD
) # Q1 or Q2
)
diagnosis_algorithm = "None"
if has_core:

View File

@@ -24,6 +24,7 @@ class EvaluatorConfig(BaseModel):
api_spec: str = "openai"
temperature: float = 0.0
prompt_id: str = "default"
prompt_version: int | None = None
class PHQ8Item(BaseModel):

View File

@@ -50,6 +50,7 @@ class RunConfig(BaseModel):
model: ModelConfig
prompt_id: str = "default"
prompt_version: int | None = None
class HeliaConfig(BaseSettings):

View File

@@ -6,6 +6,7 @@ from beanie import init_beanie
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from helia.assessment.schema import AssessmentResult
from helia.models.prompt import Prompt
from helia.models.transcript import Transcript
if TYPE_CHECKING:
@@ -17,6 +18,6 @@ async def init_db(config: HeliaConfig) -> AsyncMongoClient:
database = client[config.mongo.db_name]
await init_beanie(
database=database,
document_models=[AssessmentResult, Transcript],
document_models=[AssessmentResult, Transcript, Prompt],
)
return client

View File

@@ -46,6 +46,7 @@ async def process_run(
api_key=provider_config.api_key,
api_spec=provider_config.api_spec,
prompt_id=run_config.prompt_id,
prompt_version=run_config.prompt_version,
temperature=run_config.model.temperature,
)
@@ -102,7 +103,7 @@ async def main() -> None:
await init_db(config)
logger.debug("Discovering transcripts in MongoDB...")
limit=args.limit or config.patient_limit
limit = args.limit or config.patient_limit
query = Transcript.find_all(limit=limit)
if limit is not None:

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Annotated, Never
from beanie import (
Document,
Indexed,
PydanticObjectId,
Replace,
Save,
SaveChanges,
Update,
before_event,
)
from pydantic import Field
class Prompt(Document):
name: str
template: Annotated[str, Indexed(unique=True)]
input_variables: set[str]
parent: Annotated[PydanticObjectId | None, Indexed(sparse=True)] = None
children: set[Annotated[PydanticObjectId | None, Indexed(sparse=True)]] = Field(
default_factory=set
)
class Settings:
name = "prompts"
validate_on_save = True
@before_event([Save, Replace, Update, SaveChanges])
def check_immutability(self) -> Never:
raise ValueError("Prompts are immutable.")

View File

@@ -1,8 +1,8 @@
from typing import ClassVar, Literal
from typing import Annotated, Literal
from beanie import Document
from beanie import Document, Indexed
from pydantic import BaseModel
from pymongo import ASCENDING
from pymongo import ASCENDING, DESCENDING, TEXT, IndexModel
class Utterance(BaseModel):
@@ -21,50 +21,20 @@ class Turn(BaseModel):
class Transcript(Document):
transcript_id: str
utterances: list[Utterance]
@property
def turns(self) -> list[Turn]:
"""
Aggregates consecutive utterances from the same speaker into a single Turn.
"""
if not self.utterances:
return []
turns: list[Turn] = []
current_batch: list[Utterance] = []
for utterance in self.utterances:
if not current_batch:
current_batch.append(utterance)
continue
if utterance.speaker == current_batch[-1].speaker:
current_batch.append(utterance)
else:
turns.append(self._create_turn(current_batch))
current_batch = [utterance]
if current_batch:
turns.append(self._create_turn(current_batch))
return turns
def _create_turn(self, batch: list[Utterance]) -> Turn:
return Turn(
speaker=batch[0].speaker,
value=" ".join(u.value for u in batch),
start_time=batch[0].start_time,
end_time=batch[-1].end_time,
utterance_count=len(batch),
)
transcript_id: Annotated[str, Indexed(index_type=ASCENDING, unique=True)]
utterances: Annotated[list[Utterance], Indexed(unique=True)]
class Settings:
name = "transcripts"
indexes: ClassVar = [
[("transcript_id", ASCENDING)],
]
index_defs: ClassVar = [
{"key": "transcript_id", "unique": True},
]
indexes = (
IndexModel([("utterances.value", TEXT)]),
IndexModel(
[
("utterances.value", DESCENDING),
("utterances.start_time", DESCENDING),
("utterances.end_time", DESCENDING),
],
unique=True,
),
)
validate_on_save = True