WIP
This commit is contained in:
69
migrations/init_risen_prompts.py
Normal file
69
migrations/init_risen_prompts.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from helia.configuration import HeliaConfig
|
||||
from helia.db import init_db
|
||||
from helia.models.prompt import Prompt
|
||||
from helia.prompts.assessment import (
|
||||
EXTRACT_EVIDENCE_PROMPT,
|
||||
MAP_CRITERIA_PROMPT,
|
||||
SCORE_ITEM_PROMPT,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def migrate() -> None:
|
||||
try:
|
||||
config = HeliaConfig() # ty:ignore[missing-argument]
|
||||
except Exception:
|
||||
logger.exception("Failed to load configuration.")
|
||||
return
|
||||
|
||||
logger.info("Connecting to database...")
|
||||
await init_db(config)
|
||||
|
||||
prompts_to_create = [
|
||||
{
|
||||
"name": "phq8-extract",
|
||||
"template": EXTRACT_EVIDENCE_PROMPT,
|
||||
"input_variables": ["symptom_name", "symptom_description", "transcript_text"],
|
||||
},
|
||||
{
|
||||
"name": "phq8-map",
|
||||
"template": MAP_CRITERIA_PROMPT,
|
||||
"input_variables": ["symptom_name", "evidence_text"],
|
||||
},
|
||||
{
|
||||
"name": "phq8-score",
|
||||
"template": SCORE_ITEM_PROMPT,
|
||||
"input_variables": ["symptom_name", "reasoning_text"],
|
||||
},
|
||||
]
|
||||
|
||||
for p_data in prompts_to_create:
|
||||
name = p_data["name"]
|
||||
logger.info("Creating or updating prompt '%s'...", name)
|
||||
|
||||
# Check if exists to avoid duplicates or update if needed
|
||||
existing = await Prompt.find_one(Prompt.name == name)
|
||||
if existing:
|
||||
logger.info("Prompt '%s' already exists. Updating template.", name)
|
||||
existing.template = p_data["template"]
|
||||
existing.input_variables = p_data["input_variables"]
|
||||
await existing.save()
|
||||
else:
|
||||
new_prompt = Prompt(
|
||||
name=name,
|
||||
template=p_data["template"],
|
||||
input_variables=p_data["input_variables"],
|
||||
)
|
||||
await new_prompt.insert()
|
||||
logger.info("Prompt '%s' created.", name)
|
||||
|
||||
logger.info("Migration completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(migrate())
|
||||
Reference in New Issue
Block a user