2025-05-21 12:49:05 +00:00

236 lines
9.5 KiB
Python

import json
from typing import Dict, Any
from datetime import datetime
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from sqlalchemy.orm import Session
from jose import jwt, JWTError
from app.core.config import settings
from app.db.session import get_db
from app.db.repositories.user import user_repository
from app.db.repositories.message import message_repository
from app.db.repositories.conversation import conversation_repository
from app.schemas.message import MessageCreate
router = APIRouter(tags=["websocket"])
# Store connected websocket clients
# Map of user_id -> WebSocket
connected_users: Dict[str, WebSocket] = {}
async def get_user_from_token(token: str, db: Session) -> Any:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("sub")
if user_id is None:
return None
user = user_repository.get(db, id=user_id)
return user
except JWTError:
return None
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
await websocket.accept()
user = None
try:
# Authenticate user
auth_message = await websocket.receive_text()
auth_data = json.loads(auth_message)
token = auth_data.get("token")
if not token:
await websocket.send_json({"error": "Authentication required"})
await websocket.close()
return
user = await get_user_from_token(token, db)
if not user:
await websocket.send_json({"error": "Invalid authentication token"})
await websocket.close()
return
# Store the connection
connected_users[user.id] = websocket
# Notify user that connection is established
await websocket.send_json({
"type": "connection_established",
"user_id": user.id,
"timestamp": datetime.utcnow().isoformat()
})
# Send unread message count
unread_count = message_repository.get_unread_count(db, user_id=user.id)
await websocket.send_json({
"type": "unread_count",
"count": unread_count,
"timestamp": datetime.utcnow().isoformat()
})
# Handle incoming messages
while True:
data = await websocket.receive_text()
message_data = json.loads(data)
# Expected message format:
# {
# "type": "message",
# "conversation_id": "...",
# "content": "...",
# "recipient_id": "..." (optional, for direct messages)
# }
if message_data.get("type") == "message":
conversation_id = message_data.get("conversation_id")
content = message_data.get("content")
recipient_id = message_data.get("recipient_id")
if not conversation_id or not content:
await websocket.send_json({
"type": "error",
"message": "Missing required fields",
"timestamp": datetime.utcnow().isoformat()
})
continue
# Check if conversation exists and user is a participant
conversation = conversation_repository.get(db, id=conversation_id)
if not conversation:
await websocket.send_json({
"type": "error",
"message": "Conversation not found",
"timestamp": datetime.utcnow().isoformat()
})
continue
if user not in conversation.participants:
await websocket.send_json({
"type": "error",
"message": "Not a participant of this conversation",
"timestamp": datetime.utcnow().isoformat()
})
continue
# Create message
message_in = MessageCreate(
content=content,
conversation_id=conversation_id,
recipient_id=recipient_id
)
message = message_repository.create_with_sender(
db, obj_in=message_in, sender_id=user.id
)
# Prepare message data for sending to clients
message_out = {
"type": "new_message",
"message": {
"id": message.id,
"content": message.content,
"sender_id": message.sender_id,
"recipient_id": message.recipient_id,
"conversation_id": message.conversation_id,
"is_read": message.is_read,
"created_at": message.created_at.isoformat(),
"updated_at": message.updated_at.isoformat(),
},
"timestamp": datetime.utcnow().isoformat()
}
# Send message to the sender for confirmation
await websocket.send_json(message_out)
# Send message to other participants who are connected
for participant in conversation.participants:
if participant.id != user.id and participant.id in connected_users:
try:
await connected_users[participant.id].send_json(message_out)
except Exception:
# Handle errors in sending to a particular client
pass
elif message_data.get("type") == "mark_read":
message_id = message_data.get("message_id")
if not message_id:
await websocket.send_json({
"type": "error",
"message": "Missing message_id field",
"timestamp": datetime.utcnow().isoformat()
})
continue
message = message_repository.mark_as_read(
db, message_id=message_id, user_id=user.id
)
if not message:
await websocket.send_json({
"type": "error",
"message": "Message not found or you're not the recipient",
"timestamp": datetime.utcnow().isoformat()
})
continue
# Notify user that message is marked as read
await websocket.send_json({
"type": "message_read",
"message_id": message_id,
"timestamp": datetime.utcnow().isoformat()
})
# Notify sender that message is read if they're connected
if message.sender_id in connected_users:
try:
await connected_users[message.sender_id].send_json({
"type": "message_read_by_recipient",
"message_id": message_id,
"read_by": user.id,
"timestamp": datetime.utcnow().isoformat()
})
except Exception:
# Handle errors in sending to a particular client
pass
elif message_data.get("type") == "typing":
conversation_id = message_data.get("conversation_id")
if not conversation_id:
await websocket.send_json({
"type": "error",
"message": "Missing conversation_id field",
"timestamp": datetime.utcnow().isoformat()
})
continue
# Check if conversation exists and user is a participant
conversation = conversation_repository.get(db, id=conversation_id)
if not conversation or user not in conversation.participants:
continue
# Notify other participants that user is typing
for participant in conversation.participants:
if participant.id != user.id and participant.id in connected_users:
try:
await connected_users[participant.id].send_json({
"type": "user_typing",
"user_id": user.id,
"conversation_id": conversation_id,
"timestamp": datetime.utcnow().isoformat()
})
except Exception:
# Handle errors in sending to a particular client
pass
except WebSocketDisconnect:
# Remove user from connected users
if user and user.id in connected_users:
del connected_users[user.id]
except Exception:
# Handle any other exceptions
if user and user.id in connected_users:
del connected_users[user.id]