70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
from typing import List, Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.security import generate_uuid
|
|
from app.db.repositories.base import BaseRepository
|
|
from app.models.message import Message
|
|
from app.schemas.message import MessageCreate, MessageUpdate
|
|
|
|
class MessageRepository(BaseRepository[Message, MessageCreate, MessageUpdate]):
|
|
def create_with_sender(
|
|
self, db: Session, *, obj_in: MessageCreate, sender_id: str
|
|
) -> Message:
|
|
message_id = generate_uuid()
|
|
|
|
db_obj = Message(
|
|
id=message_id,
|
|
content=obj_in.content,
|
|
sender_id=sender_id,
|
|
recipient_id=obj_in.recipient_id,
|
|
conversation_id=obj_in.conversation_id,
|
|
is_read=False
|
|
)
|
|
|
|
db.add(db_obj)
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
def get_conversation_messages(
|
|
self, db: Session, *, conversation_id: str, skip: int = 0, limit: int = 100
|
|
) -> List[Message]:
|
|
return (
|
|
db.query(Message)
|
|
.filter(Message.conversation_id == conversation_id)
|
|
.order_by(Message.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
def mark_as_read(
|
|
self, db: Session, *, message_id: str, user_id: str
|
|
) -> Optional[Message]:
|
|
message = (
|
|
db.query(Message)
|
|
.filter(Message.id == message_id)
|
|
.filter(Message.recipient_id == user_id)
|
|
.first()
|
|
)
|
|
|
|
if message and not message.is_read:
|
|
message.is_read = True
|
|
db.add(message)
|
|
db.commit()
|
|
db.refresh(message)
|
|
|
|
return message
|
|
|
|
def get_unread_count(
|
|
self, db: Session, *, user_id: str
|
|
) -> int:
|
|
return (
|
|
db.query(Message)
|
|
.filter(Message.recipient_id == user_id)
|
|
.filter(not Message.is_read)
|
|
.count()
|
|
)
|
|
|
|
message_repository = MessageRepository(Message) |