from typing import List, Optional from sqlalchemy.orm import Session from app.core.security import generate_uuid from app.db.repositories.base import BaseRepository from app.models.message import Message from app.schemas.message import MessageCreate, MessageUpdate class MessageRepository(BaseRepository[Message, MessageCreate, MessageUpdate]): def create_with_sender( self, db: Session, *, obj_in: MessageCreate, sender_id: str ) -> Message: message_id = generate_uuid() db_obj = Message( id=message_id, content=obj_in.content, sender_id=sender_id, recipient_id=obj_in.recipient_id, conversation_id=obj_in.conversation_id, is_read=False ) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def get_conversation_messages( self, db: Session, *, conversation_id: str, skip: int = 0, limit: int = 100 ) -> List[Message]: return ( db.query(Message) .filter(Message.conversation_id == conversation_id) .order_by(Message.created_at.desc()) .offset(skip) .limit(limit) .all() ) def mark_as_read( self, db: Session, *, message_id: str, user_id: str ) -> Optional[Message]: message = ( db.query(Message) .filter(Message.id == message_id) .filter(Message.recipient_id == user_id) .first() ) if message and not message.is_read: message.is_read = True db.add(message) db.commit() db.refresh(message) return message def get_unread_count( self, db: Session, *, user_id: str ) -> int: return ( db.query(Message) .filter(Message.recipient_id == user_id) .filter(not Message.is_read) .count() ) message_repository = MessageRepository(Message)