import json from typing import Dict, List, Set from fastapi import WebSocket from sqlalchemy.orm import Session from app.models.user import User from app.models.chat_member import ChatMember from app.core.deps import get_user_from_token from app.db.session import SessionLocal class ConnectionManager: def __init__(self): # user_id -> list of websocket connections (multiple devices/tabs) self.active_connections: Dict[int, List[WebSocket]] = {} # chat_id -> set of user_ids self.chat_members: Dict[int, Set[int]] = {} # websocket -> user_id mapping self.connection_user_map: Dict[WebSocket, int] = {} async def connect(self, websocket: WebSocket, token: str): """Accept websocket connection and authenticate user""" await websocket.accept() # Get database session db = SessionLocal() try: # Authenticate user user = await get_user_from_token(token, db) if not user: await websocket.send_text(json.dumps({ "type": "auth_error", "message": "Authentication failed" })) await websocket.close() return None # Update user online status user.is_online = True db.commit() # Store connection if user.id not in self.active_connections: self.active_connections[user.id] = [] self.active_connections[user.id].append(websocket) self.connection_user_map[websocket] = user.id # Load user's chat memberships await self._load_user_chats(user.id, db) # Send connection success await websocket.send_text(json.dumps({ "type": "connected", "user_id": user.id, "message": "Connected successfully" })) # Notify other users in chats that this user is online await self._broadcast_user_status(user.id, "online", db) return user.id finally: db.close() async def disconnect(self, websocket: WebSocket): """Handle websocket disconnection""" if websocket not in self.connection_user_map: return user_id = self.connection_user_map[websocket] # Remove connection if user_id in self.active_connections: self.active_connections[user_id].remove(websocket) if not self.active_connections[user_id]: del self.active_connections[user_id] # Update user offline status if no active connections db = SessionLocal() try: user = db.query(User).filter(User.id == user_id).first() if user: user.is_online = False from datetime import datetime user.last_seen = datetime.utcnow() db.commit() # Notify other users that this user is offline await self._broadcast_user_status(user_id, "offline", db) finally: db.close() del self.connection_user_map[websocket] async def send_personal_message(self, message: dict, user_id: int): """Send message to specific user (all their connections)""" if user_id in self.active_connections: disconnected = [] for websocket in self.active_connections[user_id]: try: await websocket.send_text(json.dumps(message)) except Exception: disconnected.append(websocket) # Clean up disconnected websockets for ws in disconnected: if ws in self.connection_user_map: await self.disconnect(ws) async def send_to_chat(self, message: dict, chat_id: int, exclude_user_id: int = None): """Send message to all members of a chat""" if chat_id in self.chat_members: for user_id in self.chat_members[chat_id]: if exclude_user_id and user_id == exclude_user_id: continue await self.send_personal_message(message, user_id) async def broadcast(self, message: dict): """Broadcast message to all connected users""" disconnected = [] for user_id, connections in self.active_connections.items(): for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception: disconnected.append(websocket) # Clean up disconnected websockets for ws in disconnected: if ws in self.connection_user_map: await self.disconnect(ws) async def _load_user_chats(self, user_id: int, db: Session): """Load all chats for a user into memory""" chat_members = db.query(ChatMember).filter(ChatMember.user_id == user_id).all() for member in chat_members: chat_id = member.chat_id if chat_id not in self.chat_members: self.chat_members[chat_id] = set() self.chat_members[chat_id].add(user_id) async def _broadcast_user_status(self, user_id: int, status: str, db: Session): """Broadcast user online/offline status to relevant chats""" chat_members = db.query(ChatMember).filter(ChatMember.user_id == user_id).all() user = db.query(User).filter(User.id == user_id).first() status_message = { "type": "user_status", "user_id": user_id, "username": user.username if user else "Unknown", "status": status, "last_seen": user.last_seen.isoformat() if user and user.last_seen else None } for member in chat_members: await self.send_to_chat(status_message, member.chat_id, exclude_user_id=user_id) def add_user_to_chat(self, user_id: int, chat_id: int): """Add user to chat members tracking""" if chat_id not in self.chat_members: self.chat_members[chat_id] = set() self.chat_members[chat_id].add(user_id) def remove_user_from_chat(self, user_id: int, chat_id: int): """Remove user from chat members tracking""" if chat_id in self.chat_members: self.chat_members[chat_id].discard(user_id) if not self.chat_members[chat_id]: del self.chat_members[chat_id] # Global connection manager instance connection_manager = ConnectionManager()