2025-05-21 12:49:05 +00:00

70 lines
2.0 KiB
Python

from typing import List, Optional
from sqlalchemy.orm import Session
from app.core.security import generate_uuid
from app.db.repositories.base import BaseRepository
from app.models.message import Message
from app.schemas.message import MessageCreate, MessageUpdate
class MessageRepository(BaseRepository[Message, MessageCreate, MessageUpdate]):
def create_with_sender(
self, db: Session, *, obj_in: MessageCreate, sender_id: str
) -> Message:
message_id = generate_uuid()
db_obj = Message(
id=message_id,
content=obj_in.content,
sender_id=sender_id,
recipient_id=obj_in.recipient_id,
conversation_id=obj_in.conversation_id,
is_read=False
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
def get_conversation_messages(
self, db: Session, *, conversation_id: str, skip: int = 0, limit: int = 100
) -> List[Message]:
return (
db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
def mark_as_read(
self, db: Session, *, message_id: str, user_id: str
) -> Optional[Message]:
message = (
db.query(Message)
.filter(Message.id == message_id)
.filter(Message.recipient_id == user_id)
.first()
)
if message and not message.is_read:
message.is_read = True
db.add(message)
db.commit()
db.refresh(message)
return message
def get_unread_count(
self, db: Session, *, user_id: str
) -> int:
return (
db.query(Message)
.filter(Message.recipient_id == user_id)
.filter(not Message.is_read)
.count()
)
message_repository = MessageRepository(Message)