feat: implement PHQ-8 assessment prompt and refactor related components for improved functionality
This commit is contained in:
@@ -0,0 +1,3 @@
|
|||||||
|
from . import init_db
|
||||||
|
|
||||||
|
__all__ = ["init_db"]
|
||||||
|
|||||||
12
migrations/init_db.py
Normal file
12
migrations/init_db.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from . import init_prompts, init_transcripts
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate() -> None:
|
||||||
|
await init_prompts.migrate()
|
||||||
|
await init_transcripts.migrate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(migrate())
|
||||||
69
migrations/init_prompts.py
Normal file
69
migrations/init_prompts.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from helia.configuration import HeliaConfig
|
||||||
|
from helia.db import init_db
|
||||||
|
from helia.models.prompt import Prompt
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_PROMPT = """You are an expert clinical psychologist analyzing a patient interview transcript.
|
||||||
|
Your task is to assess the patient according to the PHQ-8 (Patient Health Questionnaire-8) criteria.
|
||||||
|
|
||||||
|
The transcript is provided below. You must analyze the ENTIRE transcript to find evidence for each of the 8 criteria.
|
||||||
|
|
||||||
|
CRITERIA TO ASSESS:
|
||||||
|
1. Little interest or pleasure in doing things (Anhedonia)
|
||||||
|
2. Feeling down, depressed, or hopeless
|
||||||
|
3. Trouble falling or staying asleep, or sleeping too much
|
||||||
|
4. Feeling tired or having little energy
|
||||||
|
5. Poor appetite or overeating
|
||||||
|
6. Feeling bad about yourself - or that you are a failure or have let yourself or your family down
|
||||||
|
7. Trouble concentrating on things, such as reading the newspaper or watching television
|
||||||
|
8. Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual (Psychomotor agitation/retardation)
|
||||||
|
|
||||||
|
SCORING SCALE:
|
||||||
|
0 = Not at all
|
||||||
|
1 = Several days
|
||||||
|
2 = More than half the days
|
||||||
|
3 = Nearly every day
|
||||||
|
|
||||||
|
INSTRUCTIONS:
|
||||||
|
- For EACH of the 8 items, you must provide:
|
||||||
|
- A score (0-3).
|
||||||
|
- A list of DIRECT QUOTES from the transcript that support this score.
|
||||||
|
- A reasoning explanation.
|
||||||
|
- If there is NO evidence for a symptom, score it as 0.
|
||||||
|
- Be conservative: do not hallucinate symptoms. Only score if the patient explicitly mentions it or strong context implies it.
|
||||||
|
|
||||||
|
TRANSCRIPT:
|
||||||
|
{transcript_text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate() -> None:
|
||||||
|
try:
|
||||||
|
config = HeliaConfig() # ty:ignore[missing-argument]
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load configuration: %s.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Connecting to database...")
|
||||||
|
await init_db(config)
|
||||||
|
|
||||||
|
prompt_name = "phq8-assessment"
|
||||||
|
|
||||||
|
logger.info("Creating initial prompt '%s'...", prompt_name)
|
||||||
|
new_prompt = Prompt(
|
||||||
|
name=prompt_name,
|
||||||
|
template=DEFAULT_PROMPT,
|
||||||
|
input_variables=["transcript_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
await new_prompt.insert()
|
||||||
|
logger.info("Prompt created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(migrate())
|
||||||
@@ -65,10 +65,7 @@ async def migrate() -> None:
|
|||||||
|
|
||||||
s3_loader = S3DatasetLoader(config.s3)
|
s3_loader = S3DatasetLoader(config.s3)
|
||||||
keys = await s3_loader.list_transcripts()
|
keys = await s3_loader.list_transcripts()
|
||||||
logger.info("Found %d transcript files in S3.", len(keys))
|
logger.info("%d transcript files in S3.", len(keys))
|
||||||
|
|
||||||
await Transcript.delete_all()
|
|
||||||
logger.info("Cleared existing transcripts.")
|
|
||||||
|
|
||||||
id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$")
|
id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$")
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from helia.assessment.schema import (
|
|||||||
EvaluatorConfig,
|
EvaluatorConfig,
|
||||||
)
|
)
|
||||||
from helia.llm.client import get_chat_model
|
from helia.llm.client import get_chat_model
|
||||||
|
from helia.models.prompt import Prompt
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from helia.configuration import RunConfig
|
from helia.configuration import RunConfig
|
||||||
@@ -19,40 +20,6 @@ SYMPTOM_SEVERITY_THRESHOLD = 2
|
|||||||
MAJOR_DEPRESSION_MIN_SEVERE = 5
|
MAJOR_DEPRESSION_MIN_SEVERE = 5
|
||||||
OTHER_DEPRESSION_MIN_SEVERE = 2
|
OTHER_DEPRESSION_MIN_SEVERE = 2
|
||||||
|
|
||||||
DEFAULT_PROMPT = """
|
|
||||||
You are an expert clinical psychologist analyzing a patient interview transcript.
|
|
||||||
Your task is to assess the patient according to the PHQ-8 (Patient Health Questionnaire-8) criteria.
|
|
||||||
|
|
||||||
The transcript is provided below. You must analyze the ENTIRE transcript to find evidence for each of the 8 criteria.
|
|
||||||
|
|
||||||
CRITERIA TO ASSESS:
|
|
||||||
1. Little interest or pleasure in doing things (Anhedonia)
|
|
||||||
2. Feeling down, depressed, or hopeless
|
|
||||||
3. Trouble falling or staying asleep, or sleeping too much
|
|
||||||
4. Feeling tired or having little energy
|
|
||||||
5. Poor appetite or overeating
|
|
||||||
6. Feeling bad about yourself - or that you are a failure or have let yourself or your family down
|
|
||||||
7. Trouble concentrating on things, such as reading the newspaper or watching television
|
|
||||||
8. Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual (Psychomotor agitation/retardation)
|
|
||||||
|
|
||||||
SCORING SCALE:
|
|
||||||
0 = Not at all
|
|
||||||
1 = Several days
|
|
||||||
2 = More than half the days
|
|
||||||
3 = Nearly every day
|
|
||||||
|
|
||||||
INSTRUCTIONS:
|
|
||||||
- For EACH of the 8 items, you must provide:
|
|
||||||
- A score (0-3).
|
|
||||||
- A list of DIRECT QUOTES from the transcript that support this score.
|
|
||||||
- A reasoning explanation.
|
|
||||||
- If there is NO evidence for a symptom, score it as 0.
|
|
||||||
- Be conservative: do not hallucinate symptoms. Only score if the patient explicitly mentions it or strong context implies it.
|
|
||||||
|
|
||||||
TRANSCRIPT:
|
|
||||||
{transcript_text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PHQ8Evaluator:
|
class PHQ8Evaluator:
|
||||||
def __init__(self, config: EvaluatorConfig) -> None:
|
def __init__(self, config: EvaluatorConfig) -> None:
|
||||||
@@ -65,6 +32,27 @@ class PHQ8Evaluator:
|
|||||||
temperature=self.config.temperature,
|
temperature=self.config.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _get_prompt_template(self) -> str:
|
||||||
|
"""
|
||||||
|
Fetch the prompt template from the database based on configuration.
|
||||||
|
Falls back to the hardcoded default if not found.
|
||||||
|
"""
|
||||||
|
query = Prompt.find(Prompt.name == self.config.prompt_id)
|
||||||
|
|
||||||
|
if self.config.prompt_version is not None:
|
||||||
|
query = query.find(Prompt.version == self.config.prompt_version)
|
||||||
|
prompt_doc = await query.first_or_none()
|
||||||
|
else:
|
||||||
|
prompt_doc = await query.sort("-version").first_or_none()
|
||||||
|
|
||||||
|
if prompt_doc:
|
||||||
|
return prompt_doc.template
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Prompt '{self.config.prompt_id}' "
|
||||||
|
f"version '{self.config.prompt_version}' not found in database."
|
||||||
|
)
|
||||||
|
|
||||||
async def evaluate(
|
async def evaluate(
|
||||||
self, transcript: Transcript, original_run_config: RunConfig
|
self, transcript: Transcript, original_run_config: RunConfig
|
||||||
) -> AssessmentResult:
|
) -> AssessmentResult:
|
||||||
@@ -73,10 +61,9 @@ class PHQ8Evaluator:
|
|||||||
"""
|
"""
|
||||||
transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
|
transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances])
|
||||||
|
|
||||||
# 2. Prepare Prompt
|
template = await self._get_prompt_template()
|
||||||
final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text)
|
final_prompt = template.format(transcript_text=transcript_text)
|
||||||
|
|
||||||
# 3. Call LLM (Async with Structured Output)
|
|
||||||
structured_llm = self.llm.with_structured_output(AssessmentResponse)
|
structured_llm = self.llm.with_structured_output(AssessmentResponse)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -87,7 +74,6 @@ class PHQ8Evaluator:
|
|||||||
response_obj = cast("AssessmentResponse", await structured_llm.ainvoke(messages))
|
response_obj = cast("AssessmentResponse", await structured_llm.ainvoke(messages))
|
||||||
items = response_obj.items
|
items = response_obj.items
|
||||||
|
|
||||||
# 4. Calculate Diagnostics
|
|
||||||
total_score = sum(item.score for item in items)
|
total_score = sum(item.score for item in items)
|
||||||
diagnosis_cutpoint = total_score >= DIAGNOSIS_THRESHOLD
|
diagnosis_cutpoint = total_score >= DIAGNOSIS_THRESHOLD
|
||||||
|
|
||||||
@@ -97,7 +83,7 @@ class PHQ8Evaluator:
|
|||||||
count_severe = sum(1 for i in items if i.score >= SYMPTOM_SEVERITY_THRESHOLD)
|
count_severe = sum(1 for i in items if i.score >= SYMPTOM_SEVERITY_THRESHOLD)
|
||||||
has_core = (items[0].score >= SYMPTOM_SEVERITY_THRESHOLD) or (
|
has_core = (items[0].score >= SYMPTOM_SEVERITY_THRESHOLD) or (
|
||||||
items[1].score >= SYMPTOM_SEVERITY_THRESHOLD
|
items[1].score >= SYMPTOM_SEVERITY_THRESHOLD
|
||||||
) # Q1 or Q2
|
)
|
||||||
|
|
||||||
diagnosis_algorithm = "None"
|
diagnosis_algorithm = "None"
|
||||||
if has_core:
|
if has_core:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class EvaluatorConfig(BaseModel):
|
|||||||
api_spec: str = "openai"
|
api_spec: str = "openai"
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
prompt_id: str = "default"
|
prompt_id: str = "default"
|
||||||
|
prompt_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class PHQ8Item(BaseModel):
|
class PHQ8Item(BaseModel):
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class RunConfig(BaseModel):
|
|||||||
|
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
prompt_id: str = "default"
|
prompt_id: str = "default"
|
||||||
|
prompt_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class HeliaConfig(BaseSettings):
|
class HeliaConfig(BaseSettings):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from beanie import init_beanie
|
|||||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||||
|
|
||||||
from helia.assessment.schema import AssessmentResult
|
from helia.assessment.schema import AssessmentResult
|
||||||
|
from helia.models.prompt import Prompt
|
||||||
from helia.models.transcript import Transcript
|
from helia.models.transcript import Transcript
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -17,6 +18,6 @@ async def init_db(config: HeliaConfig) -> AsyncMongoClient:
|
|||||||
database = client[config.mongo.db_name]
|
database = client[config.mongo.db_name]
|
||||||
await init_beanie(
|
await init_beanie(
|
||||||
database=database,
|
database=database,
|
||||||
document_models=[AssessmentResult, Transcript],
|
document_models=[AssessmentResult, Transcript, Prompt],
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ async def process_run(
|
|||||||
api_key=provider_config.api_key,
|
api_key=provider_config.api_key,
|
||||||
api_spec=provider_config.api_spec,
|
api_spec=provider_config.api_spec,
|
||||||
prompt_id=run_config.prompt_id,
|
prompt_id=run_config.prompt_id,
|
||||||
|
prompt_version=run_config.prompt_version,
|
||||||
temperature=run_config.model.temperature,
|
temperature=run_config.model.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
33
src/helia/models/prompt.py
Normal file
33
src/helia/models/prompt.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Never
|
||||||
|
|
||||||
|
from beanie import (
|
||||||
|
Document,
|
||||||
|
Indexed,
|
||||||
|
PydanticObjectId,
|
||||||
|
Replace,
|
||||||
|
Save,
|
||||||
|
SaveChanges,
|
||||||
|
Update,
|
||||||
|
before_event,
|
||||||
|
)
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(Document):
|
||||||
|
name: str
|
||||||
|
template: Annotated[str, Indexed(unique=True)]
|
||||||
|
input_variables: set[str]
|
||||||
|
parent: Annotated[PydanticObjectId | None, Indexed(sparse=True)] = None
|
||||||
|
children: set[Annotated[PydanticObjectId | None, Indexed(sparse=True)]] = Field(
|
||||||
|
default_factory=set
|
||||||
|
)
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
name = "prompts"
|
||||||
|
validate_on_save = True
|
||||||
|
|
||||||
|
@before_event([Save, Replace, Update, SaveChanges])
|
||||||
|
def check_immutability(self) -> Never:
|
||||||
|
raise ValueError("Prompts are immutable.")
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
from typing import ClassVar, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from beanie import Document
|
from beanie import Document, Indexed
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pymongo import ASCENDING
|
from pymongo import ASCENDING, DESCENDING, TEXT, IndexModel
|
||||||
|
|
||||||
|
|
||||||
class Utterance(BaseModel):
|
class Utterance(BaseModel):
|
||||||
@@ -21,50 +21,20 @@ class Turn(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Transcript(Document):
|
class Transcript(Document):
|
||||||
transcript_id: str
|
transcript_id: Annotated[str, Indexed(index_type=ASCENDING, unique=True)]
|
||||||
utterances: list[Utterance]
|
utterances: Annotated[list[Utterance], Indexed(unique=True)]
|
||||||
|
|
||||||
@property
|
|
||||||
def turns(self) -> list[Turn]:
|
|
||||||
"""
|
|
||||||
Aggregates consecutive utterances from the same speaker into a single Turn.
|
|
||||||
"""
|
|
||||||
if not self.utterances:
|
|
||||||
return []
|
|
||||||
|
|
||||||
turns: list[Turn] = []
|
|
||||||
current_batch: list[Utterance] = []
|
|
||||||
|
|
||||||
for utterance in self.utterances:
|
|
||||||
if not current_batch:
|
|
||||||
current_batch.append(utterance)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if utterance.speaker == current_batch[-1].speaker:
|
|
||||||
current_batch.append(utterance)
|
|
||||||
else:
|
|
||||||
turns.append(self._create_turn(current_batch))
|
|
||||||
current_batch = [utterance]
|
|
||||||
|
|
||||||
if current_batch:
|
|
||||||
turns.append(self._create_turn(current_batch))
|
|
||||||
|
|
||||||
return turns
|
|
||||||
|
|
||||||
def _create_turn(self, batch: list[Utterance]) -> Turn:
|
|
||||||
return Turn(
|
|
||||||
speaker=batch[0].speaker,
|
|
||||||
value=" ".join(u.value for u in batch),
|
|
||||||
start_time=batch[0].start_time,
|
|
||||||
end_time=batch[-1].end_time,
|
|
||||||
utterance_count=len(batch),
|
|
||||||
)
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
name = "transcripts"
|
name = "transcripts"
|
||||||
indexes: ClassVar = [
|
indexes = (
|
||||||
[("transcript_id", ASCENDING)],
|
IndexModel([("utterances.value", TEXT)]),
|
||||||
]
|
IndexModel(
|
||||||
index_defs: ClassVar = [
|
[
|
||||||
{"key": "transcript_id", "unique": True},
|
("utterances.value", DESCENDING),
|
||||||
]
|
("utterances.start_time", DESCENDING),
|
||||||
|
("utterances.end_time", DESCENDING),
|
||||||
|
],
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
validate_on_save = True
|
||||||
|
|||||||
Reference in New Issue
Block a user