import json from typing import Dict, Any from datetime import datetime from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from sqlalchemy.orm import Session from jose import jwt, JWTError from app.core.config import settings from app.db.session import get_db from app.db.repositories.user import user_repository from app.db.repositories.message import message_repository from app.db.repositories.conversation import conversation_repository from app.schemas.message import MessageCreate router = APIRouter(tags=["websocket"]) # Store connected websocket clients # Map of user_id -> WebSocket connected_users: Dict[str, WebSocket] = {} async def get_user_from_token(token: str, db: Session) -> Any: try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) user_id = payload.get("sub") if user_id is None: return None user = user_repository.get(db, id=user_id) return user except JWTError: return None @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, db: Session = Depends(get_db)): await websocket.accept() user = None try: # Authenticate user auth_message = await websocket.receive_text() auth_data = json.loads(auth_message) token = auth_data.get("token") if not token: await websocket.send_json({"error": "Authentication required"}) await websocket.close() return user = await get_user_from_token(token, db) if not user: await websocket.send_json({"error": "Invalid authentication token"}) await websocket.close() return # Store the connection connected_users[user.id] = websocket # Notify user that connection is established await websocket.send_json({ "type": "connection_established", "user_id": user.id, "timestamp": datetime.utcnow().isoformat() }) # Send unread message count unread_count = message_repository.get_unread_count(db, user_id=user.id) await websocket.send_json({ "type": "unread_count", "count": unread_count, "timestamp": datetime.utcnow().isoformat() }) # Handle incoming messages while True: data = await websocket.receive_text() message_data = json.loads(data) # Expected message format: # { # "type": "message", # "conversation_id": "...", # "content": "...", # "recipient_id": "..." (optional, for direct messages) # } if message_data.get("type") == "message": conversation_id = message_data.get("conversation_id") content = message_data.get("content") recipient_id = message_data.get("recipient_id") if not conversation_id or not content: await websocket.send_json({ "type": "error", "message": "Missing required fields", "timestamp": datetime.utcnow().isoformat() }) continue # Check if conversation exists and user is a participant conversation = conversation_repository.get(db, id=conversation_id) if not conversation: await websocket.send_json({ "type": "error", "message": "Conversation not found", "timestamp": datetime.utcnow().isoformat() }) continue if user not in conversation.participants: await websocket.send_json({ "type": "error", "message": "Not a participant of this conversation", "timestamp": datetime.utcnow().isoformat() }) continue # Create message message_in = MessageCreate( content=content, conversation_id=conversation_id, recipient_id=recipient_id ) message = message_repository.create_with_sender( db, obj_in=message_in, sender_id=user.id ) # Prepare message data for sending to clients message_out = { "type": "new_message", "message": { "id": message.id, "content": message.content, "sender_id": message.sender_id, "recipient_id": message.recipient_id, "conversation_id": message.conversation_id, "is_read": message.is_read, "created_at": message.created_at.isoformat(), "updated_at": message.updated_at.isoformat(), }, "timestamp": datetime.utcnow().isoformat() } # Send message to the sender for confirmation await websocket.send_json(message_out) # Send message to other participants who are connected for participant in conversation.participants: if participant.id != user.id and participant.id in connected_users: try: await connected_users[participant.id].send_json(message_out) except Exception: # Handle errors in sending to a particular client pass elif message_data.get("type") == "mark_read": message_id = message_data.get("message_id") if not message_id: await websocket.send_json({ "type": "error", "message": "Missing message_id field", "timestamp": datetime.utcnow().isoformat() }) continue message = message_repository.mark_as_read( db, message_id=message_id, user_id=user.id ) if not message: await websocket.send_json({ "type": "error", "message": "Message not found or you're not the recipient", "timestamp": datetime.utcnow().isoformat() }) continue # Notify user that message is marked as read await websocket.send_json({ "type": "message_read", "message_id": message_id, "timestamp": datetime.utcnow().isoformat() }) # Notify sender that message is read if they're connected if message.sender_id in connected_users: try: await connected_users[message.sender_id].send_json({ "type": "message_read_by_recipient", "message_id": message_id, "read_by": user.id, "timestamp": datetime.utcnow().isoformat() }) except Exception: # Handle errors in sending to a particular client pass elif message_data.get("type") == "typing": conversation_id = message_data.get("conversation_id") if not conversation_id: await websocket.send_json({ "type": "error", "message": "Missing conversation_id field", "timestamp": datetime.utcnow().isoformat() }) continue # Check if conversation exists and user is a participant conversation = conversation_repository.get(db, id=conversation_id) if not conversation or user not in conversation.participants: continue # Notify other participants that user is typing for participant in conversation.participants: if participant.id != user.id and participant.id in connected_users: try: await connected_users[participant.id].send_json({ "type": "user_typing", "user_id": user.id, "conversation_id": conversation_id, "timestamp": datetime.utcnow().isoformat() }) except Exception: # Handle errors in sending to a particular client pass except WebSocketDisconnect: # Remove user from connected users if user and user.id in connected_users: del connected_users[user.id] except Exception: # Handle any other exceptions if user and user.id in connected_users: del connected_users[user.id]