diff --git a/migrations/__init__.py b/migrations/__init__.py index e69de29..13d49ea 100644 --- a/migrations/__init__.py +++ b/migrations/__init__.py @@ -0,0 +1,3 @@ +from . import init_db + +__all__ = ["init_db"] diff --git a/migrations/init_db.py b/migrations/init_db.py new file mode 100644 index 0000000..8e8457c --- /dev/null +++ b/migrations/init_db.py @@ -0,0 +1,12 @@ +import asyncio + +from . import init_prompts, init_transcripts + + +async def migrate() -> None: + await init_prompts.migrate() + await init_transcripts.migrate() + + +if __name__ == "__main__": + asyncio.run(migrate()) diff --git a/migrations/init_prompts.py b/migrations/init_prompts.py new file mode 100644 index 0000000..dbf2f8d --- /dev/null +++ b/migrations/init_prompts.py @@ -0,0 +1,69 @@ +import asyncio +import logging + +from helia.configuration import HeliaConfig +from helia.db import init_db +from helia.models.prompt import Prompt + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +DEFAULT_PROMPT = """You are an expert clinical psychologist analyzing a patient interview transcript. +Your task is to assess the patient according to the PHQ-8 (Patient Health Questionnaire-8) criteria. + +The transcript is provided below. You must analyze the ENTIRE transcript to find evidence for each of the 8 criteria. + +CRITERIA TO ASSESS: +1. Little interest or pleasure in doing things (Anhedonia) +2. Feeling down, depressed, or hopeless +3. Trouble falling or staying asleep, or sleeping too much +4. Feeling tired or having little energy +5. Poor appetite or overeating +6. Feeling bad about yourself - or that you are a failure or have let yourself or your family down +7. Trouble concentrating on things, such as reading the newspaper or watching television +8. Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual (Psychomotor agitation/retardation) + +SCORING SCALE: +0 = Not at all +1 = Several days +2 = More than half the days +3 = Nearly every day + +INSTRUCTIONS: +- For EACH of the 8 items, you must provide: + - A score (0-3). + - A list of DIRECT QUOTES from the transcript that support this score. + - A reasoning explanation. +- If there is NO evidence for a symptom, score it as 0. +- Be conservative: do not hallucinate symptoms. Only score if the patient explicitly mentions it or strong context implies it. + +TRANSCRIPT: +{transcript_text} +""" + + +async def migrate() -> None: + try: + config = HeliaConfig() # ty:ignore[missing-argument] + except Exception: + logger.exception("Failed to load configuration: %s.") + return + + logger.info("Connecting to database...") + await init_db(config) + + prompt_name = "phq8-assessment" + + logger.info("Creating initial prompt '%s'...", prompt_name) + new_prompt = Prompt( + name=prompt_name, + template=DEFAULT_PROMPT, + input_variables=["transcript_text"], + ) + + await new_prompt.insert() + logger.info("Prompt created successfully.") + + +if __name__ == "__main__": + asyncio.run(migrate()) diff --git a/migrations/init_transcripts.py b/migrations/init_transcripts.py index 27a7c1c..f51e61c 100644 --- a/migrations/init_transcripts.py +++ b/migrations/init_transcripts.py @@ -65,10 +65,7 @@ async def migrate() -> None: s3_loader = S3DatasetLoader(config.s3) keys = await s3_loader.list_transcripts() - logger.info("Found %d transcript files in S3.", len(keys)) - - await Transcript.delete_all() - logger.info("Cleared existing transcripts.") + logger.info("%d transcript files in S3.", len(keys)) id_pattern = re.compile(r"/(\d+)_TRANSCRIPT\.csv$") diff --git a/src/helia/assessment/core.py b/src/helia/assessment/core.py index 44b8929..0fd6a47 100644 --- a/src/helia/assessment/core.py +++ b/src/helia/assessment/core.py @@ -8,6 +8,7 @@ from helia.assessment.schema import ( EvaluatorConfig, ) from helia.llm.client import get_chat_model +from helia.models.prompt import Prompt if TYPE_CHECKING: from helia.configuration import RunConfig @@ -19,40 +20,6 @@ SYMPTOM_SEVERITY_THRESHOLD = 2 MAJOR_DEPRESSION_MIN_SEVERE = 5 OTHER_DEPRESSION_MIN_SEVERE = 2 -DEFAULT_PROMPT = """ -You are an expert clinical psychologist analyzing a patient interview transcript. -Your task is to assess the patient according to the PHQ-8 (Patient Health Questionnaire-8) criteria. - -The transcript is provided below. You must analyze the ENTIRE transcript to find evidence for each of the 8 criteria. - -CRITERIA TO ASSESS: -1. Little interest or pleasure in doing things (Anhedonia) -2. Feeling down, depressed, or hopeless -3. Trouble falling or staying asleep, or sleeping too much -4. Feeling tired or having little energy -5. Poor appetite or overeating -6. Feeling bad about yourself - or that you are a failure or have let yourself or your family down -7. Trouble concentrating on things, such as reading the newspaper or watching television -8. Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual (Psychomotor agitation/retardation) - -SCORING SCALE: -0 = Not at all -1 = Several days -2 = More than half the days -3 = Nearly every day - -INSTRUCTIONS: -- For EACH of the 8 items, you must provide: - - A score (0-3). - - A list of DIRECT QUOTES from the transcript that support this score. - - A reasoning explanation. -- If there is NO evidence for a symptom, score it as 0. -- Be conservative: do not hallucinate symptoms. Only score if the patient explicitly mentions it or strong context implies it. - -TRANSCRIPT: -{transcript_text} -""" - class PHQ8Evaluator: def __init__(self, config: EvaluatorConfig) -> None: @@ -65,6 +32,27 @@ class PHQ8Evaluator: temperature=self.config.temperature, ) + async def _get_prompt_template(self) -> str: + """ + Fetch the prompt template from the database based on configuration. + Falls back to the hardcoded default if not found. + """ + query = Prompt.find(Prompt.name == self.config.prompt_id) + + if self.config.prompt_version is not None: + query = query.find(Prompt.version == self.config.prompt_version) + prompt_doc = await query.first_or_none() + else: + prompt_doc = await query.sort("-version").first_or_none() + + if prompt_doc: + return prompt_doc.template + + raise ValueError( + f"Prompt '{self.config.prompt_id}' " + f"version '{self.config.prompt_version}' not found in database." + ) + async def evaluate( self, transcript: Transcript, original_run_config: RunConfig ) -> AssessmentResult: @@ -73,10 +61,9 @@ class PHQ8Evaluator: """ transcript_text = "\n".join([f"{u.speaker}: {u.value}" for u in transcript.utterances]) - # 2. Prepare Prompt - final_prompt = DEFAULT_PROMPT.format(transcript_text=transcript_text) + template = await self._get_prompt_template() + final_prompt = template.format(transcript_text=transcript_text) - # 3. Call LLM (Async with Structured Output) structured_llm = self.llm.with_structured_output(AssessmentResponse) messages = [ @@ -87,7 +74,6 @@ class PHQ8Evaluator: response_obj = cast("AssessmentResponse", await structured_llm.ainvoke(messages)) items = response_obj.items - # 4. Calculate Diagnostics total_score = sum(item.score for item in items) diagnosis_cutpoint = total_score >= DIAGNOSIS_THRESHOLD @@ -97,7 +83,7 @@ class PHQ8Evaluator: count_severe = sum(1 for i in items if i.score >= SYMPTOM_SEVERITY_THRESHOLD) has_core = (items[0].score >= SYMPTOM_SEVERITY_THRESHOLD) or ( items[1].score >= SYMPTOM_SEVERITY_THRESHOLD - ) # Q1 or Q2 + ) diagnosis_algorithm = "None" if has_core: diff --git a/src/helia/assessment/schema.py b/src/helia/assessment/schema.py index 3199cca..ffc919b 100644 --- a/src/helia/assessment/schema.py +++ b/src/helia/assessment/schema.py @@ -24,6 +24,7 @@ class EvaluatorConfig(BaseModel): api_spec: str = "openai" temperature: float = 0.0 prompt_id: str = "default" + prompt_version: int | None = None class PHQ8Item(BaseModel): diff --git a/src/helia/configuration.py b/src/helia/configuration.py index ea7d0bb..9f05957 100644 --- a/src/helia/configuration.py +++ b/src/helia/configuration.py @@ -50,6 +50,7 @@ class RunConfig(BaseModel): model: ModelConfig prompt_id: str = "default" + prompt_version: int | None = None class HeliaConfig(BaseSettings): diff --git a/src/helia/db.py b/src/helia/db.py index be1b016..aa1ff18 100644 --- a/src/helia/db.py +++ b/src/helia/db.py @@ -6,6 +6,7 @@ from beanie import init_beanie from pymongo.asynchronous.mongo_client import AsyncMongoClient from helia.assessment.schema import AssessmentResult +from helia.models.prompt import Prompt from helia.models.transcript import Transcript if TYPE_CHECKING: @@ -17,6 +18,6 @@ async def init_db(config: HeliaConfig) -> AsyncMongoClient: database = client[config.mongo.db_name] await init_beanie( database=database, - document_models=[AssessmentResult, Transcript], + document_models=[AssessmentResult, Transcript, Prompt], ) return client diff --git a/src/helia/main.py b/src/helia/main.py index 4254471..f4743a0 100644 --- a/src/helia/main.py +++ b/src/helia/main.py @@ -46,6 +46,7 @@ async def process_run( api_key=provider_config.api_key, api_spec=provider_config.api_spec, prompt_id=run_config.prompt_id, + prompt_version=run_config.prompt_version, temperature=run_config.model.temperature, ) @@ -102,7 +103,7 @@ async def main() -> None: await init_db(config) logger.debug("Discovering transcripts in MongoDB...") - limit=args.limit or config.patient_limit + limit = args.limit or config.patient_limit query = Transcript.find_all(limit=limit) if limit is not None: diff --git a/src/helia/models/prompt.py b/src/helia/models/prompt.py new file mode 100644 index 0000000..2ac32a6 --- /dev/null +++ b/src/helia/models/prompt.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Annotated, Never + +from beanie import ( + Document, + Indexed, + PydanticObjectId, + Replace, + Save, + SaveChanges, + Update, + before_event, +) +from pydantic import Field + + +class Prompt(Document): + name: str + template: Annotated[str, Indexed(unique=True)] + input_variables: set[str] + parent: Annotated[PydanticObjectId | None, Indexed(sparse=True)] = None + children: set[Annotated[PydanticObjectId | None, Indexed(sparse=True)]] = Field( + default_factory=set + ) + + class Settings: + name = "prompts" + validate_on_save = True + + @before_event([Save, Replace, Update, SaveChanges]) + def check_immutability(self) -> Never: + raise ValueError("Prompts are immutable.") diff --git a/src/helia/models/transcript.py b/src/helia/models/transcript.py index a1a8a03..597c0e0 100644 --- a/src/helia/models/transcript.py +++ b/src/helia/models/transcript.py @@ -1,8 +1,8 @@ -from typing import ClassVar, Literal +from typing import Annotated, Literal -from beanie import Document +from beanie import Document, Indexed from pydantic import BaseModel -from pymongo import ASCENDING +from pymongo import ASCENDING, DESCENDING, TEXT, IndexModel class Utterance(BaseModel): @@ -21,50 +21,20 @@ class Turn(BaseModel): class Transcript(Document): - transcript_id: str - utterances: list[Utterance] - - @property - def turns(self) -> list[Turn]: - """ - Aggregates consecutive utterances from the same speaker into a single Turn. - """ - if not self.utterances: - return [] - - turns: list[Turn] = [] - current_batch: list[Utterance] = [] - - for utterance in self.utterances: - if not current_batch: - current_batch.append(utterance) - continue - - if utterance.speaker == current_batch[-1].speaker: - current_batch.append(utterance) - else: - turns.append(self._create_turn(current_batch)) - current_batch = [utterance] - - if current_batch: - turns.append(self._create_turn(current_batch)) - - return turns - - def _create_turn(self, batch: list[Utterance]) -> Turn: - return Turn( - speaker=batch[0].speaker, - value=" ".join(u.value for u in batch), - start_time=batch[0].start_time, - end_time=batch[-1].end_time, - utterance_count=len(batch), - ) + transcript_id: Annotated[str, Indexed(index_type=ASCENDING, unique=True)] + utterances: Annotated[list[Utterance], Indexed(unique=True)] class Settings: name = "transcripts" - indexes: ClassVar = [ - [("transcript_id", ASCENDING)], - ] - index_defs: ClassVar = [ - {"key": "transcript_id", "unique": True}, - ] + indexes = ( + IndexModel([("utterances.value", TEXT)]), + IndexModel( + [ + ("utterances.value", DESCENDING), + ("utterances.start_time", DESCENDING), + ("utterances.end_time", DESCENDING), + ], + unique=True, + ), + ) + validate_on_save = True