236 lines
9.5 KiB
Python
236 lines
9.5 KiB
Python
import json
|
|
from typing import Dict, Any
|
|
from datetime import datetime
|
|
|
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
from sqlalchemy.orm import Session
|
|
from jose import jwt, JWTError
|
|
|
|
from app.core.config import settings
|
|
from app.db.session import get_db
|
|
from app.db.repositories.user import user_repository
|
|
from app.db.repositories.message import message_repository
|
|
from app.db.repositories.conversation import conversation_repository
|
|
from app.schemas.message import MessageCreate
|
|
|
|
router = APIRouter(tags=["websocket"])
|
|
|
|
# Store connected websocket clients
|
|
# Map of user_id -> WebSocket
|
|
connected_users: Dict[str, WebSocket] = {}
|
|
|
|
async def get_user_from_token(token: str, db: Session) -> Any:
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
return None
|
|
user = user_repository.get(db, id=user_id)
|
|
return user
|
|
except JWTError:
|
|
return None
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
|
|
await websocket.accept()
|
|
user = None
|
|
|
|
try:
|
|
# Authenticate user
|
|
auth_message = await websocket.receive_text()
|
|
auth_data = json.loads(auth_message)
|
|
token = auth_data.get("token")
|
|
|
|
if not token:
|
|
await websocket.send_json({"error": "Authentication required"})
|
|
await websocket.close()
|
|
return
|
|
|
|
user = await get_user_from_token(token, db)
|
|
if not user:
|
|
await websocket.send_json({"error": "Invalid authentication token"})
|
|
await websocket.close()
|
|
return
|
|
|
|
# Store the connection
|
|
connected_users[user.id] = websocket
|
|
|
|
# Notify user that connection is established
|
|
await websocket.send_json({
|
|
"type": "connection_established",
|
|
"user_id": user.id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
# Send unread message count
|
|
unread_count = message_repository.get_unread_count(db, user_id=user.id)
|
|
await websocket.send_json({
|
|
"type": "unread_count",
|
|
"count": unread_count,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
# Handle incoming messages
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
message_data = json.loads(data)
|
|
|
|
# Expected message format:
|
|
# {
|
|
# "type": "message",
|
|
# "conversation_id": "...",
|
|
# "content": "...",
|
|
# "recipient_id": "..." (optional, for direct messages)
|
|
# }
|
|
|
|
if message_data.get("type") == "message":
|
|
conversation_id = message_data.get("conversation_id")
|
|
content = message_data.get("content")
|
|
recipient_id = message_data.get("recipient_id")
|
|
|
|
if not conversation_id or not content:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Missing required fields",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
continue
|
|
|
|
# Check if conversation exists and user is a participant
|
|
conversation = conversation_repository.get(db, id=conversation_id)
|
|
if not conversation:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Conversation not found",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
continue
|
|
|
|
if user not in conversation.participants:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Not a participant of this conversation",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
continue
|
|
|
|
# Create message
|
|
message_in = MessageCreate(
|
|
content=content,
|
|
conversation_id=conversation_id,
|
|
recipient_id=recipient_id
|
|
)
|
|
|
|
message = message_repository.create_with_sender(
|
|
db, obj_in=message_in, sender_id=user.id
|
|
)
|
|
|
|
# Prepare message data for sending to clients
|
|
message_out = {
|
|
"type": "new_message",
|
|
"message": {
|
|
"id": message.id,
|
|
"content": message.content,
|
|
"sender_id": message.sender_id,
|
|
"recipient_id": message.recipient_id,
|
|
"conversation_id": message.conversation_id,
|
|
"is_read": message.is_read,
|
|
"created_at": message.created_at.isoformat(),
|
|
"updated_at": message.updated_at.isoformat(),
|
|
},
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# Send message to the sender for confirmation
|
|
await websocket.send_json(message_out)
|
|
|
|
# Send message to other participants who are connected
|
|
for participant in conversation.participants:
|
|
if participant.id != user.id and participant.id in connected_users:
|
|
try:
|
|
await connected_users[participant.id].send_json(message_out)
|
|
except Exception:
|
|
# Handle errors in sending to a particular client
|
|
pass
|
|
|
|
elif message_data.get("type") == "mark_read":
|
|
message_id = message_data.get("message_id")
|
|
|
|
if not message_id:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Missing message_id field",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
continue
|
|
|
|
message = message_repository.mark_as_read(
|
|
db, message_id=message_id, user_id=user.id
|
|
)
|
|
|
|
if not message:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Message not found or you're not the recipient",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
continue
|
|
|
|
# Notify user that message is marked as read
|
|
await websocket.send_json({
|
|
"type": "message_read",
|
|
"message_id": message_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
# Notify sender that message is read if they're connected
|
|
if message.sender_id in connected_users:
|
|
try:
|
|
await connected_users[message.sender_id].send_json({
|
|
"type": "message_read_by_recipient",
|
|
"message_id": message_id,
|
|
"read_by": user.id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
except Exception:
|
|
# Handle errors in sending to a particular client
|
|
pass
|
|
|
|
elif message_data.get("type") == "typing":
|
|
conversation_id = message_data.get("conversation_id")
|
|
|
|
if not conversation_id:
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"message": "Missing conversation_id field",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
continue
|
|
|
|
# Check if conversation exists and user is a participant
|
|
conversation = conversation_repository.get(db, id=conversation_id)
|
|
if not conversation or user not in conversation.participants:
|
|
continue
|
|
|
|
# Notify other participants that user is typing
|
|
for participant in conversation.participants:
|
|
if participant.id != user.id and participant.id in connected_users:
|
|
try:
|
|
await connected_users[participant.id].send_json({
|
|
"type": "user_typing",
|
|
"user_id": user.id,
|
|
"conversation_id": conversation_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
except Exception:
|
|
# Handle errors in sending to a particular client
|
|
pass
|
|
|
|
except WebSocketDisconnect:
|
|
# Remove user from connected users
|
|
if user and user.id in connected_users:
|
|
del connected_users[user.id]
|
|
except Exception:
|
|
# Handle any other exceptions
|
|
if user and user.id in connected_users:
|
|
del connected_users[user.id] |