2025-05-17 09:49:47 +00:00

166 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM service for converting natural language to structured task data.
"""
import json
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
from app.core.config import settings
logger = logging.getLogger(__name__)
class LLMService(ABC):
"""Abstract base class for LLM service implementations."""
@abstractmethod
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
"""
Convert natural language input to structured task objects.
Must return a list of task dictionaries.
"""
pass
def extract_json_from_response(text: str) -> Union[Dict, List]:
"""Extract valid JSON from possibly markdown-wrapped LLM responses."""
try:
if "```json" in text:
text = text.split("```json")[1].split("```")[0].strip()
elif "```" in text:
text = text.split("```")[1].strip()
return json.loads(text)
except Exception as e:
logger.error(f"Failed to parse JSON: {e}\nRaw response: {text}")
raise
class OpenAIService(LLMService):
"""OpenAI implementation of the LLM service."""
def __init__(self):
try:
import openai
self.client = openai.AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
self.model = settings.OPENAI_MODEL
except (ImportError, AttributeError) as e:
logger.error(f"OpenAI service initialization failed: {e}")
raise RuntimeError("OpenAI client setup failed.")
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
system_prompt = (
"You are a task extraction assistant. Convert the user's message into structured task objects. "
"Each task must include:\n"
"- title: short title\n"
"- description: detailed description\n"
"- due_date: ISO 8601 date (YYYY-MM-DD) or null\n"
"- priority: high, medium, or low\n"
"- status: set to \"pending\"\n\n"
"Respond ONLY with a JSON object in the format:\n"
"{ \"tasks\": [ { ... }, { ... } ] }\n"
"No extra commentary or text."
)
try:
response = await self.client.chat.completions.create(
model=self.model,
temperature=0.2,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
)
raw = response.choices[0].message.content.strip()
result = extract_json_from_response(raw)
# Expect a dict with a "tasks" key
if isinstance(result, dict) and "tasks" in result and isinstance(result["tasks"], list):
return result["tasks"]
raise ValueError("Missing or invalid 'tasks' key in response.")
except Exception as e:
logger.error(f"OpenAI task extraction failed: {e}")
raise RuntimeError("Failed to extract tasks from OpenAI response.")
class GeminiService(LLMService):
"""Google Gemini implementation of the LLM service."""
def __init__(self):
try:
import google.generativeai as genai
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL)
except (ImportError, AttributeError) as e:
logger.error(f"Gemini service initialization failed: {e}")
raise RuntimeError("Gemini client setup failed.")
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
# Note: Gemini returns a JSON LIST of tasks, not wrapped in an object.
system_prompt = (
"You are a task extraction assistant. Convert the user's message into structured task objects. "
"Each task must include:\n"
"- title: short title\n"
"- description: detailed description\n"
"- due_date: ISO 8601 date (YYYY-MM-DD) or null\n"
"- priority: high, medium, or low\n"
"- status: set to \"pending\"\n\n"
"Return ONLY a JSON array like this:\n"
"[ { ... }, { ... } ]\n"
"No explanations, no markdown, no formatting just pure JSON."
)
try:
chat = self.model.start_chat(history=[
{"role": "user", "parts": [system_prompt]}
])
response = await chat.send_message_async(prompt)
raw = response.text.strip()
result = extract_json_from_response(raw)
# Expect a LIST of task dicts directly
if isinstance(result, list):
return result
raise ValueError("Expected a JSON list of tasks from Gemini response.")
except Exception as e:
logger.error(f"Gemini task extraction failed: {e}")
raise RuntimeError("Failed to extract tasks from Gemini response.")
class MockLLMService(LLMService):
"""Mock LLM service for testing without external API calls."""
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
words = prompt.lower().split()
priority = "medium"
if any(word in words for word in ["urgent", "important"]):
priority = "high"
elif any(word in words for word in ["minor", "low"]):
priority = "low"
return [{
"title": (prompt[:50] + "...") if len(prompt) > 50 else prompt,
"description": prompt,
"due_date": None,
"priority": priority,
"status": "pending"
}]
def get_llm_service() -> LLMService:
"""Factory to return the appropriate LLM service based on settings."""
llm_provider = settings.LLM_PROVIDER.lower()
if llm_provider == "openai" and settings.OPENAI_API_KEY:
return OpenAIService()
elif llm_provider == "gemini" and settings.GEMINI_API_KEY:
return GeminiService()
elif llm_provider == "mock" or settings.environment == "test":
return MockLLMService()
else:
logger.warning(
f"LLM provider '{llm_provider}' is not properly configured. Falling back to mock."
)
return MockLLMService()