166 lines
6.1 KiB
Python
166 lines
6.1 KiB
Python
"""
|
||
LLM service for converting natural language to structured task data.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMService(ABC):
|
||
"""Abstract base class for LLM service implementations."""
|
||
|
||
@abstractmethod
|
||
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
|
||
"""
|
||
Convert natural language input to structured task objects.
|
||
Must return a list of task dictionaries.
|
||
"""
|
||
pass
|
||
|
||
|
||
def extract_json_from_response(text: str) -> Union[Dict, List]:
|
||
"""Extract valid JSON from possibly markdown-wrapped LLM responses."""
|
||
try:
|
||
if "```json" in text:
|
||
text = text.split("```json")[1].split("```")[0].strip()
|
||
elif "```" in text:
|
||
text = text.split("```")[1].strip()
|
||
return json.loads(text)
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse JSON: {e}\nRaw response: {text}")
|
||
raise
|
||
|
||
|
||
class OpenAIService(LLMService):
|
||
"""OpenAI implementation of the LLM service."""
|
||
|
||
def __init__(self):
|
||
try:
|
||
import openai
|
||
self.client = openai.AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||
self.model = settings.OPENAI_MODEL
|
||
except (ImportError, AttributeError) as e:
|
||
logger.error(f"OpenAI service initialization failed: {e}")
|
||
raise RuntimeError("OpenAI client setup failed.")
|
||
|
||
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
|
||
system_prompt = (
|
||
"You are a task extraction assistant. Convert the user's message into structured task objects. "
|
||
"Each task must include:\n"
|
||
"- title: short title\n"
|
||
"- description: detailed description\n"
|
||
"- due_date: ISO 8601 date (YYYY-MM-DD) or null\n"
|
||
"- priority: high, medium, or low\n"
|
||
"- status: set to \"pending\"\n\n"
|
||
"Respond ONLY with a JSON object in the format:\n"
|
||
"{ \"tasks\": [ { ... }, { ... } ] }\n"
|
||
"No extra commentary or text."
|
||
)
|
||
try:
|
||
response = await self.client.chat.completions.create(
|
||
model=self.model,
|
||
temperature=0.2,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": prompt},
|
||
]
|
||
)
|
||
raw = response.choices[0].message.content.strip()
|
||
result = extract_json_from_response(raw)
|
||
|
||
# Expect a dict with a "tasks" key
|
||
if isinstance(result, dict) and "tasks" in result and isinstance(result["tasks"], list):
|
||
return result["tasks"]
|
||
raise ValueError("Missing or invalid 'tasks' key in response.")
|
||
|
||
except Exception as e:
|
||
logger.error(f"OpenAI task extraction failed: {e}")
|
||
raise RuntimeError("Failed to extract tasks from OpenAI response.")
|
||
|
||
|
||
class GeminiService(LLMService):
|
||
"""Google Gemini implementation of the LLM service."""
|
||
|
||
def __init__(self):
|
||
try:
|
||
import google.generativeai as genai
|
||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||
self.model = genai.GenerativeModel(settings.GEMINI_MODEL)
|
||
except (ImportError, AttributeError) as e:
|
||
logger.error(f"Gemini service initialization failed: {e}")
|
||
raise RuntimeError("Gemini client setup failed.")
|
||
|
||
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
|
||
# Note: Gemini returns a JSON LIST of tasks, not wrapped in an object.
|
||
system_prompt = (
|
||
"You are a task extraction assistant. Convert the user's message into structured task objects. "
|
||
"Each task must include:\n"
|
||
"- title: short title\n"
|
||
"- description: detailed description\n"
|
||
"- due_date: ISO 8601 date (YYYY-MM-DD) or null\n"
|
||
"- priority: high, medium, or low\n"
|
||
"- status: set to \"pending\"\n\n"
|
||
"Return ONLY a JSON array like this:\n"
|
||
"[ { ... }, { ... } ]\n"
|
||
"No explanations, no markdown, no formatting – just pure JSON."
|
||
)
|
||
try:
|
||
chat = self.model.start_chat(history=[
|
||
{"role": "user", "parts": [system_prompt]}
|
||
])
|
||
response = await chat.send_message_async(prompt)
|
||
raw = response.text.strip()
|
||
result = extract_json_from_response(raw)
|
||
|
||
# Expect a LIST of task dicts directly
|
||
if isinstance(result, list):
|
||
return result
|
||
raise ValueError("Expected a JSON list of tasks from Gemini response.")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Gemini task extraction failed: {e}")
|
||
raise RuntimeError("Failed to extract tasks from Gemini response.")
|
||
|
||
|
||
class MockLLMService(LLMService):
|
||
"""Mock LLM service for testing without external API calls."""
|
||
|
||
async def chat_to_tasks(self, prompt: str) -> List[Dict]:
|
||
words = prompt.lower().split()
|
||
priority = "medium"
|
||
if any(word in words for word in ["urgent", "important"]):
|
||
priority = "high"
|
||
elif any(word in words for word in ["minor", "low"]):
|
||
priority = "low"
|
||
|
||
return [{
|
||
"title": (prompt[:50] + "...") if len(prompt) > 50 else prompt,
|
||
"description": prompt,
|
||
"due_date": None,
|
||
"priority": priority,
|
||
"status": "pending"
|
||
}]
|
||
|
||
|
||
def get_llm_service() -> LLMService:
|
||
"""Factory to return the appropriate LLM service based on settings."""
|
||
llm_provider = settings.LLM_PROVIDER.lower()
|
||
|
||
if llm_provider == "openai" and settings.OPENAI_API_KEY:
|
||
return OpenAIService()
|
||
elif llm_provider == "gemini" and settings.GEMINI_API_KEY:
|
||
return GeminiService()
|
||
elif llm_provider == "mock" or settings.environment == "test":
|
||
return MockLLMService()
|
||
else:
|
||
logger.warning(
|
||
f"LLM provider '{llm_provider}' is not properly configured. Falling back to mock."
|
||
)
|
||
return MockLLMService()
|