74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
from typing import List, Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.security import generate_uuid
|
|
from app.db.repositories.base import BaseRepository
|
|
from app.models.conversation import Conversation, conversation_participants
|
|
from app.models.user import User
|
|
from app.schemas.conversation import ConversationCreate, ConversationUpdate
|
|
|
|
class ConversationRepository(BaseRepository[Conversation, ConversationCreate, ConversationUpdate]):
|
|
def create_with_participants(
|
|
self, db: Session, *, obj_in: ConversationCreate, creator_id: str
|
|
) -> Conversation:
|
|
conversation_id = generate_uuid()
|
|
|
|
# Create conversation
|
|
db_obj = Conversation(
|
|
id=conversation_id,
|
|
name=obj_in.name,
|
|
is_group=obj_in.is_group
|
|
)
|
|
|
|
db.add(db_obj)
|
|
db.flush()
|
|
|
|
# Add creator to participants if not already included
|
|
participant_ids = set(obj_in.participant_ids)
|
|
if creator_id not in participant_ids:
|
|
participant_ids.add(creator_id)
|
|
|
|
# Add participants
|
|
for user_id in participant_ids:
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if user:
|
|
db_obj.participants.append(user)
|
|
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
def get_user_conversations(
|
|
self, db: Session, *, user_id: str, skip: int = 0, limit: int = 100
|
|
) -> List[Conversation]:
|
|
return (
|
|
db.query(Conversation)
|
|
.join(conversation_participants)
|
|
.filter(conversation_participants.c.user_id == user_id)
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
def get_conversation_between_users(
|
|
self, db: Session, *, user_id_1: str, user_id_2: str
|
|
) -> Optional[Conversation]:
|
|
# Find direct conversations (not groups) where both users are participants
|
|
conversations = (
|
|
db.query(Conversation)
|
|
.filter(not Conversation.is_group)
|
|
.join(conversation_participants, Conversation.id == conversation_participants.c.conversation_id)
|
|
.filter(conversation_participants.c.user_id.in_([user_id_1, user_id_2]))
|
|
.all()
|
|
)
|
|
|
|
# Check which conversations have exactly these two users
|
|
for conversation in conversations:
|
|
participant_ids = [participant.id for participant in conversation.participants]
|
|
if set(participant_ids) == {user_id_1, user_id_2}:
|
|
return conversation
|
|
|
|
return None
|
|
|
|
conversation_repository = ConversationRepository(Conversation) |