Build comprehensive real-time chat API with advanced features
Complete rewrite from task management to full-featured chat system: Core Features: - Real-time WebSocket messaging with connection management - Direct messages and group chats with admin controls - Message types: text, images, videos, audio, documents - Message status tracking: sent, delivered, read receipts - Typing indicators and user presence (online/offline) - Message replies, editing, and deletion Security & Encryption: - End-to-end encryption with RSA + AES hybrid approach - JWT authentication for API and WebSocket connections - Secure file storage with access control - Automatic RSA key pair generation per user Media & File Sharing: - Multi-format file upload (images, videos, audio, documents) - Automatic thumbnail generation for images/videos - File size validation and MIME type checking - Secure download endpoints with permission checks Notifications & Alerts: - Real-time WebSocket notifications - Push notifications via Firebase integration - @username mention alerts with notification history - Unread message and mention counting - Custom notification types (message, mention, group invite) Advanced Features: - Group chat management with roles (member, admin, owner) - User search and chat member management - Message pagination and chat history - Last seen timestamps and activity tracking - Comprehensive API documentation with WebSocket events Architecture: - Clean layered architecture with services, models, schemas - WebSocket connection manager for real-time features - Modular notification system with multiple channels - Comprehensive error handling and validation - Production-ready with Docker support Technologies: FastAPI, WebSocket, SQLAlchemy, SQLite, Cryptography, Firebase, Pillow
This commit is contained in:
parent
f19a6fea04
commit
41bbd8c182
506
README.md
506
README.md
@ -1,15 +1,48 @@
|
||||
# Personal Task Management API
|
||||
# Real-time Chat API
|
||||
|
||||
A comprehensive FastAPI-based REST API for managing personal tasks, categories, and user authentication. Built with Python, FastAPI, SQLAlchemy, and SQLite.
|
||||
A comprehensive real-time chat API built with FastAPI, WebSocket, and advanced features including media sharing, end-to-end encryption, and push notifications. Perfect for building modern chat applications.
|
||||
|
||||
## Features
|
||||
## 🚀 Features
|
||||
|
||||
- **User Authentication**: JWT-based authentication with registration and login
|
||||
- **Task Management**: Full CRUD operations for tasks with status and priority tracking
|
||||
- **Category Management**: Organize tasks with custom categories
|
||||
- **User Authorization**: All operations are user-scoped and protected
|
||||
- **Database Migrations**: Alembic for database schema management
|
||||
- **API Documentation**: Auto-generated OpenAPI/Swagger documentation
|
||||
### Core Chat Features
|
||||
- **Real-time Messaging**: WebSocket-based instant messaging
|
||||
- **Direct Messages**: Private one-on-one conversations
|
||||
- **Group Chats**: Multi-user chat rooms with admin controls
|
||||
- **Message Types**: Text, images, videos, audio, documents
|
||||
- **Message Status**: Sent, delivered, read receipts
|
||||
- **Typing Indicators**: Real-time typing status
|
||||
- **Message Replies**: Reply to specific messages
|
||||
- **Message Editing**: Edit and delete messages
|
||||
|
||||
### Media & File Sharing
|
||||
- **Image Sharing**: Upload and share images with automatic thumbnail generation
|
||||
- **Video Support**: Share video files with metadata extraction
|
||||
- **Document Support**: PDF, Word, Excel, PowerPoint, and more
|
||||
- **Audio Messages**: Voice recordings and audio file sharing
|
||||
- **File Validation**: Size limits and type restrictions
|
||||
- **Secure Storage**: Files stored securely with access control
|
||||
|
||||
### Security & Encryption
|
||||
- **End-to-End Encryption**: RSA + AES hybrid encryption for messages
|
||||
- **JWT Authentication**: Secure token-based authentication
|
||||
- **Key Management**: Automatic RSA key pair generation
|
||||
- **Message Encryption**: Client-side encryption capabilities
|
||||
- **Secure File Storage**: Protected media file access
|
||||
|
||||
### Notifications & Alerts
|
||||
- **Push Notifications**: Firebase-based mobile notifications
|
||||
- **Mention Alerts**: @username notifications
|
||||
- **Real-time Notifications**: WebSocket-based instant alerts
|
||||
- **Notification History**: Persistent notification storage
|
||||
- **Custom Notification Types**: Message, mention, group invites
|
||||
|
||||
### Advanced Features
|
||||
- **User Presence**: Online/offline status tracking
|
||||
- **Last Seen**: User activity timestamps
|
||||
- **Chat Management**: Admin/owner roles and permissions
|
||||
- **Member Management**: Add/remove users, role assignment
|
||||
- **Message Search**: Find messages across chats
|
||||
- **Unread Counts**: Track unread messages per chat
|
||||
|
||||
## Quick Start
|
||||
|
||||
@ -43,112 +76,433 @@ The API will be available at `http://localhost:8000`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
- OpenAPI JSON: `http://localhost:8000/openapi.json`
|
||||
|
||||
## API Endpoints
|
||||
### WebSocket Connection
|
||||
|
||||
### Authentication
|
||||
- `POST /auth/register` - Register a new user
|
||||
- `POST /auth/login` - Login with email and password
|
||||
- `GET /auth/me` - Get current user information
|
||||
Connect to the WebSocket endpoint for real-time messaging:
|
||||
|
||||
### Tasks
|
||||
- `GET /tasks/` - List all tasks (with filtering options)
|
||||
- `POST /tasks/` - Create a new task
|
||||
- `GET /tasks/{task_id}` - Get a specific task
|
||||
- `PUT /tasks/{task_id}` - Update a task
|
||||
- `DELETE /tasks/{task_id}` - Delete a task
|
||||
```javascript
|
||||
const ws = new WebSocket('ws://localhost:8000/ws/{access_token}');
|
||||
|
||||
### Categories
|
||||
- `GET /categories/` - List all categories
|
||||
- `POST /categories/` - Create a new category
|
||||
- `GET /categories/{category_id}` - Get a specific category
|
||||
- `PUT /categories/{category_id}` - Update a category
|
||||
- `DELETE /categories/{category_id}` - Delete a category
|
||||
// Send a message
|
||||
ws.send(JSON.stringify({
|
||||
type: 'send_message',
|
||||
chat_id: 1,
|
||||
content: 'Hello, world!',
|
||||
message_type: 'text'
|
||||
}));
|
||||
|
||||
// Listen for messages
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
console.log('Received:', data);
|
||||
};
|
||||
```
|
||||
|
||||
## 📡 API Endpoints
|
||||
|
||||
### Authentication (`/api/auth`)
|
||||
- `POST /register` - Register a new user
|
||||
- `POST /login` - Login with username and password
|
||||
- `GET /me` - Get current user information
|
||||
- `PUT /me` - Update user profile
|
||||
- `POST /device-token` - Update device token for push notifications
|
||||
|
||||
### Chats (`/api/chats`)
|
||||
- `GET /` - Get user's chat list
|
||||
- `POST /direct` - Create or get direct chat with another user
|
||||
- `POST /group` - Create a new group chat
|
||||
- `GET /{chat_id}` - Get chat details and members
|
||||
- `PUT /{chat_id}` - Update chat information (admin/owner only)
|
||||
- `POST /{chat_id}/members` - Add member to chat
|
||||
- `DELETE /{chat_id}/members/{user_id}` - Remove member from chat
|
||||
|
||||
### Messages (`/api/messages`)
|
||||
- `GET /{chat_id}` - Get chat messages with pagination
|
||||
- `PUT /{message_id}` - Edit a message
|
||||
- `DELETE /{message_id}` - Delete a message
|
||||
- `POST /{message_id}/read` - Mark message as read
|
||||
|
||||
### Media (`/api/media`)
|
||||
- `POST /upload/{message_id}` - Upload media files for a message
|
||||
- `GET /{media_id}` - Download media file
|
||||
- `GET /{media_id}/thumbnail` - Get media thumbnail
|
||||
- `GET /{media_id}/info` - Get media file information
|
||||
- `DELETE /{media_id}` - Delete media file
|
||||
|
||||
### Notifications (`/api/notifications`)
|
||||
- `GET /` - Get user notifications
|
||||
- `GET /mentions` - Get unread mentions
|
||||
- `POST /{notification_id}/read` - Mark notification as read
|
||||
- `POST /mentions/{mention_id}/read` - Mark mention as read
|
||||
- `POST /read-all` - Mark all notifications as read
|
||||
- `GET /count` - Get unread notification count
|
||||
|
||||
### Encryption (`/api/encryption`)
|
||||
- `POST /generate-keys` - Generate new RSA key pair
|
||||
- `GET /public-key` - Get current user's public key
|
||||
- `GET /public-key/{user_id}` - Get another user's public key
|
||||
- `POST /encrypt` - Encrypt a message for a user
|
||||
- `POST /decrypt` - Decrypt a message
|
||||
- `PUT /public-key` - Update user's public key
|
||||
|
||||
### System
|
||||
- `GET /` - API information and links
|
||||
- `GET /` - API information and features
|
||||
- `GET /health` - Health check endpoint
|
||||
- `GET /api/status` - System status and statistics
|
||||
|
||||
## Data Models
|
||||
|
||||
### Task
|
||||
- `id`: Unique identifier
|
||||
- `title`: Task title (required)
|
||||
- `description`: Optional task description
|
||||
- `status`: pending, in_progress, completed, cancelled
|
||||
- `priority`: low, medium, high, urgent
|
||||
- `due_date`: Optional due date
|
||||
- `category_id`: Optional category association
|
||||
- `completed_at`: Timestamp when task was completed
|
||||
- `created_at/updated_at`: Timestamps
|
||||
|
||||
### Category
|
||||
- `id`: Unique identifier
|
||||
- `name`: Category name (required)
|
||||
- `description`: Optional description
|
||||
- `color`: Optional color code
|
||||
- `created_at/updated_at`: Timestamps
|
||||
## 🏗️ Data Models
|
||||
|
||||
### User
|
||||
- `id`: Unique identifier
|
||||
- `email`: User email (unique, required)
|
||||
- `username`: Username (unique, required)
|
||||
- `email`: Email address (unique, required)
|
||||
- `full_name`: Optional full name
|
||||
- `avatar_url`: Profile picture URL
|
||||
- `bio`: User biography
|
||||
- `is_active`: Account status
|
||||
- `created_at/updated_at`: Timestamps
|
||||
- `is_online`: Current online status
|
||||
- `last_seen`: Last activity timestamp
|
||||
- `public_key`: RSA public key for E2E encryption
|
||||
- `device_token`: Firebase device token for push notifications
|
||||
|
||||
## Environment Variables
|
||||
### Chat
|
||||
- `id`: Unique identifier
|
||||
- `name`: Chat name (null for direct messages)
|
||||
- `description`: Chat description
|
||||
- `chat_type`: "direct" or "group"
|
||||
- `avatar_url`: Chat avatar image
|
||||
- `is_active`: Active status
|
||||
|
||||
### Message
|
||||
- `id`: Unique identifier
|
||||
- `chat_id`: Associated chat
|
||||
- `sender_id`: Message sender
|
||||
- `reply_to_id`: Replied message ID (optional)
|
||||
- `content`: Message content (encrypted)
|
||||
- `content_type`: "text", "image", "video", "audio", "file", "system"
|
||||
- `status`: "sent", "delivered", "read", "failed"
|
||||
- `is_edited`: Edit status
|
||||
- `is_deleted`: Deletion status
|
||||
- `created_at`: Send timestamp
|
||||
|
||||
### Media
|
||||
- `id`: Unique identifier
|
||||
- `message_id`: Associated message
|
||||
- `filename`: Stored filename
|
||||
- `original_filename`: Original upload name
|
||||
- `file_size`: Size in bytes
|
||||
- `mime_type`: File MIME type
|
||||
- `media_type`: "image", "video", "audio", "document", "other"
|
||||
- `width/height`: Image/video dimensions
|
||||
- `duration`: Audio/video duration
|
||||
- `thumbnail_path`: Thumbnail file path
|
||||
|
||||
### Chat Member
|
||||
- `id`: Unique identifier
|
||||
- `chat_id`: Associated chat
|
||||
- `user_id`: User ID
|
||||
- `role`: "member", "admin", "owner"
|
||||
- `nickname`: Chat-specific nickname
|
||||
- `is_muted`: Mute status
|
||||
- `joined_at`: Join timestamp
|
||||
- `last_read_message_id`: Last read message for unread count
|
||||
|
||||
### Notification
|
||||
- `id`: Unique identifier
|
||||
- `user_id`: Recipient user
|
||||
- `notification_type`: "message", "mention", "group_invite"
|
||||
- `title`: Notification title
|
||||
- `body`: Notification content
|
||||
- `data`: Additional JSON data
|
||||
- `is_read`: Read status
|
||||
- `created_at`: Creation timestamp
|
||||
|
||||
## 🔧 Environment Variables
|
||||
|
||||
Set the following environment variables for production:
|
||||
|
||||
- `SECRET_KEY`: JWT secret key for token signing (required for production)
|
||||
|
||||
Example:
|
||||
```bash
|
||||
export SECRET_KEY="your-super-secret-key-here"
|
||||
# Required
|
||||
SECRET_KEY="your-super-secret-key-change-in-production"
|
||||
|
||||
# Optional - Firebase Push Notifications
|
||||
FIREBASE_CREDENTIALS_PATH="/path/to/firebase-credentials.json"
|
||||
# OR
|
||||
FIREBASE_CREDENTIALS_JSON='{"type": "service_account", ...}'
|
||||
|
||||
# Optional - Redis for caching (if implemented)
|
||||
REDIS_URL="redis://localhost:6379"
|
||||
```
|
||||
|
||||
## Database
|
||||
## 💾 Database
|
||||
|
||||
The application uses SQLite by default with the database file stored at `/app/storage/db/db.sqlite`. The database is automatically created when the application starts.
|
||||
The application uses SQLite by default with the database file stored at `/app/storage/db/chat.sqlite`. The database is automatically created when the application starts.
|
||||
|
||||
## Authentication
|
||||
**File Storage Structure:**
|
||||
```
|
||||
/app/storage/
|
||||
├── db/
|
||||
│ └── chat.sqlite # Main database
|
||||
├── media/ # Uploaded media files
|
||||
│ ├── image_file.jpg
|
||||
│ └── document.pdf
|
||||
└── thumbnails/ # Generated thumbnails
|
||||
└── thumb_image_file.jpg
|
||||
```
|
||||
|
||||
The API uses JWT bearer tokens for authentication. To access protected endpoints:
|
||||
## 🔐 Authentication & Security
|
||||
|
||||
### JWT Authentication
|
||||
1. Register a new user or login with existing credentials
|
||||
2. Include the access token in the Authorization header: `Authorization: Bearer <token>`
|
||||
3. For WebSocket connections, pass the token in the URL: `/ws/{access_token}`
|
||||
|
||||
## Development
|
||||
### End-to-End Encryption
|
||||
1. **Key Generation**: Each user gets an RSA key pair on registration
|
||||
2. **Message Encryption**: Messages encrypted with recipient's public key
|
||||
3. **Hybrid Encryption**: Uses RSA + AES for large messages
|
||||
4. **Client-Side**: Encryption/decryption can be done client-side
|
||||
|
||||
Example encryption workflow:
|
||||
```python
|
||||
# Get recipient's public key
|
||||
GET /api/encryption/public-key/{user_id}
|
||||
|
||||
# Encrypt message client-side
|
||||
POST /api/encryption/encrypt
|
||||
{
|
||||
"message": "Hello, this is encrypted!",
|
||||
"recipient_user_id": 123
|
||||
}
|
||||
|
||||
# Send encrypted message via WebSocket
|
||||
ws.send({
|
||||
"type": "send_message",
|
||||
"chat_id": 1,
|
||||
"content": "encrypted_content_here"
|
||||
})
|
||||
```
|
||||
|
||||
## 📱 WebSocket Events
|
||||
|
||||
### Client → Server Events
|
||||
```javascript
|
||||
// Send message
|
||||
{
|
||||
"type": "send_message",
|
||||
"chat_id": 1,
|
||||
"content": "Hello!",
|
||||
"message_type": "text",
|
||||
"reply_to_id": null
|
||||
}
|
||||
|
||||
// Typing indicators
|
||||
{
|
||||
"type": "typing_start",
|
||||
"chat_id": 1
|
||||
}
|
||||
|
||||
{
|
||||
"type": "typing_stop",
|
||||
"chat_id": 1
|
||||
}
|
||||
|
||||
// Mark message as read
|
||||
{
|
||||
"type": "message_read",
|
||||
"message_id": 123,
|
||||
"chat_id": 1
|
||||
}
|
||||
|
||||
// Join/leave chat room
|
||||
{
|
||||
"type": "join_chat",
|
||||
"chat_id": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Server → Client Events
|
||||
```javascript
|
||||
// New message received
|
||||
{
|
||||
"type": "new_message",
|
||||
"message": {
|
||||
"id": 123,
|
||||
"chat_id": 1,
|
||||
"sender_id": 456,
|
||||
"sender_username": "john_doe",
|
||||
"content": "Hello!",
|
||||
"created_at": "2024-12-20T12:00:00Z"
|
||||
}
|
||||
}
|
||||
|
||||
// User status updates
|
||||
{
|
||||
"type": "user_status",
|
||||
"user_id": 456,
|
||||
"username": "john_doe",
|
||||
"status": "online"
|
||||
}
|
||||
|
||||
// Typing indicators
|
||||
{
|
||||
"type": "typing_start",
|
||||
"chat_id": 1,
|
||||
"user_id": 456,
|
||||
"username": "john_doe"
|
||||
}
|
||||
|
||||
// Notifications
|
||||
{
|
||||
"type": "mention_notification",
|
||||
"title": "Mentioned by john_doe",
|
||||
"body": "You were mentioned in a message",
|
||||
"chat_id": 1
|
||||
}
|
||||
```
|
||||
|
||||
## 🔧 Development
|
||||
|
||||
### Code Quality
|
||||
|
||||
The project uses Ruff for linting and code formatting:
|
||||
|
||||
```bash
|
||||
ruff check .
|
||||
# Lint and format code
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
```
|
||||
|
||||
### Database Migrations
|
||||
|
||||
Create new migrations after model changes:
|
||||
|
||||
```bash
|
||||
# Create new migration
|
||||
alembic revision --autogenerate -m "Description of changes"
|
||||
|
||||
# Apply migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Check current migration
|
||||
alembic current
|
||||
```
|
||||
|
||||
## Architecture
|
||||
### Testing WebSocket Connection
|
||||
```bash
|
||||
# Install websocat for testing
|
||||
pip install websockets
|
||||
|
||||
- **FastAPI**: Modern, fast web framework for building APIs
|
||||
- **SQLAlchemy**: ORM for database operations
|
||||
- **Alembic**: Database migration tool
|
||||
# Test WebSocket connection
|
||||
python -c "
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
|
||||
async def test():
|
||||
uri = 'ws://localhost:8000/ws/your_jwt_token_here'
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Send a test message
|
||||
await websocket.send(json.dumps({
|
||||
'type': 'send_message',
|
||||
'chat_id': 1,
|
||||
'content': 'Test message'
|
||||
}))
|
||||
|
||||
# Listen for response
|
||||
response = await websocket.recv()
|
||||
print(json.loads(response))
|
||||
|
||||
asyncio.run(test())
|
||||
"
|
||||
```
|
||||
|
||||
## 🏗️ Architecture
|
||||
|
||||
### Technology Stack
|
||||
- **FastAPI**: Modern, fast web framework with WebSocket support
|
||||
- **SQLAlchemy**: Python ORM for database operations
|
||||
- **Alembic**: Database migration management
|
||||
- **Pydantic**: Data validation using Python type annotations
|
||||
- **WebSocket**: Real-time bidirectional communication
|
||||
- **JWT**: JSON Web Tokens for authentication
|
||||
- **SQLite**: Lightweight database for development and small deployments
|
||||
- **SQLite**: Lightweight database for development
|
||||
- **Cryptography**: RSA + AES encryption for E2E security
|
||||
- **Firebase Admin**: Push notification service
|
||||
- **Pillow**: Image processing and thumbnail generation
|
||||
|
||||
The application follows a clean architecture pattern with separate layers for:
|
||||
- API routes (`app/api/`)
|
||||
- Database models (`app/models/`)
|
||||
- Data schemas (`app/schemas/`)
|
||||
- Core utilities (`app/core/`)
|
||||
- Database configuration (`app/db/`)
|
||||
### Project Structure
|
||||
```
|
||||
app/
|
||||
├── api/ # API route handlers
|
||||
│ ├── auth.py # Authentication endpoints
|
||||
│ ├── chats.py # Chat management
|
||||
│ ├── messages.py # Message operations
|
||||
│ ├── media.py # File upload/download
|
||||
│ ├── notifications.py # Notification management
|
||||
│ └── encryption.py # E2E encryption utilities
|
||||
├── core/ # Core utilities
|
||||
│ ├── deps.py # Dependency injection
|
||||
│ └── security.py # Security functions
|
||||
├── db/ # Database configuration
|
||||
│ ├── base.py # SQLAlchemy base
|
||||
│ └── session.py # Database session management
|
||||
├── models/ # SQLAlchemy models
|
||||
│ ├── user.py # User model
|
||||
│ ├── chat.py # Chat model
|
||||
│ ├── message.py # Message model
|
||||
│ └── ...
|
||||
├── schemas/ # Pydantic schemas
|
||||
│ ├── user.py # User schemas
|
||||
│ ├── chat.py # Chat schemas
|
||||
│ └── ...
|
||||
├── services/ # Business logic services
|
||||
│ ├── media_service.py # File handling
|
||||
│ ├── encryption_service.py # Encryption logic
|
||||
│ ├── notification_service.py # Notification management
|
||||
│ └── push_notification_service.py
|
||||
├── websocket/ # WebSocket handling
|
||||
│ ├── connection_manager.py # Connection management
|
||||
│ └── chat_handler.py # Message handling
|
||||
└── utils/ # Utility functions
|
||||
```
|
||||
|
||||
### Design Patterns
|
||||
- **Repository Pattern**: Data access abstraction
|
||||
- **Service Layer**: Business logic separation
|
||||
- **Dependency Injection**: Loose coupling
|
||||
- **Event-Driven**: WebSocket event handling
|
||||
- **Clean Architecture**: Separation of concerns
|
||||
|
||||
## 📊 Performance Considerations
|
||||
|
||||
- **Connection Pooling**: SQLAlchemy connection management
|
||||
- **File Storage**: Local storage with direct file serving
|
||||
- **Memory Management**: Efficient WebSocket connection tracking
|
||||
- **Thumbnail Generation**: Automatic image thumbnail creation
|
||||
- **Message Pagination**: Efficient large chat history loading
|
||||
- **Real-time Optimization**: WebSocket connection pooling
|
||||
|
||||
## 🚀 Deployment
|
||||
|
||||
### Production Checklist
|
||||
- [ ] Set strong `SECRET_KEY` environment variable
|
||||
- [ ] Configure Firebase credentials for push notifications
|
||||
- [ ] Set up proper file storage (S3, etc.) for production
|
||||
- [ ] Configure reverse proxy (nginx) for WebSocket support
|
||||
- [ ] Set up SSL/TLS certificates
|
||||
- [ ] Configure database backups
|
||||
- [ ] Set up monitoring and logging
|
||||
- [ ] Configure CORS for your frontend domain
|
||||
|
||||
### Docker Deployment
|
||||
```dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
This comprehensive chat API provides a solid foundation for building modern messaging applications with real-time communication, security, and scalability in mind.
|
||||
|
@ -2,7 +2,7 @@
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
sqlalchemy.url = sqlite:////app/storage/db/db.sqlite
|
||||
sqlalchemy.url = sqlite:////app/storage/db/chat.sqlite
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Initial migration
|
||||
"""Initial migration for chat system
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
@ -20,58 +20,149 @@ def upgrade() -> None:
|
||||
# Create users table
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||
sa.Column('full_name', sa.String(), nullable=True),
|
||||
sa.Column('avatar_url', sa.String(), nullable=True),
|
||||
sa.Column('bio', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_online', sa.Boolean(), nullable=True),
|
||||
sa.Column('last_seen', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('public_key', sa.Text(), nullable=True),
|
||||
sa.Column('device_token', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
|
||||
# Create categories table
|
||||
op.create_table('categories',
|
||||
# Create chats table
|
||||
op.create_table('chats',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=True),
|
||||
sa.Column('color', sa.String(), nullable=True),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_categories_id'), 'categories', ['id'], unique=False)
|
||||
|
||||
# Create tasks table
|
||||
op.create_table('tasks',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('title', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('status', sa.Enum('PENDING', 'IN_PROGRESS', 'COMPLETED', 'CANCELLED', name='taskstatus'), nullable=True),
|
||||
sa.Column('priority', sa.Enum('LOW', 'MEDIUM', 'HIGH', 'URGENT', name='taskpriority'), nullable=True),
|
||||
sa.Column('due_date', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('category_id', sa.Integer(), nullable=True),
|
||||
sa.Column('chat_type', sa.Enum('DIRECT', 'GROUP', name='chattype'), nullable=False),
|
||||
sa.Column('avatar_url', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['category_id'], ['categories.id'], ),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tasks_id'), 'tasks', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_tasks_title'), 'tasks', ['title'], unique=False)
|
||||
op.create_index(op.f('ix_chats_id'), 'chats', ['id'], unique=False)
|
||||
|
||||
# Create chat_members table
|
||||
op.create_table('chat_members',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('chat_id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role', sa.Enum('MEMBER', 'ADMIN', 'OWNER', name='memberrole'), nullable=True),
|
||||
sa.Column('nickname', sa.String(), nullable=True),
|
||||
sa.Column('is_muted', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_banned', sa.Boolean(), nullable=True),
|
||||
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('last_read_message_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['chat_id'], ['chats.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_chat_members_id'), 'chat_members', ['id'], unique=False)
|
||||
|
||||
# Create messages table
|
||||
op.create_table('messages',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('chat_id', sa.Integer(), nullable=False),
|
||||
sa.Column('sender_id', sa.Integer(), nullable=False),
|
||||
sa.Column('reply_to_id', sa.Integer(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('content_type', sa.Enum('TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'FILE', 'SYSTEM', name='messagetype'), nullable=True),
|
||||
sa.Column('status', sa.Enum('SENT', 'DELIVERED', 'READ', 'FAILED', name='messagestatus'), nullable=True),
|
||||
sa.Column('is_edited', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), nullable=True),
|
||||
sa.Column('edited_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.ForeignKeyConstraint(['chat_id'], ['chats.id'], ),
|
||||
sa.ForeignKeyConstraint(['reply_to_id'], ['messages.id'], ),
|
||||
sa.ForeignKeyConstraint(['sender_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False)
|
||||
|
||||
# Create media table
|
||||
op.create_table('media',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('message_id', sa.Integer(), nullable=False),
|
||||
sa.Column('uploader_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(), nullable=False),
|
||||
sa.Column('original_filename', sa.String(), nullable=False),
|
||||
sa.Column('file_path', sa.String(), nullable=False),
|
||||
sa.Column('file_size', sa.BigInteger(), nullable=False),
|
||||
sa.Column('mime_type', sa.String(), nullable=False),
|
||||
sa.Column('media_type', sa.Enum('IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT', 'OTHER', name='mediatype'), nullable=False),
|
||||
sa.Column('width', sa.Integer(), nullable=True),
|
||||
sa.Column('height', sa.Integer(), nullable=True),
|
||||
sa.Column('duration', sa.Integer(), nullable=True),
|
||||
sa.Column('thumbnail_path', sa.String(), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.ForeignKeyConstraint(['message_id'], ['messages.id'], ),
|
||||
sa.ForeignKeyConstraint(['uploader_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_media_id'), 'media', ['id'], unique=False)
|
||||
|
||||
# Create mentions table
|
||||
op.create_table('mentions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('message_id', sa.Integer(), nullable=False),
|
||||
sa.Column('mentioned_user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('is_read', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('read_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['mentioned_user_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['message_id'], ['messages.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_mentions_id'), 'mentions', ['id'], unique=False)
|
||||
|
||||
# Create notifications table
|
||||
op.create_table('notifications',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('chat_id', sa.Integer(), nullable=True),
|
||||
sa.Column('message_id', sa.Integer(), nullable=True),
|
||||
sa.Column('notification_type', sa.Enum('MESSAGE', 'MENTION', 'GROUP_INVITE', 'GROUP_JOIN', 'GROUP_LEAVE', name='notificationtype'), nullable=False),
|
||||
sa.Column('title', sa.String(), nullable=False),
|
||||
sa.Column('body', sa.Text(), nullable=False),
|
||||
sa.Column('data', sa.Text(), nullable=True),
|
||||
sa.Column('is_read', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_sent', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.Column('read_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['chat_id'], ['chats.id'], ),
|
||||
sa.ForeignKeyConstraint(['message_id'], ['messages.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notifications_id'), 'notifications', ['id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_tasks_title'), table_name='tasks')
|
||||
op.drop_index(op.f('ix_tasks_id'), table_name='tasks')
|
||||
op.drop_table('tasks')
|
||||
op.drop_index(op.f('ix_categories_id'), table_name='categories')
|
||||
op.drop_table('categories')
|
||||
op.drop_index(op.f('ix_notifications_id'), table_name='notifications')
|
||||
op.drop_table('notifications')
|
||||
op.drop_index(op.f('ix_mentions_id'), table_name='mentions')
|
||||
op.drop_table('mentions')
|
||||
op.drop_index(op.f('ix_media_id'), table_name='media')
|
||||
op.drop_table('media')
|
||||
op.drop_index(op.f('ix_messages_id'), table_name='messages')
|
||||
op.drop_table('messages')
|
||||
op.drop_index(op.f('ix_chat_members_id'), table_name='chat_members')
|
||||
op.drop_table('chat_members')
|
||||
op.drop_index(op.f('ix_chats_id'), table_name='chats')
|
||||
op.drop_table('chats')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
124
app/api/auth.py
124
app/api/auth.py
@ -1,60 +1,144 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.session import get_db
|
||||
from app.core.security import verify_password, get_password_hash, create_access_token
|
||||
from app.core.security import verify_password, get_password_hash, create_access_token, generate_rsa_key_pair
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models.user import User as UserModel
|
||||
from app.schemas.user import User, UserCreate, Token
|
||||
from app.schemas.user import User, UserCreate, UserLogin, UserUpdate, Token, UserPublic
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
|
||||
@router.post("/register", response_model=User)
|
||||
@router.post("/register", response_model=Token)
|
||||
def register(
|
||||
user: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
# Check if user already exists
|
||||
db_user = db.query(UserModel).filter(UserModel.email == user.email).first()
|
||||
if db_user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Email already registered"
|
||||
)
|
||||
# Check if username or email already exists
|
||||
existing_user = db.query(UserModel).filter(
|
||||
(UserModel.username == user.username) | (UserModel.email == user.email)
|
||||
).first()
|
||||
|
||||
if existing_user:
|
||||
if existing_user.username == user.username:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
# Generate E2E encryption key pair
|
||||
private_key, public_key = generate_rsa_key_pair()
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user.password)
|
||||
db_user = UserModel(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
hashed_password=hashed_password,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active
|
||||
bio=user.bio,
|
||||
is_active=user.is_active,
|
||||
public_key=public_key
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
# Create access token
|
||||
access_token = create_access_token(subject=db_user.id)
|
||||
|
||||
user_public = UserPublic(
|
||||
id=db_user.id,
|
||||
username=db_user.username,
|
||||
full_name=db_user.full_name,
|
||||
avatar_url=db_user.avatar_url,
|
||||
is_online=db_user.is_online,
|
||||
last_seen=db_user.last_seen
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
user=user_public
|
||||
)
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
def login(
|
||||
email: str,
|
||||
password: str,
|
||||
credentials: UserLogin,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(UserModel).filter(UserModel.email == email).first()
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
user = db.query(UserModel).filter(UserModel.username == credentials.username).first()
|
||||
if not user or not verify_password(credentials.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
access_token = create_access_token(subject=user.id)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
user_public = UserPublic(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
full_name=user.full_name,
|
||||
avatar_url=user.avatar_url,
|
||||
is_online=user.is_online,
|
||||
last_seen=user.last_seen
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
user=user_public
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=User)
|
||||
def read_users_me(current_user: UserModel = Depends(get_current_active_user)):
|
||||
return current_user
|
||||
|
||||
@router.put("/me", response_model=User)
|
||||
def update_user_me(
|
||||
user_update: UserUpdate,
|
||||
current_user: UserModel = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
update_data = user_update.dict(exclude_unset=True)
|
||||
|
||||
# Handle password update
|
||||
if "password" in update_data:
|
||||
update_data["hashed_password"] = get_password_hash(update_data.pop("password"))
|
||||
|
||||
# Check username/email uniqueness if being updated
|
||||
if "username" in update_data:
|
||||
existing = db.query(UserModel).filter(
|
||||
UserModel.username == update_data["username"],
|
||||
UserModel.id != current_user.id
|
||||
).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Username already taken")
|
||||
|
||||
if "email" in update_data:
|
||||
existing = db.query(UserModel).filter(
|
||||
UserModel.email == update_data["email"],
|
||||
UserModel.id != current_user.id
|
||||
).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Email already taken")
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(current_user, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
return current_user
|
||||
|
||||
@router.post("/device-token")
|
||||
def update_device_token(
|
||||
device_token: str,
|
||||
current_user: UserModel = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update user's device token for push notifications"""
|
||||
current_user.device_token = device_token
|
||||
db.commit()
|
||||
return {"message": "Device token updated successfully"}
|
@ -1,87 +0,0 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.category import Category as CategoryModel
|
||||
from app.schemas.category import Category, CategoryCreate, CategoryUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=Category)
|
||||
def create_category(
|
||||
category: CategoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
db_category = CategoryModel(**category.dict(), owner_id=current_user.id)
|
||||
db.add(db_category)
|
||||
db.commit()
|
||||
db.refresh(db_category)
|
||||
return db_category
|
||||
|
||||
@router.get("/", response_model=List[Category])
|
||||
def read_categories(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
categories = db.query(CategoryModel).filter(
|
||||
CategoryModel.owner_id == current_user.id
|
||||
).offset(skip).limit(limit).all()
|
||||
return categories
|
||||
|
||||
@router.get("/{category_id}", response_model=Category)
|
||||
def read_category(
|
||||
category_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
category = db.query(CategoryModel).filter(
|
||||
CategoryModel.id == category_id,
|
||||
CategoryModel.owner_id == current_user.id
|
||||
).first()
|
||||
if category is None:
|
||||
raise HTTPException(status_code=404, detail="Category not found")
|
||||
return category
|
||||
|
||||
@router.put("/{category_id}", response_model=Category)
|
||||
def update_category(
|
||||
category_id: int,
|
||||
category_update: CategoryUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
category = db.query(CategoryModel).filter(
|
||||
CategoryModel.id == category_id,
|
||||
CategoryModel.owner_id == current_user.id
|
||||
).first()
|
||||
if category is None:
|
||||
raise HTTPException(status_code=404, detail="Category not found")
|
||||
|
||||
update_data = category_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(category, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(category)
|
||||
return category
|
||||
|
||||
@router.delete("/{category_id}")
|
||||
def delete_category(
|
||||
category_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
category = db.query(CategoryModel).filter(
|
||||
CategoryModel.id == category_id,
|
||||
CategoryModel.owner_id == current_user.id
|
||||
).first()
|
||||
if category is None:
|
||||
raise HTTPException(status_code=404, detail="Category not found")
|
||||
|
||||
db.delete(category)
|
||||
db.commit()
|
||||
return {"message": "Category deleted successfully"}
|
391
app/api/chats.py
Normal file
391
app/api/chats.py
Normal file
@ -0,0 +1,391 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models import User, Chat as ChatModel, ChatMember as ChatMemberModel, Message as MessageModel, ChatType, MemberRole
|
||||
from app.schemas.chat import Chat, ChatCreate, DirectChatCreate, ChatUpdate, ChatList, ChatMember, ChatMemberCreate
|
||||
from app.schemas.user import UserPublic
|
||||
from app.websocket.connection_manager import connection_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[ChatList])
|
||||
def get_user_chats(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all chats for the current user"""
|
||||
# Get user's chat memberships
|
||||
memberships = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
chats = []
|
||||
for membership in memberships:
|
||||
chat = membership.chat
|
||||
|
||||
# Get last message
|
||||
last_message = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat.id
|
||||
).order_by(desc(MessageModel.created_at)).first()
|
||||
|
||||
# Calculate unread count
|
||||
unread_count = 0
|
||||
if membership.last_read_message_id:
|
||||
unread_count = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat.id,
|
||||
MessageModel.id > membership.last_read_message_id,
|
||||
MessageModel.sender_id != current_user.id
|
||||
).count()
|
||||
else:
|
||||
unread_count = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat.id,
|
||||
MessageModel.sender_id != current_user.id
|
||||
).count()
|
||||
|
||||
# For direct chats, get the other user's online status
|
||||
is_online = False
|
||||
chat_name = chat.name
|
||||
if chat.chat_type == ChatType.DIRECT:
|
||||
other_member = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat.id,
|
||||
ChatMemberModel.user_id != current_user.id
|
||||
).first()
|
||||
if other_member:
|
||||
is_online = other_member.user.is_online
|
||||
if not chat_name:
|
||||
chat_name = other_member.user.username
|
||||
|
||||
chat_data = ChatList(
|
||||
id=chat.id,
|
||||
name=chat_name,
|
||||
chat_type=chat.chat_type,
|
||||
avatar_url=chat.avatar_url,
|
||||
last_message={
|
||||
"id": last_message.id,
|
||||
"content": last_message.content[:100] if last_message.content else "",
|
||||
"sender_username": last_message.sender.username,
|
||||
"created_at": last_message.created_at.isoformat()
|
||||
} if last_message else None,
|
||||
last_activity=last_message.created_at if last_message else chat.updated_at,
|
||||
unread_count=unread_count,
|
||||
is_online=is_online
|
||||
)
|
||||
chats.append(chat_data)
|
||||
|
||||
# Sort by last activity
|
||||
chats.sort(key=lambda x: x.last_activity or x.id, reverse=True)
|
||||
return chats
|
||||
|
||||
@router.post("/direct", response_model=Chat)
|
||||
def create_direct_chat(
|
||||
chat_data: DirectChatCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create or get existing direct chat"""
|
||||
other_user = db.query(User).filter(User.id == chat_data.other_user_id).first()
|
||||
if not other_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if other_user.id == current_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot create chat with yourself")
|
||||
|
||||
# Check if direct chat already exists
|
||||
existing_chat = db.query(ChatModel).join(ChatMemberModel).filter(
|
||||
ChatModel.chat_type == ChatType.DIRECT,
|
||||
ChatMemberModel.user_id.in_([current_user.id, other_user.id])
|
||||
).group_by(ChatModel.id).having(
|
||||
db.func.count(ChatMemberModel.user_id) == 2
|
||||
).first()
|
||||
|
||||
if existing_chat:
|
||||
return _get_chat_with_details(existing_chat.id, current_user.id, db)
|
||||
|
||||
# Create new direct chat
|
||||
chat = ChatModel(
|
||||
chat_type=ChatType.DIRECT,
|
||||
is_active=True
|
||||
)
|
||||
db.add(chat)
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
|
||||
# Add both users as members
|
||||
for user in [current_user, other_user]:
|
||||
member = ChatMemberModel(
|
||||
chat_id=chat.id,
|
||||
user_id=user.id,
|
||||
role=MemberRole.MEMBER
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Add users to connection manager
|
||||
connection_manager.add_user_to_chat(current_user.id, chat.id)
|
||||
connection_manager.add_user_to_chat(other_user.id, chat.id)
|
||||
|
||||
return _get_chat_with_details(chat.id, current_user.id, db)
|
||||
|
||||
@router.post("/group", response_model=Chat)
|
||||
def create_group_chat(
|
||||
chat_data: ChatCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a new group chat"""
|
||||
if not chat_data.name:
|
||||
raise HTTPException(status_code=400, detail="Group name is required")
|
||||
|
||||
# Create group chat
|
||||
chat = ChatModel(
|
||||
name=chat_data.name,
|
||||
description=chat_data.description,
|
||||
chat_type=ChatType.GROUP,
|
||||
avatar_url=chat_data.avatar_url,
|
||||
is_active=True
|
||||
)
|
||||
db.add(chat)
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
|
||||
# Add creator as owner
|
||||
owner_member = ChatMemberModel(
|
||||
chat_id=chat.id,
|
||||
user_id=current_user.id,
|
||||
role=MemberRole.OWNER
|
||||
)
|
||||
db.add(owner_member)
|
||||
|
||||
# Add other members
|
||||
for member_id in chat_data.member_ids:
|
||||
if member_id != current_user.id: # Don't add creator twice
|
||||
user = db.query(User).filter(User.id == member_id).first()
|
||||
if user:
|
||||
member = ChatMemberModel(
|
||||
chat_id=chat.id,
|
||||
user_id=member_id,
|
||||
role=MemberRole.MEMBER
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Add users to connection manager
|
||||
connection_manager.add_user_to_chat(current_user.id, chat.id)
|
||||
for member_id in chat_data.member_ids:
|
||||
connection_manager.add_user_to_chat(member_id, chat.id)
|
||||
|
||||
return _get_chat_with_details(chat.id, current_user.id, db)
|
||||
|
||||
@router.get("/{chat_id}", response_model=Chat)
|
||||
def get_chat(
|
||||
chat_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get chat details"""
|
||||
# Verify user is member of chat
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
return _get_chat_with_details(chat_id, current_user.id, db)
|
||||
|
||||
@router.put("/{chat_id}", response_model=Chat)
|
||||
def update_chat(
|
||||
chat_id: int,
|
||||
chat_update: ChatUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update chat details (admin/owner only)"""
|
||||
# Verify user is admin/owner of chat
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == current_user.id,
|
||||
ChatMemberModel.role.in_([MemberRole.ADMIN, MemberRole.OWNER])
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
chat = db.query(ChatModel).filter(ChatModel.id == chat_id).first()
|
||||
if not chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
update_data = chat_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(chat, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
|
||||
return _get_chat_with_details(chat_id, current_user.id, db)
|
||||
|
||||
@router.post("/{chat_id}/members", response_model=ChatMember)
|
||||
def add_chat_member(
|
||||
chat_id: int,
|
||||
member_data: ChatMemberCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Add member to chat (admin/owner only)"""
|
||||
# Verify permissions
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == current_user.id,
|
||||
ChatMemberModel.role.in_([MemberRole.ADMIN, MemberRole.OWNER])
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# Check if user is already a member
|
||||
existing_member = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == member_data.user_id
|
||||
).first()
|
||||
|
||||
if existing_member:
|
||||
raise HTTPException(status_code=400, detail="User is already a member")
|
||||
|
||||
# Verify user exists
|
||||
user = db.query(User).filter(User.id == member_data.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Add member
|
||||
new_member = ChatMemberModel(
|
||||
chat_id=chat_id,
|
||||
user_id=member_data.user_id,
|
||||
role=member_data.role,
|
||||
nickname=member_data.nickname
|
||||
)
|
||||
db.add(new_member)
|
||||
db.commit()
|
||||
db.refresh(new_member)
|
||||
|
||||
# Add to connection manager
|
||||
connection_manager.add_user_to_chat(member_data.user_id, chat_id)
|
||||
|
||||
return ChatMember(
|
||||
id=new_member.id,
|
||||
chat_id=new_member.chat_id,
|
||||
user_id=new_member.user_id,
|
||||
role=new_member.role,
|
||||
nickname=new_member.nickname,
|
||||
is_muted=new_member.is_muted,
|
||||
is_banned=new_member.is_banned,
|
||||
joined_at=new_member.joined_at,
|
||||
last_read_message_id=new_member.last_read_message_id,
|
||||
user=UserPublic.from_orm(user)
|
||||
)
|
||||
|
||||
@router.delete("/{chat_id}/members/{user_id}")
|
||||
def remove_chat_member(
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Remove member from chat"""
|
||||
# Users can remove themselves, or admins/owners can remove others
|
||||
if user_id != current_user.id:
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == current_user.id,
|
||||
ChatMemberModel.role.in_([MemberRole.ADMIN, MemberRole.OWNER])
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# Find and remove member
|
||||
member = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not member:
|
||||
raise HTTPException(status_code=404, detail="Member not found")
|
||||
|
||||
db.delete(member)
|
||||
db.commit()
|
||||
|
||||
# Remove from connection manager
|
||||
connection_manager.remove_user_from_chat(user_id, chat_id)
|
||||
|
||||
return {"message": "Member removed successfully"}
|
||||
|
||||
def _get_chat_with_details(chat_id: int, current_user_id: int, db: Session) -> Chat:
|
||||
"""Helper function to get chat with all details"""
|
||||
chat = db.query(ChatModel).filter(ChatModel.id == chat_id).first()
|
||||
|
||||
# Get members
|
||||
members = db.query(ChatMemberModel).filter(ChatMemberModel.chat_id == chat_id).all()
|
||||
chat_members = []
|
||||
|
||||
for member in members:
|
||||
user_public = UserPublic.from_orm(member.user)
|
||||
chat_member = ChatMember(
|
||||
id=member.id,
|
||||
chat_id=member.chat_id,
|
||||
user_id=member.user_id,
|
||||
role=member.role,
|
||||
nickname=member.nickname,
|
||||
is_muted=member.is_muted,
|
||||
is_banned=member.is_banned,
|
||||
joined_at=member.joined_at,
|
||||
last_read_message_id=member.last_read_message_id,
|
||||
user=user_public
|
||||
)
|
||||
chat_members.append(chat_member)
|
||||
|
||||
# Get last message
|
||||
last_message = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat_id
|
||||
).order_by(desc(MessageModel.created_at)).first()
|
||||
|
||||
# Calculate unread count
|
||||
current_member = next((m for m in members if m.user_id == current_user_id), None)
|
||||
unread_count = 0
|
||||
if current_member and current_member.last_read_message_id:
|
||||
unread_count = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat_id,
|
||||
MessageModel.id > current_member.last_read_message_id,
|
||||
MessageModel.sender_id != current_user_id
|
||||
).count()
|
||||
else:
|
||||
unread_count = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat_id,
|
||||
MessageModel.sender_id != current_user_id
|
||||
).count()
|
||||
|
||||
return Chat(
|
||||
id=chat.id,
|
||||
name=chat.name,
|
||||
description=chat.description,
|
||||
chat_type=chat.chat_type,
|
||||
avatar_url=chat.avatar_url,
|
||||
is_active=chat.is_active,
|
||||
created_at=chat.created_at,
|
||||
updated_at=chat.updated_at,
|
||||
members=chat_members,
|
||||
last_message={
|
||||
"id": last_message.id,
|
||||
"content": last_message.content,
|
||||
"sender_username": last_message.sender.username,
|
||||
"created_at": last_message.created_at.isoformat()
|
||||
} if last_message else None,
|
||||
unread_count=unread_count
|
||||
)
|
119
app/api/encryption.py
Normal file
119
app/api/encryption.py
Normal file
@ -0,0 +1,119 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.services.encryption_service import encryption_service
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class KeyPairResponse(BaseModel):
|
||||
public_key: str
|
||||
private_key: str
|
||||
|
||||
class EncryptRequest(BaseModel):
|
||||
message: str
|
||||
recipient_user_id: int
|
||||
|
||||
class EncryptResponse(BaseModel):
|
||||
encrypted_message: str
|
||||
|
||||
class DecryptRequest(BaseModel):
|
||||
encrypted_message: str
|
||||
private_key: str
|
||||
|
||||
class DecryptResponse(BaseModel):
|
||||
decrypted_message: str
|
||||
|
||||
@router.post("/generate-keys", response_model=KeyPairResponse)
|
||||
def generate_encryption_keys(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate new RSA key pair for E2E encryption"""
|
||||
private_key, public_key = encryption_service.generate_rsa_key_pair()
|
||||
|
||||
# Update user's public key in database
|
||||
current_user.public_key = public_key
|
||||
db.commit()
|
||||
|
||||
return KeyPairResponse(
|
||||
public_key=public_key,
|
||||
private_key=private_key
|
||||
)
|
||||
|
||||
@router.get("/public-key")
|
||||
def get_public_key(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get current user's public key"""
|
||||
return {
|
||||
"public_key": current_user.public_key,
|
||||
"user_id": current_user.id
|
||||
}
|
||||
|
||||
@router.get("/public-key/{user_id}")
|
||||
def get_user_public_key(
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get another user's public key"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return {
|
||||
"public_key": user.public_key,
|
||||
"user_id": user.id,
|
||||
"username": user.username
|
||||
}
|
||||
|
||||
@router.post("/encrypt", response_model=EncryptResponse)
|
||||
def encrypt_message(
|
||||
request: EncryptRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Encrypt a message for a specific user"""
|
||||
recipient = db.query(User).filter(User.id == request.recipient_user_id).first()
|
||||
if not recipient:
|
||||
raise HTTPException(status_code=404, detail="Recipient not found")
|
||||
|
||||
if not recipient.public_key:
|
||||
raise HTTPException(status_code=400, detail="Recipient has no public key")
|
||||
|
||||
encrypted_message = encryption_service.encrypt_message(
|
||||
request.message,
|
||||
recipient.public_key
|
||||
)
|
||||
|
||||
return EncryptResponse(encrypted_message=encrypted_message)
|
||||
|
||||
@router.post("/decrypt", response_model=DecryptResponse)
|
||||
def decrypt_message(
|
||||
request: DecryptRequest,
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Decrypt a message using private key"""
|
||||
try:
|
||||
decrypted_message = encryption_service.decrypt_message(
|
||||
request.encrypted_message,
|
||||
request.private_key
|
||||
)
|
||||
return DecryptResponse(decrypted_message=decrypted_message)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Decryption failed: {str(e)}")
|
||||
|
||||
@router.put("/public-key")
|
||||
def update_public_key(
|
||||
public_key: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update user's public key"""
|
||||
current_user.public_key = public_key
|
||||
db.commit()
|
||||
|
||||
return {"message": "Public key updated successfully"}
|
224
app/api/media.py
Normal file
224
app/api/media.py
Normal file
@ -0,0 +1,224 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models import User, Message as MessageModel, Media as MediaModel, ChatMember as ChatMemberModel
|
||||
from app.schemas.message import MediaFile
|
||||
from app.services.media_service import media_service
|
||||
import io
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/upload/{message_id}", response_model=List[MediaFile])
|
||||
async def upload_media(
|
||||
message_id: int,
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Upload media files for a message"""
|
||||
# Verify message exists and user can access it
|
||||
message = db.query(MessageModel).filter(MessageModel.id == message_id).first()
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Verify user is member of the chat
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Verify user owns the message
|
||||
if message.sender_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Can only upload media to your own messages")
|
||||
|
||||
media_files = []
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
# Save file
|
||||
filename, file_path, file_size, media_type, width, height, thumbnail_path = await media_service.save_file(
|
||||
file, current_user.id
|
||||
)
|
||||
|
||||
# Create media record
|
||||
media = MediaModel(
|
||||
message_id=message_id,
|
||||
uploader_id=current_user.id,
|
||||
filename=filename,
|
||||
original_filename=file.filename,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
mime_type=file.content_type or "application/octet-stream",
|
||||
media_type=media_type,
|
||||
width=width,
|
||||
height=height,
|
||||
thumbnail_path=thumbnail_path,
|
||||
is_processed=True
|
||||
)
|
||||
|
||||
db.add(media)
|
||||
db.commit()
|
||||
db.refresh(media)
|
||||
|
||||
# Create response
|
||||
media_file = MediaFile(
|
||||
id=media.id,
|
||||
filename=media.filename,
|
||||
original_filename=media.original_filename,
|
||||
file_size=media.file_size,
|
||||
mime_type=media.mime_type,
|
||||
media_type=media.media_type.value,
|
||||
width=media.width,
|
||||
height=media.height,
|
||||
duration=media.duration,
|
||||
thumbnail_url=f"/api/media/{media.id}/thumbnail" if media.thumbnail_path else None
|
||||
)
|
||||
media_files.append(media_file)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload file: {str(e)}")
|
||||
|
||||
return media_files
|
||||
|
||||
@router.get("/{media_id}")
|
||||
async def download_media(
|
||||
media_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Download media file"""
|
||||
media = db.query(MediaModel).filter(MediaModel.id == media_id).first()
|
||||
if not media:
|
||||
raise HTTPException(status_code=404, detail="Media not found")
|
||||
|
||||
# Verify user can access this media
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == media.message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
try:
|
||||
content = await media_service.get_file_content(media.filename)
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type=media.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={media.original_filename}"
|
||||
}
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
@router.get("/{media_id}/thumbnail")
|
||||
async def get_media_thumbnail(
|
||||
media_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get media thumbnail"""
|
||||
media = db.query(MediaModel).filter(MediaModel.id == media_id).first()
|
||||
if not media:
|
||||
raise HTTPException(status_code=404, detail="Media not found")
|
||||
|
||||
if not media.thumbnail_path:
|
||||
raise HTTPException(status_code=404, detail="Thumbnail not available")
|
||||
|
||||
# Verify user can access this media
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == media.message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
try:
|
||||
content = await media_service.get_thumbnail_content(media.filename)
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type="image/jpeg"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Thumbnail not found")
|
||||
|
||||
@router.delete("/{media_id}")
|
||||
async def delete_media(
|
||||
media_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Delete media file"""
|
||||
media = db.query(MediaModel).filter(MediaModel.id == media_id).first()
|
||||
if not media:
|
||||
raise HTTPException(status_code=404, detail="Media not found")
|
||||
|
||||
# Verify user owns the media or is admin/owner of chat
|
||||
can_delete = media.uploader_id == current_user.id
|
||||
|
||||
if not can_delete:
|
||||
from app.models.chat_member import MemberRole
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == media.message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id,
|
||||
ChatMemberModel.role.in_([MemberRole.ADMIN, MemberRole.OWNER])
|
||||
).first()
|
||||
can_delete = membership is not None
|
||||
|
||||
if not can_delete:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# Delete file from storage
|
||||
thumbnail_filename = f"thumb_{media.filename}" if media.thumbnail_path else None
|
||||
media_service.delete_file(media.filename, thumbnail_filename)
|
||||
|
||||
# Delete database record
|
||||
db.delete(media)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Media deleted successfully"}
|
||||
|
||||
@router.get("/{media_id}/info", response_model=MediaFile)
|
||||
async def get_media_info(
|
||||
media_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get media file information"""
|
||||
media = db.query(MediaModel).filter(MediaModel.id == media_id).first()
|
||||
if not media:
|
||||
raise HTTPException(status_code=404, detail="Media not found")
|
||||
|
||||
# Verify user can access this media
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == media.message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return MediaFile(
|
||||
id=media.id,
|
||||
filename=media.filename,
|
||||
original_filename=media.original_filename,
|
||||
file_size=media.file_size,
|
||||
mime_type=media.mime_type,
|
||||
media_type=media.media_type.value,
|
||||
width=media.width,
|
||||
height=media.height,
|
||||
duration=media.duration,
|
||||
thumbnail_url=f"/api/media/{media.id}/thumbnail" if media.thumbnail_path else None
|
||||
)
|
236
app/api/messages.py
Normal file
236
app/api/messages.py
Normal file
@ -0,0 +1,236 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models import User, ChatMember as ChatMemberModel, Message as MessageModel
|
||||
from app.schemas.message import Message, MessageUpdate, MessageList, MediaFile, MessageMention
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/{chat_id}", response_model=MessageList)
|
||||
def get_chat_messages(
|
||||
chat_id: int,
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get messages for a chat with pagination"""
|
||||
# Verify user is member of chat
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Calculate offset
|
||||
offset = (page - 1) * limit
|
||||
|
||||
# Get total count
|
||||
total = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat_id,
|
||||
MessageModel.is_deleted.is_(False)
|
||||
).count()
|
||||
|
||||
# Get messages
|
||||
messages_query = db.query(MessageModel).filter(
|
||||
MessageModel.chat_id == chat_id,
|
||||
MessageModel.is_deleted.is_(False)
|
||||
).order_by(desc(MessageModel.created_at)).offset(offset).limit(limit)
|
||||
|
||||
db_messages = messages_query.all()
|
||||
|
||||
# Convert to response format
|
||||
messages = []
|
||||
for msg in db_messages:
|
||||
# Get media files
|
||||
media_files = []
|
||||
for media in msg.media_files:
|
||||
media_file = MediaFile(
|
||||
id=media.id,
|
||||
filename=media.filename,
|
||||
original_filename=media.original_filename,
|
||||
file_size=media.file_size,
|
||||
mime_type=media.mime_type,
|
||||
media_type=media.media_type.value,
|
||||
width=media.width,
|
||||
height=media.height,
|
||||
duration=media.duration,
|
||||
thumbnail_url=f"/api/media/{media.id}/thumbnail" if media.thumbnail_path else None
|
||||
)
|
||||
media_files.append(media_file)
|
||||
|
||||
# Get mentions
|
||||
mentions = []
|
||||
for mention in msg.mentions:
|
||||
mention_data = MessageMention(
|
||||
id=mention.mentioned_user.id,
|
||||
username=mention.mentioned_user.username,
|
||||
full_name=mention.mentioned_user.full_name
|
||||
)
|
||||
mentions.append(mention_data)
|
||||
|
||||
# Get reply-to message if exists
|
||||
reply_to = None
|
||||
if msg.reply_to:
|
||||
reply_to = Message(
|
||||
id=msg.reply_to.id,
|
||||
chat_id=msg.reply_to.chat_id,
|
||||
sender_id=msg.reply_to.sender_id,
|
||||
content=msg.reply_to.content,
|
||||
content_type=msg.reply_to.content_type,
|
||||
status=msg.reply_to.status,
|
||||
is_edited=msg.reply_to.is_edited,
|
||||
is_deleted=msg.reply_to.is_deleted,
|
||||
edited_at=msg.reply_to.edited_at,
|
||||
created_at=msg.reply_to.created_at,
|
||||
sender_username=msg.reply_to.sender.username,
|
||||
sender_avatar=msg.reply_to.sender.avatar_url,
|
||||
media_files=[],
|
||||
mentions=[],
|
||||
reply_to=None
|
||||
)
|
||||
|
||||
message = Message(
|
||||
id=msg.id,
|
||||
chat_id=msg.chat_id,
|
||||
sender_id=msg.sender_id,
|
||||
content=msg.content,
|
||||
content_type=msg.content_type,
|
||||
reply_to_id=msg.reply_to_id,
|
||||
status=msg.status,
|
||||
is_edited=msg.is_edited,
|
||||
is_deleted=msg.is_deleted,
|
||||
edited_at=msg.edited_at,
|
||||
created_at=msg.created_at,
|
||||
sender_username=msg.sender.username,
|
||||
sender_avatar=msg.sender.avatar_url,
|
||||
media_files=media_files,
|
||||
mentions=mentions,
|
||||
reply_to=reply_to
|
||||
)
|
||||
messages.append(message)
|
||||
|
||||
# Reverse to show oldest first
|
||||
messages.reverse()
|
||||
|
||||
return MessageList(
|
||||
messages=messages,
|
||||
total=total,
|
||||
page=page,
|
||||
has_more=offset + limit < total
|
||||
)
|
||||
|
||||
@router.put("/{message_id}", response_model=Message)
|
||||
def update_message(
|
||||
message_id: int,
|
||||
message_update: MessageUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update/edit a message (sender only)"""
|
||||
message = db.query(MessageModel).filter(MessageModel.id == message_id).first()
|
||||
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
if message.sender_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Can only edit your own messages")
|
||||
|
||||
if message.is_deleted:
|
||||
raise HTTPException(status_code=400, detail="Cannot edit deleted message")
|
||||
|
||||
# Update message
|
||||
update_data = message_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(message, field, value)
|
||||
|
||||
message.is_edited = True
|
||||
from datetime import datetime
|
||||
message.edited_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
# Return updated message (simplified response)
|
||||
return Message(
|
||||
id=message.id,
|
||||
chat_id=message.chat_id,
|
||||
sender_id=message.sender_id,
|
||||
content=message.content,
|
||||
content_type=message.content_type,
|
||||
reply_to_id=message.reply_to_id,
|
||||
status=message.status,
|
||||
is_edited=message.is_edited,
|
||||
is_deleted=message.is_deleted,
|
||||
edited_at=message.edited_at,
|
||||
created_at=message.created_at,
|
||||
sender_username=message.sender.username,
|
||||
sender_avatar=message.sender.avatar_url,
|
||||
media_files=[],
|
||||
mentions=[],
|
||||
reply_to=None
|
||||
)
|
||||
|
||||
@router.delete("/{message_id}")
|
||||
def delete_message(
|
||||
message_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Delete a message (sender only or admin/owner)"""
|
||||
message = db.query(MessageModel).filter(MessageModel.id == message_id).first()
|
||||
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Check if user can delete (sender or admin/owner of chat)
|
||||
can_delete = message.sender_id == current_user.id
|
||||
|
||||
if not can_delete:
|
||||
# Check if user is admin/owner of the chat
|
||||
from app.models.chat_member import MemberRole
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id,
|
||||
ChatMemberModel.role.in_([MemberRole.ADMIN, MemberRole.OWNER])
|
||||
).first()
|
||||
can_delete = membership is not None
|
||||
|
||||
if not can_delete:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# Soft delete
|
||||
message.is_deleted = True
|
||||
message.content = "[Deleted]"
|
||||
db.commit()
|
||||
|
||||
return {"message": "Message deleted successfully"}
|
||||
|
||||
@router.post("/{message_id}/read")
|
||||
def mark_message_read(
|
||||
message_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Mark message as read"""
|
||||
message = db.query(MessageModel).filter(MessageModel.id == message_id).first()
|
||||
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Update user's last read message
|
||||
membership = db.query(ChatMemberModel).filter(
|
||||
ChatMemberModel.chat_id == message.chat_id,
|
||||
ChatMemberModel.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if membership:
|
||||
if not membership.last_read_message_id or membership.last_read_message_id < message.id:
|
||||
membership.last_read_message_id = message.id
|
||||
db.commit()
|
||||
|
||||
return {"message": "Message marked as read"}
|
112
app/api/notifications.py
Normal file
112
app/api/notifications.py
Normal file
@ -0,0 +1,112 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.services.notification_service import notification_service
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class NotificationResponse(BaseModel):
|
||||
id: int
|
||||
type: str
|
||||
title: str
|
||||
body: str
|
||||
data: dict
|
||||
is_read: bool
|
||||
created_at: str
|
||||
read_at: str = None
|
||||
|
||||
class MentionResponse(BaseModel):
|
||||
id: int
|
||||
message_id: int
|
||||
chat_id: int
|
||||
chat_name: str
|
||||
sender_username: str
|
||||
message_content: str
|
||||
created_at: str
|
||||
|
||||
@router.get("/", response_model=List[NotificationResponse])
|
||||
def get_notifications(
|
||||
unread_only: bool = Query(False),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get user notifications"""
|
||||
notifications = notification_service.get_user_notifications(
|
||||
user_id=current_user.id,
|
||||
unread_only=unread_only,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return [NotificationResponse(**notif) for notif in notifications]
|
||||
|
||||
@router.get("/mentions", response_model=List[MentionResponse])
|
||||
def get_unread_mentions(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get unread mentions for user"""
|
||||
mentions = notification_service.get_unread_mentions(current_user.id)
|
||||
return [MentionResponse(**mention) for mention in mentions]
|
||||
|
||||
@router.post("/{notification_id}/read")
|
||||
async def mark_notification_read(
|
||||
notification_id: int,
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Mark notification as read"""
|
||||
await notification_service.mark_notification_read(notification_id, current_user.id)
|
||||
return {"message": "Notification marked as read"}
|
||||
|
||||
@router.post("/mentions/{mention_id}/read")
|
||||
async def mark_mention_read(
|
||||
mention_id: int,
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Mark mention as read"""
|
||||
await notification_service.mark_mention_read(mention_id, current_user.id)
|
||||
return {"message": "Mention marked as read"}
|
||||
|
||||
@router.post("/read-all")
|
||||
async def mark_all_notifications_read(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Mark all notifications as read"""
|
||||
from app.models.notification import Notification
|
||||
from datetime import datetime
|
||||
|
||||
db.query(Notification).filter(
|
||||
Notification.user_id == current_user.id,
|
||||
Notification.is_read.is_(False)
|
||||
).update({
|
||||
"is_read": True,
|
||||
"read_at": datetime.utcnow()
|
||||
})
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"message": "All notifications marked as read"}
|
||||
|
||||
@router.get("/count")
|
||||
def get_notification_count(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get unread notification count"""
|
||||
notifications = notification_service.get_user_notifications(
|
||||
user_id=current_user.id,
|
||||
unread_only=True,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
mentions = notification_service.get_unread_mentions(current_user.id)
|
||||
|
||||
return {
|
||||
"unread_notifications": len(notifications),
|
||||
"unread_mentions": len(mentions),
|
||||
"total_unread": len(notifications) + len(mentions)
|
||||
}
|
102
app/api/tasks.py
102
app/api/tasks.py
@ -1,102 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
from app.db.session import get_db
|
||||
from app.core.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.task import Task as TaskModel, TaskStatus
|
||||
from app.schemas.task import Task, TaskCreate, TaskUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=Task)
|
||||
def create_task(
|
||||
task: TaskCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
db_task = TaskModel(**task.dict(), owner_id=current_user.id)
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
@router.get("/", response_model=List[Task])
|
||||
def read_tasks(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: Optional[TaskStatus] = Query(None),
|
||||
category_id: Optional[int] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
query = db.query(TaskModel).filter(TaskModel.owner_id == current_user.id)
|
||||
|
||||
if status:
|
||||
query = query.filter(TaskModel.status == status)
|
||||
if category_id:
|
||||
query = query.filter(TaskModel.category_id == category_id)
|
||||
|
||||
tasks = query.offset(skip).limit(limit).all()
|
||||
return tasks
|
||||
|
||||
@router.get("/{task_id}", response_model=Task)
|
||||
def read_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
task = db.query(TaskModel).filter(
|
||||
TaskModel.id == task_id,
|
||||
TaskModel.owner_id == current_user.id
|
||||
).first()
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
@router.put("/{task_id}", response_model=Task)
|
||||
def update_task(
|
||||
task_id: int,
|
||||
task_update: TaskUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
task = db.query(TaskModel).filter(
|
||||
TaskModel.id == task_id,
|
||||
TaskModel.owner_id == current_user.id
|
||||
).first()
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
update_data = task_update.dict(exclude_unset=True)
|
||||
|
||||
# Set completed_at when task is marked as completed
|
||||
if update_data.get("status") == TaskStatus.COMPLETED and task.status != TaskStatus.COMPLETED:
|
||||
update_data["completed_at"] = datetime.utcnow()
|
||||
elif update_data.get("status") != TaskStatus.COMPLETED:
|
||||
update_data["completed_at"] = None
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
def delete_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
task = db.query(TaskModel).filter(
|
||||
TaskModel.id == task_id,
|
||||
TaskModel.owner_id == current_user.id
|
||||
).first()
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
return {"message": "Task deleted successfully"}
|
@ -37,3 +37,18 @@ def get_current_active_user(
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
async def get_user_from_token(token: str, db: Session) -> User:
|
||||
"""Get user from WebSocket token"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
user = db.query(User).filter(User.id == int(user_id)).first()
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except JWTError:
|
||||
return None
|
@ -3,6 +3,10 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key, load_pem_private_key
|
||||
import base64
|
||||
|
||||
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
||||
ALGORITHM = "HS256"
|
||||
@ -28,3 +32,59 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
# E2E Encryption utilities
|
||||
def generate_rsa_key_pair():
|
||||
"""Generate RSA key pair for E2E encryption"""
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize keys
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
return private_pem.decode(), public_pem.decode()
|
||||
|
||||
def encrypt_message(message: str, public_key_pem: str) -> str:
|
||||
"""Encrypt message using recipient's public key"""
|
||||
try:
|
||||
public_key = load_pem_public_key(public_key_pem.encode())
|
||||
encrypted = public_key.encrypt(
|
||||
message.encode(),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
return base64.b64encode(encrypted).decode()
|
||||
except Exception:
|
||||
return message # Fallback to unencrypted if encryption fails
|
||||
|
||||
def decrypt_message(encrypted_message: str, private_key_pem: str) -> str:
|
||||
"""Decrypt message using recipient's private key"""
|
||||
try:
|
||||
private_key = load_pem_private_key(private_key_pem.encode(), password=None)
|
||||
encrypted_bytes = base64.b64decode(encrypted_message.encode())
|
||||
decrypted = private_key.decrypt(
|
||||
encrypted_bytes,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
return decrypted.decode()
|
||||
except Exception:
|
||||
return encrypted_message # Fallback to encrypted if decryption fails
|
@ -6,7 +6,7 @@ from app.db.base import Base
|
||||
DB_DIR = Path("/app/storage/db")
|
||||
DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_DIR}/db.sqlite"
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_DIR}/chat.sqlite"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
|
@ -1,5 +1,17 @@
|
||||
from .user import User
|
||||
from .task import Task, TaskStatus, TaskPriority
|
||||
from .category import Category
|
||||
from .chat import Chat, ChatType
|
||||
from .chat_member import ChatMember, MemberRole
|
||||
from .message import Message, MessageType, MessageStatus
|
||||
from .media import Media, MediaType
|
||||
from .mention import Mention
|
||||
from .notification import Notification, NotificationType
|
||||
|
||||
__all__ = ["User", "Task", "TaskStatus", "TaskPriority", "Category"]
|
||||
__all__ = [
|
||||
"User",
|
||||
"Chat", "ChatType",
|
||||
"ChatMember", "MemberRole",
|
||||
"Message", "MessageType", "MessageStatus",
|
||||
"Media", "MediaType",
|
||||
"Mention",
|
||||
"Notification", "NotificationType"
|
||||
]
|
@ -1,18 +0,0 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.db.base import Base
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = "categories"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
color = Column(String, nullable=True)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
owner = relationship("User", back_populates="categories")
|
||||
tasks = relationship("Task", back_populates="category")
|
25
app/models/chat.py
Normal file
25
app/models/chat.py
Normal file
@ -0,0 +1,25 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.base import Base
|
||||
|
||||
class ChatType(PyEnum):
|
||||
DIRECT = "direct"
|
||||
GROUP = "group"
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chats"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=True) # Null for direct messages
|
||||
description = Column(Text, nullable=True)
|
||||
chat_type = Column(Enum(ChatType), nullable=False)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# Relationships
|
||||
messages = relationship("Message", back_populates="chat", cascade="all, delete-orphan")
|
||||
members = relationship("ChatMember", back_populates="chat", cascade="all, delete-orphan")
|
27
app/models/chat_member.py
Normal file
27
app/models/chat_member.py
Normal file
@ -0,0 +1,27 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.base import Base
|
||||
|
||||
class MemberRole(PyEnum):
|
||||
MEMBER = "member"
|
||||
ADMIN = "admin"
|
||||
OWNER = "owner"
|
||||
|
||||
class ChatMember(Base):
|
||||
__tablename__ = "chat_members"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chat_id = Column(Integer, ForeignKey("chats.id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
role = Column(Enum(MemberRole), default=MemberRole.MEMBER)
|
||||
nickname = Column(String, nullable=True) # Chat-specific nickname
|
||||
is_muted = Column(Boolean, default=False)
|
||||
is_banned = Column(Boolean, default=False)
|
||||
joined_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_read_message_id = Column(Integer, nullable=True) # For unread count
|
||||
|
||||
# Relationships
|
||||
chat = relationship("Chat", back_populates="members")
|
||||
user = relationship("User", back_populates="chat_members")
|
35
app/models/media.py
Normal file
35
app/models/media.py
Normal file
@ -0,0 +1,35 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, BigInteger, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.base import Base
|
||||
|
||||
class MediaType(PyEnum):
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = "document"
|
||||
OTHER = "other"
|
||||
|
||||
class Media(Base):
|
||||
__tablename__ = "media"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(Integer, ForeignKey("messages.id"), nullable=False)
|
||||
uploader_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
filename = Column(String, nullable=False)
|
||||
original_filename = Column(String, nullable=False)
|
||||
file_path = Column(String, nullable=False)
|
||||
file_size = Column(BigInteger, nullable=False) # Size in bytes
|
||||
mime_type = Column(String, nullable=False)
|
||||
media_type = Column(Enum(MediaType), nullable=False)
|
||||
width = Column(Integer, nullable=True) # For images/videos
|
||||
height = Column(Integer, nullable=True) # For images/videos
|
||||
duration = Column(Integer, nullable=True) # For videos/audio in seconds
|
||||
thumbnail_path = Column(String, nullable=True) # For videos/images
|
||||
is_processed = Column(Boolean, default=False) # For media processing
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
message = relationship("Message", back_populates="media_files")
|
||||
uploader = relationship("User", back_populates="sent_media")
|
18
app/models/mention.py
Normal file
18
app/models/mention.py
Normal file
@ -0,0 +1,18 @@
|
||||
from sqlalchemy import Column, Integer, DateTime, Boolean, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.db.base import Base
|
||||
|
||||
class Mention(Base):
|
||||
__tablename__ = "mentions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(Integer, ForeignKey("messages.id"), nullable=False)
|
||||
mentioned_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
is_read = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
read_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
message = relationship("Message", back_populates="mentions")
|
||||
mentioned_user = relationship("User")
|
41
app/models/message.py
Normal file
41
app/models/message.py
Normal file
@ -0,0 +1,41 @@
|
||||
from sqlalchemy import Column, Integer, DateTime, Boolean, Text, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.base import Base
|
||||
|
||||
class MessageType(PyEnum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
SYSTEM = "system" # For system messages like "User joined"
|
||||
|
||||
class MessageStatus(PyEnum):
|
||||
SENT = "sent"
|
||||
DELIVERED = "delivered"
|
||||
READ = "read"
|
||||
FAILED = "failed"
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chat_id = Column(Integer, ForeignKey("chats.id"), nullable=False)
|
||||
sender_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
reply_to_id = Column(Integer, ForeignKey("messages.id"), nullable=True) # For replies
|
||||
content = Column(Text, nullable=True) # Encrypted content
|
||||
content_type = Column(Enum(MessageType), default=MessageType.TEXT)
|
||||
status = Column(Enum(MessageStatus), default=MessageStatus.SENT)
|
||||
is_edited = Column(Boolean, default=False)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
edited_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
chat = relationship("Chat", back_populates="messages")
|
||||
sender = relationship("User", foreign_keys=[sender_id], back_populates="sent_messages")
|
||||
reply_to = relationship("Message", remote_side=[id])
|
||||
media_files = relationship("Media", back_populates="message", cascade="all, delete-orphan")
|
||||
mentions = relationship("Mention", back_populates="message", cascade="all, delete-orphan")
|
33
app/models/notification.py
Normal file
33
app/models/notification.py
Normal file
@ -0,0 +1,33 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.base import Base
|
||||
|
||||
class NotificationType(PyEnum):
|
||||
MESSAGE = "message"
|
||||
MENTION = "mention"
|
||||
GROUP_INVITE = "group_invite"
|
||||
GROUP_JOIN = "group_join"
|
||||
GROUP_LEAVE = "group_leave"
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
chat_id = Column(Integer, ForeignKey("chats.id"), nullable=True)
|
||||
message_id = Column(Integer, ForeignKey("messages.id"), nullable=True)
|
||||
notification_type = Column(Enum(NotificationType), nullable=False)
|
||||
title = Column(String, nullable=False)
|
||||
body = Column(Text, nullable=False)
|
||||
data = Column(Text, nullable=True) # JSON data for additional info
|
||||
is_read = Column(Boolean, default=False)
|
||||
is_sent = Column(Boolean, default=False) # Push notification sent
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
read_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
chat = relationship("Chat")
|
||||
message = relationship("Message")
|
@ -1,35 +0,0 @@
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.base import Base
|
||||
|
||||
class TaskStatus(PyEnum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class TaskPriority(PyEnum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
URGENT = "urgent"
|
||||
|
||||
class Task(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
title = Column(String, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
status = Column(Enum(TaskStatus), default=TaskStatus.PENDING)
|
||||
priority = Column(Enum(TaskPriority), default=TaskPriority.MEDIUM)
|
||||
due_date = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
category_id = Column(Integer, ForeignKey("categories.id"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
owner = relationship("User", back_populates="tasks")
|
||||
category = relationship("Category", back_populates="tasks")
|
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.db.base import Base
|
||||
@ -7,12 +7,21 @@ class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String, unique=True, index=True, nullable=False)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
full_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
bio = Column(Text, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_online = Column(Boolean, default=False)
|
||||
last_seen = Column(DateTime(timezone=True), nullable=True)
|
||||
public_key = Column(Text, nullable=True) # For E2E encryption
|
||||
device_token = Column(String, nullable=True) # For push notifications
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
tasks = relationship("Task", back_populates="owner")
|
||||
categories = relationship("Category", back_populates="owner")
|
||||
# Relationships
|
||||
sent_messages = relationship("Message", foreign_keys="Message.sender_id", back_populates="sender")
|
||||
chat_members = relationship("ChatMember", back_populates="user")
|
||||
sent_media = relationship("Media", back_populates="uploader")
|
@ -1,9 +1,9 @@
|
||||
from .user import User, UserCreate, UserUpdate, Token
|
||||
from .task import Task, TaskCreate, TaskUpdate
|
||||
from .category import Category, CategoryCreate, CategoryUpdate
|
||||
from .user import User, UserCreate, UserUpdate, UserLogin, UserPublic, Token
|
||||
from .chat import Chat, ChatCreate, DirectChatCreate, ChatUpdate, ChatMember, ChatMemberCreate, ChatMemberUpdate, ChatList
|
||||
from .message import Message, MessageCreate, MessageUpdate, MessageList, MediaFile
|
||||
|
||||
__all__ = [
|
||||
"User", "UserCreate", "UserUpdate", "Token",
|
||||
"Task", "TaskCreate", "TaskUpdate",
|
||||
"Category", "CategoryCreate", "CategoryUpdate"
|
||||
"User", "UserCreate", "UserUpdate", "UserLogin", "UserPublic", "Token",
|
||||
"Chat", "ChatCreate", "DirectChatCreate", "ChatUpdate", "ChatMember", "ChatMemberCreate", "ChatMemberUpdate", "ChatList",
|
||||
"Message", "MessageCreate", "MessageUpdate", "MessageList", "MediaFile"
|
||||
]
|
@ -1,28 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
class CategoryBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
class CategoryCreate(CategoryBase):
|
||||
pass
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
class CategoryInDBBase(CategoryBase):
|
||||
id: int
|
||||
owner_id: int
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class Category(CategoryInDBBase):
|
||||
pass
|
75
app/schemas/chat.py
Normal file
75
app/schemas/chat.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from app.models.chat import ChatType
|
||||
from app.models.chat_member import MemberRole
|
||||
from app.schemas.user import UserPublic
|
||||
|
||||
class ChatBase(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
chat_type: ChatType
|
||||
avatar_url: Optional[str] = None
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
member_ids: List[int] = [] # For group chats
|
||||
|
||||
class DirectChatCreate(BaseModel):
|
||||
other_user_id: int
|
||||
|
||||
class ChatUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
|
||||
class ChatMemberBase(BaseModel):
|
||||
user_id: int
|
||||
role: MemberRole = MemberRole.MEMBER
|
||||
nickname: Optional[str] = None
|
||||
|
||||
class ChatMemberCreate(ChatMemberBase):
|
||||
pass
|
||||
|
||||
class ChatMemberUpdate(BaseModel):
|
||||
role: Optional[MemberRole] = None
|
||||
nickname: Optional[str] = None
|
||||
is_muted: Optional[bool] = None
|
||||
|
||||
class ChatMember(ChatMemberBase):
|
||||
id: int
|
||||
chat_id: int
|
||||
is_muted: bool = False
|
||||
is_banned: bool = False
|
||||
joined_at: datetime
|
||||
last_read_message_id: Optional[int] = None
|
||||
user: UserPublic
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class ChatInDBBase(ChatBase):
|
||||
id: int
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class Chat(ChatInDBBase):
|
||||
members: List[ChatMember] = []
|
||||
last_message: Optional[dict] = None
|
||||
unread_count: int = 0
|
||||
|
||||
class ChatList(BaseModel):
|
||||
id: int
|
||||
name: Optional[str] = None
|
||||
chat_type: ChatType
|
||||
avatar_url: Optional[str] = None
|
||||
last_message: Optional[dict] = None
|
||||
last_activity: Optional[datetime] = None
|
||||
unread_count: int = 0
|
||||
is_online: bool = False # For direct chats
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
64
app/schemas/message.py
Normal file
64
app/schemas/message.py
Normal file
@ -0,0 +1,64 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from app.models.message import MessageType, MessageStatus
|
||||
|
||||
class MessageBase(BaseModel):
|
||||
content: Optional[str] = None
|
||||
content_type: MessageType = MessageType.TEXT
|
||||
reply_to_id: Optional[int] = None
|
||||
|
||||
class MessageCreate(MessageBase):
|
||||
chat_id: int
|
||||
|
||||
class MessageUpdate(BaseModel):
|
||||
content: Optional[str] = None
|
||||
|
||||
class MediaFile(BaseModel):
|
||||
id: int
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
media_type: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
duration: Optional[int] = None
|
||||
thumbnail_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class MessageMention(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
full_name: Optional[str] = None
|
||||
|
||||
class MessageInDBBase(MessageBase):
|
||||
id: int
|
||||
chat_id: int
|
||||
sender_id: int
|
||||
status: MessageStatus
|
||||
is_edited: bool = False
|
||||
is_deleted: bool = False
|
||||
edited_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class Message(MessageInDBBase):
|
||||
sender_username: str
|
||||
sender_avatar: Optional[str] = None
|
||||
media_files: List[MediaFile] = []
|
||||
mentions: List[MessageMention] = []
|
||||
reply_to: Optional['Message'] = None
|
||||
|
||||
class MessageList(BaseModel):
|
||||
messages: List[Message]
|
||||
total: int
|
||||
page: int
|
||||
has_more: bool
|
||||
|
||||
# Enable forward references
|
||||
Message.model_rebuild()
|
@ -1,36 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from app.models.task import TaskStatus, TaskPriority
|
||||
|
||||
class TaskBase(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
priority: TaskPriority = TaskPriority.MEDIUM
|
||||
due_date: Optional[datetime] = None
|
||||
category_id: Optional[int] = None
|
||||
|
||||
class TaskCreate(TaskBase):
|
||||
pass
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
status: Optional[TaskStatus] = None
|
||||
priority: Optional[TaskPriority] = None
|
||||
due_date: Optional[datetime] = None
|
||||
category_id: Optional[int] = None
|
||||
|
||||
class TaskInDBBase(TaskBase):
|
||||
id: int
|
||||
owner_id: int
|
||||
completed_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class Task(TaskInDBBase):
|
||||
pass
|
@ -3,21 +3,33 @@ from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime
|
||||
|
||||
class UserBase(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
full_name: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
is_active: bool = True
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
full_name: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
class UserInDBBase(UserBase):
|
||||
id: int
|
||||
avatar_url: Optional[str] = None
|
||||
is_online: bool = False
|
||||
last_seen: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@ -27,12 +39,26 @@ class UserInDBBase(UserBase):
|
||||
class User(UserInDBBase):
|
||||
pass
|
||||
|
||||
class UserPublic(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
full_name: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
is_online: bool = False
|
||||
last_seen: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class UserInDB(UserInDBBase):
|
||||
hashed_password: str
|
||||
public_key: Optional[str] = None
|
||||
device_token: Optional[str] = None
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
user: UserPublic
|
||||
|
||||
class TokenData(BaseModel):
|
||||
user_id: Optional[int] = None
|
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Services package
|
208
app/services/encryption_service.py
Normal file
208
app/services/encryption_service.py
Normal file
@ -0,0 +1,208 @@
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key, load_pem_private_key
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user import User
|
||||
from app.models.chat_member import ChatMember
|
||||
import os
|
||||
|
||||
class EncryptionService:
|
||||
def __init__(self):
|
||||
self.algorithm = hashes.SHA256()
|
||||
|
||||
def generate_rsa_key_pair(self) -> tuple[str, str]:
|
||||
"""Generate RSA key pair for E2E encryption"""
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize keys
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
|
||||
return private_pem.decode(), public_pem.decode()
|
||||
|
||||
def encrypt_message_for_chat(self, message: str, chat_id: int, db: Session) -> Dict[str, str]:
|
||||
"""
|
||||
Encrypt message for all chat members
|
||||
Returns a dict with user_id as key and encrypted message as value
|
||||
"""
|
||||
# Get all chat members with their public keys
|
||||
members = db.query(ChatMember).join(User).filter(
|
||||
ChatMember.chat_id == chat_id,
|
||||
User.public_key.isnot(None)
|
||||
).all()
|
||||
|
||||
encrypted_messages = {}
|
||||
|
||||
for member in members:
|
||||
user = member.user
|
||||
if user.public_key:
|
||||
try:
|
||||
encrypted_msg = self.encrypt_message(message, user.public_key)
|
||||
encrypted_messages[str(user.id)] = encrypted_msg
|
||||
except Exception as e:
|
||||
print(f"Failed to encrypt for user {user.id}: {e}")
|
||||
# Store unencrypted as fallback
|
||||
encrypted_messages[str(user.id)] = message
|
||||
else:
|
||||
# No public key, store unencrypted
|
||||
encrypted_messages[str(user.id)] = message
|
||||
|
||||
return encrypted_messages
|
||||
|
||||
def encrypt_message(self, message: str, public_key_pem: str) -> str:
|
||||
"""Encrypt message using recipient's public key"""
|
||||
try:
|
||||
public_key = load_pem_public_key(public_key_pem.encode())
|
||||
|
||||
# For messages longer than RSA key size, use hybrid encryption
|
||||
if len(message.encode()) > 190: # RSA 2048 can encrypt ~190 bytes
|
||||
return self._hybrid_encrypt(message, public_key)
|
||||
else:
|
||||
encrypted = public_key.encrypt(
|
||||
message.encode(),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
return base64.b64encode(encrypted).decode()
|
||||
except Exception as e:
|
||||
print(f"Encryption failed: {e}")
|
||||
return message # Fallback to unencrypted
|
||||
|
||||
def decrypt_message(self, encrypted_message: str, private_key_pem: str) -> str:
|
||||
"""Decrypt message using recipient's private key"""
|
||||
try:
|
||||
private_key = load_pem_private_key(private_key_pem.encode(), password=None)
|
||||
|
||||
# Check if it's hybrid encryption (contains ':')
|
||||
if ':' in encrypted_message:
|
||||
return self._hybrid_decrypt(encrypted_message, private_key)
|
||||
else:
|
||||
encrypted_bytes = base64.b64decode(encrypted_message.encode())
|
||||
decrypted = private_key.decrypt(
|
||||
encrypted_bytes,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
print(f"Decryption failed: {e}")
|
||||
return encrypted_message # Return encrypted if decryption fails
|
||||
|
||||
def _hybrid_encrypt(self, message: str, public_key) -> str:
|
||||
"""
|
||||
Hybrid encryption: Use AES for message, RSA for AES key
|
||||
Format: base64(encrypted_aes_key):base64(encrypted_message)
|
||||
"""
|
||||
# Generate AES key
|
||||
aes_key = os.urandom(32) # 256-bit key
|
||||
iv = os.urandom(16) # 128-bit IV
|
||||
|
||||
# Encrypt message with AES
|
||||
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv))
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# Pad message to multiple of 16 bytes
|
||||
padded_message = self._pad_message(message.encode())
|
||||
encrypted_message = encryptor.update(padded_message) + encryptor.finalize()
|
||||
|
||||
# Encrypt AES key + IV with RSA
|
||||
key_iv = aes_key + iv
|
||||
encrypted_key_iv = public_key.encrypt(
|
||||
key_iv,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
|
||||
# Combine encrypted key and message
|
||||
encrypted_key_b64 = base64.b64encode(encrypted_key_iv).decode()
|
||||
encrypted_msg_b64 = base64.b64encode(encrypted_message).decode()
|
||||
|
||||
return f"{encrypted_key_b64}:{encrypted_msg_b64}"
|
||||
|
||||
def _hybrid_decrypt(self, encrypted_data: str, private_key) -> str:
|
||||
"""Hybrid decryption: Decrypt AES key with RSA, then message with AES"""
|
||||
try:
|
||||
encrypted_key_b64, encrypted_msg_b64 = encrypted_data.split(':', 1)
|
||||
|
||||
# Decrypt AES key + IV
|
||||
encrypted_key_iv = base64.b64decode(encrypted_key_b64.encode())
|
||||
key_iv = private_key.decrypt(
|
||||
encrypted_key_iv,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
|
||||
aes_key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
|
||||
# Decrypt message with AES
|
||||
encrypted_message = base64.b64decode(encrypted_msg_b64.encode())
|
||||
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv))
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
padded_message = decryptor.update(encrypted_message) + decryptor.finalize()
|
||||
message = self._unpad_message(padded_message).decode()
|
||||
|
||||
return message
|
||||
except Exception as e:
|
||||
print(f"Hybrid decryption failed: {e}")
|
||||
return encrypted_data
|
||||
|
||||
def _pad_message(self, message: bytes) -> bytes:
|
||||
"""PKCS7 padding"""
|
||||
padding_length = 16 - (len(message) % 16)
|
||||
padding = bytes([padding_length] * padding_length)
|
||||
return message + padding
|
||||
|
||||
def _unpad_message(self, padded_message: bytes) -> bytes:
|
||||
"""Remove PKCS7 padding"""
|
||||
padding_length = padded_message[-1]
|
||||
return padded_message[:-padding_length]
|
||||
|
||||
def get_user_encrypted_message(self, encrypted_messages: Dict[str, str], user_id: int) -> str:
|
||||
"""Get encrypted message for specific user"""
|
||||
return encrypted_messages.get(str(user_id), "")
|
||||
|
||||
def encrypt_file_metadata(self, metadata: Dict, public_key_pem: str) -> str:
|
||||
"""Encrypt file metadata"""
|
||||
metadata_json = json.dumps(metadata)
|
||||
return self.encrypt_message(metadata_json, public_key_pem)
|
||||
|
||||
def decrypt_file_metadata(self, encrypted_metadata: str, private_key_pem: str) -> Dict:
|
||||
"""Decrypt file metadata"""
|
||||
try:
|
||||
metadata_json = self.decrypt_message(encrypted_metadata, private_key_pem)
|
||||
return json.loads(metadata_json)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
# Global encryption service instance
|
||||
encryption_service = EncryptionService()
|
183
app/services/media_service.py
Normal file
183
app/services/media_service.py
Normal file
@ -0,0 +1,183 @@
|
||||
import uuid
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from fastapi import UploadFile
|
||||
from PIL import Image
|
||||
import magic
|
||||
from app.models.media import MediaType
|
||||
|
||||
class MediaService:
|
||||
def __init__(self):
|
||||
self.storage_path = Path("/app/storage/media")
|
||||
self.thumbnails_path = Path("/app/storage/thumbnails")
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self.thumbnails_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Max file sizes (in bytes)
|
||||
self.max_file_sizes = {
|
||||
MediaType.IMAGE: 10 * 1024 * 1024, # 10MB
|
||||
MediaType.VIDEO: 100 * 1024 * 1024, # 100MB
|
||||
MediaType.AUDIO: 50 * 1024 * 1024, # 50MB
|
||||
MediaType.DOCUMENT: 25 * 1024 * 1024 # 25MB
|
||||
}
|
||||
|
||||
# Allowed MIME types
|
||||
self.allowed_types = {
|
||||
MediaType.IMAGE: [
|
||||
"image/jpeg", "image/png", "image/gif", "image/webp", "image/bmp"
|
||||
],
|
||||
MediaType.VIDEO: [
|
||||
"video/mp4", "video/avi", "video/mov", "video/wmv", "video/flv", "video/webm"
|
||||
],
|
||||
MediaType.AUDIO: [
|
||||
"audio/mp3", "audio/wav", "audio/ogg", "audio/m4a", "audio/flac"
|
||||
],
|
||||
MediaType.DOCUMENT: [
|
||||
"application/pdf", "application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"text/plain", "text/csv", "application/zip", "application/x-rar-compressed"
|
||||
]
|
||||
}
|
||||
|
||||
async def save_file(self, file: UploadFile, uploader_id: int) -> Tuple[str, str, int, MediaType, Optional[int], Optional[int], Optional[str]]:
|
||||
"""
|
||||
Save uploaded file and return file info
|
||||
Returns: (filename, file_path, file_size, media_type, width, height, thumbnail_path)
|
||||
"""
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Detect MIME type
|
||||
mime_type = magic.from_buffer(content, mime=True)
|
||||
|
||||
# Determine media type
|
||||
media_type = self._get_media_type(mime_type)
|
||||
|
||||
# Validate file type and size
|
||||
self._validate_file(media_type, mime_type, file_size)
|
||||
|
||||
# Generate unique filename
|
||||
file_extension = self._get_file_extension(file.filename, mime_type)
|
||||
filename = f"{uuid.uuid4()}{file_extension}"
|
||||
file_path = self.storage_path / filename
|
||||
|
||||
# Save file
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
# Process media (get dimensions, create thumbnail)
|
||||
width, height, thumbnail_path = await self._process_media(
|
||||
file_path, media_type, filename
|
||||
)
|
||||
|
||||
return str(filename), str(file_path), file_size, media_type, width, height, thumbnail_path
|
||||
|
||||
def _get_media_type(self, mime_type: str) -> MediaType:
|
||||
"""Determine media type from MIME type"""
|
||||
for media_type, allowed_mimes in self.allowed_types.items():
|
||||
if mime_type in allowed_mimes:
|
||||
return media_type
|
||||
return MediaType.OTHER
|
||||
|
||||
def _validate_file(self, media_type: MediaType, mime_type: str, file_size: int):
|
||||
"""Validate file type and size"""
|
||||
# Check if MIME type is allowed
|
||||
if media_type in self.allowed_types:
|
||||
if mime_type not in self.allowed_types[media_type]:
|
||||
raise ValueError(f"File type {mime_type} not allowed")
|
||||
|
||||
# Check file size
|
||||
max_size = self.max_file_sizes.get(media_type, 10 * 1024 * 1024) # Default 10MB
|
||||
if file_size > max_size:
|
||||
raise ValueError(f"File size {file_size} exceeds maximum {max_size}")
|
||||
|
||||
def _get_file_extension(self, original_filename: str, mime_type: str) -> str:
|
||||
"""Get appropriate file extension"""
|
||||
if original_filename and '.' in original_filename:
|
||||
return '.' + original_filename.rsplit('.', 1)[1].lower()
|
||||
|
||||
# Fallback based on MIME type
|
||||
mime_extensions = {
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
'image/gif': '.gif',
|
||||
'video/mp4': '.mp4',
|
||||
'audio/mp3': '.mp3',
|
||||
'application/pdf': '.pdf',
|
||||
'text/plain': '.txt'
|
||||
}
|
||||
return mime_extensions.get(mime_type, '.bin')
|
||||
|
||||
async def _process_media(self, file_path: Path, media_type: MediaType, filename: str) -> Tuple[Optional[int], Optional[int], Optional[str]]:
|
||||
"""Process media file (get dimensions, create thumbnails)"""
|
||||
width, height, thumbnail_path = None, None, None
|
||||
|
||||
if media_type == MediaType.IMAGE:
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
# Create thumbnail
|
||||
thumbnail_filename = f"thumb_{filename}"
|
||||
thumbnail_path = self.thumbnails_path / thumbnail_filename
|
||||
|
||||
# Create thumbnail (max 300x300)
|
||||
img.thumbnail((300, 300), Image.Resampling.LANCZOS)
|
||||
img.save(thumbnail_path, format='JPEG', quality=85)
|
||||
thumbnail_path = str(thumbnail_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image {filename}: {e}")
|
||||
|
||||
elif media_type == MediaType.VIDEO:
|
||||
# For videos, you'd typically use ffmpeg to get dimensions and create thumbnails
|
||||
# This is a simplified version
|
||||
try:
|
||||
# You would use ffprobe to get video dimensions
|
||||
# For now, we'll set default values
|
||||
width, height = 1920, 1080 # Default values
|
||||
|
||||
# Create video thumbnail using ffmpeg
|
||||
# thumbnail_filename = f"thumb_{filename}.jpg"
|
||||
# thumbnail_path = self.thumbnails_path / thumbnail_filename
|
||||
# ... ffmpeg command to extract frame ...
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing video {filename}: {e}")
|
||||
|
||||
return width, height, thumbnail_path
|
||||
|
||||
async def get_file_content(self, filename: str) -> bytes:
|
||||
"""Get file content"""
|
||||
file_path = self.storage_path / filename
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File {filename} not found")
|
||||
|
||||
async with aiofiles.open(file_path, 'rb') as f:
|
||||
return await f.read()
|
||||
|
||||
async def get_thumbnail_content(self, filename: str) -> bytes:
|
||||
"""Get thumbnail content"""
|
||||
thumbnail_path = self.thumbnails_path / f"thumb_{filename}"
|
||||
if not thumbnail_path.exists():
|
||||
raise FileNotFoundError(f"Thumbnail for {filename} not found")
|
||||
|
||||
async with aiofiles.open(thumbnail_path, 'rb') as f:
|
||||
return await f.read()
|
||||
|
||||
def delete_file(self, filename: str, thumbnail_filename: Optional[str] = None):
|
||||
"""Delete file and its thumbnail"""
|
||||
file_path = self.storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
if thumbnail_filename:
|
||||
thumbnail_path = self.thumbnails_path / thumbnail_filename
|
||||
if thumbnail_path.exists():
|
||||
thumbnail_path.unlink()
|
||||
|
||||
# Global media service instance
|
||||
media_service = MediaService()
|
298
app/services/notification_service.py
Normal file
298
app/services/notification_service.py
Normal file
@ -0,0 +1,298 @@
|
||||
import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from app.models import User, Chat, Message, Notification, NotificationType, Mention
|
||||
from app.db.session import SessionLocal
|
||||
from app.websocket.connection_manager import connection_manager
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self):
|
||||
self.notification_queue = []
|
||||
|
||||
async def send_mention_notification(
|
||||
self,
|
||||
mentioned_user_id: int,
|
||||
sender_username: str,
|
||||
message_content: str,
|
||||
chat_id: int,
|
||||
message_id: int
|
||||
):
|
||||
"""Send notification for user mention"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
mentioned_user = db.query(User).filter(User.id == mentioned_user_id).first()
|
||||
chat = db.query(Chat).filter(Chat.id == chat_id).first()
|
||||
|
||||
if not mentioned_user or not chat:
|
||||
return
|
||||
|
||||
# Create notification record
|
||||
notification = Notification(
|
||||
user_id=mentioned_user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
notification_type=NotificationType.MENTION,
|
||||
title=f"Mentioned by {sender_username}",
|
||||
body=f"You were mentioned: {message_content}",
|
||||
data=json.dumps({
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"sender_username": sender_username,
|
||||
"chat_name": chat.name or f"Chat with {sender_username}"
|
||||
})
|
||||
)
|
||||
|
||||
db.add(notification)
|
||||
db.commit()
|
||||
db.refresh(notification)
|
||||
|
||||
# Send real-time notification via WebSocket
|
||||
await self._send_realtime_notification(mentioned_user_id, {
|
||||
"type": "mention_notification",
|
||||
"notification_id": notification.id,
|
||||
"title": notification.title,
|
||||
"body": notification.body,
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"sender_username": sender_username,
|
||||
"created_at": notification.created_at.isoformat()
|
||||
})
|
||||
|
||||
# Add to push notification queue
|
||||
if mentioned_user.device_token:
|
||||
await self._queue_push_notification(notification)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def send_message_notification(
|
||||
self,
|
||||
recipient_user_id: int,
|
||||
sender_username: str,
|
||||
message_content: str,
|
||||
chat_id: int,
|
||||
message_id: int
|
||||
):
|
||||
"""Send notification for new message"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
recipient = db.query(User).filter(User.id == recipient_user_id).first()
|
||||
chat = db.query(Chat).filter(Chat.id == chat_id).first()
|
||||
|
||||
if not recipient or not chat:
|
||||
return
|
||||
|
||||
# Check if user is online and active in this chat
|
||||
is_online = recipient_user_id in connection_manager.active_connections
|
||||
if is_online:
|
||||
# Don't send push notification if user is online
|
||||
return
|
||||
|
||||
# Create notification record
|
||||
notification = Notification(
|
||||
user_id=recipient_user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
notification_type=NotificationType.MESSAGE,
|
||||
title=f"New message from {sender_username}",
|
||||
body=message_content[:100],
|
||||
data=json.dumps({
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"sender_username": sender_username,
|
||||
"chat_name": chat.name or f"Chat with {sender_username}"
|
||||
})
|
||||
)
|
||||
|
||||
db.add(notification)
|
||||
db.commit()
|
||||
db.refresh(notification)
|
||||
|
||||
# Add to push notification queue
|
||||
if recipient.device_token:
|
||||
await self._queue_push_notification(notification)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def send_group_invite_notification(
|
||||
self,
|
||||
invited_user_id: int,
|
||||
inviter_username: str,
|
||||
chat_id: int,
|
||||
group_name: str
|
||||
):
|
||||
"""Send notification for group chat invitation"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
invited_user = db.query(User).filter(User.id == invited_user_id).first()
|
||||
|
||||
if not invited_user:
|
||||
return
|
||||
|
||||
notification = Notification(
|
||||
user_id=invited_user_id,
|
||||
chat_id=chat_id,
|
||||
notification_type=NotificationType.GROUP_INVITE,
|
||||
title=f"Invited to {group_name}",
|
||||
body=f"{inviter_username} invited you to join {group_name}",
|
||||
data=json.dumps({
|
||||
"chat_id": chat_id,
|
||||
"inviter_username": inviter_username,
|
||||
"group_name": group_name
|
||||
})
|
||||
)
|
||||
|
||||
db.add(notification)
|
||||
db.commit()
|
||||
db.refresh(notification)
|
||||
|
||||
# Send real-time notification
|
||||
await self._send_realtime_notification(invited_user_id, {
|
||||
"type": "group_invite_notification",
|
||||
"notification_id": notification.id,
|
||||
"title": notification.title,
|
||||
"body": notification.body,
|
||||
"chat_id": chat_id,
|
||||
"group_name": group_name,
|
||||
"inviter_username": inviter_username,
|
||||
"created_at": notification.created_at.isoformat()
|
||||
})
|
||||
|
||||
if invited_user.device_token:
|
||||
await self._queue_push_notification(notification)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def mark_notification_read(self, notification_id: int, user_id: int):
|
||||
"""Mark notification as read"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
notification = db.query(Notification).filter(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id
|
||||
).first()
|
||||
|
||||
if notification:
|
||||
notification.is_read = True
|
||||
notification.read_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
# Send real-time update
|
||||
await self._send_realtime_notification(user_id, {
|
||||
"type": "notification_read",
|
||||
"notification_id": notification_id
|
||||
})
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def mark_mention_read(self, mention_id: int, user_id: int):
|
||||
"""Mark mention as read"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
mention = db.query(Mention).filter(
|
||||
Mention.id == mention_id,
|
||||
Mention.mentioned_user_id == user_id
|
||||
).first()
|
||||
|
||||
if mention:
|
||||
mention.is_read = True
|
||||
mention.read_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_user_notifications(
|
||||
self,
|
||||
user_id: int,
|
||||
unread_only: bool = False,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[dict]:
|
||||
"""Get user notifications"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
query = db.query(Notification).filter(Notification.user_id == user_id)
|
||||
|
||||
if unread_only:
|
||||
query = query.filter(Notification.is_read.is_(False))
|
||||
|
||||
notifications = query.order_by(
|
||||
Notification.created_at.desc()
|
||||
).offset(offset).limit(limit).all()
|
||||
|
||||
result = []
|
||||
for notif in notifications:
|
||||
data = json.loads(notif.data) if notif.data else {}
|
||||
result.append({
|
||||
"id": notif.id,
|
||||
"type": notif.notification_type.value,
|
||||
"title": notif.title,
|
||||
"body": notif.body,
|
||||
"data": data,
|
||||
"is_read": notif.is_read,
|
||||
"created_at": notif.created_at.isoformat(),
|
||||
"read_at": notif.read_at.isoformat() if notif.read_at else None
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_unread_mentions(self, user_id: int) -> List[dict]:
|
||||
"""Get unread mentions for user"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
mentions = db.query(Mention).join(Message).join(Chat).filter(
|
||||
Mention.mentioned_user_id == user_id,
|
||||
Mention.is_read.is_(False)
|
||||
).all()
|
||||
|
||||
result = []
|
||||
for mention in mentions:
|
||||
message = mention.message
|
||||
result.append({
|
||||
"id": mention.id,
|
||||
"message_id": message.id,
|
||||
"chat_id": message.chat_id,
|
||||
"chat_name": message.chat.name,
|
||||
"sender_username": message.sender.username,
|
||||
"message_content": message.content[:100],
|
||||
"created_at": mention.created_at.isoformat()
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _send_realtime_notification(self, user_id: int, notification_data: dict):
|
||||
"""Send real-time notification via WebSocket"""
|
||||
await connection_manager.send_personal_message(notification_data, user_id)
|
||||
|
||||
async def _queue_push_notification(self, notification: Notification):
|
||||
"""Queue push notification for sending"""
|
||||
# In a real implementation, this would add to a task queue (like Celery)
|
||||
# For now, we'll just store in memory
|
||||
push_data = {
|
||||
"user_id": notification.user_id,
|
||||
"title": notification.title,
|
||||
"body": notification.body,
|
||||
"data": json.loads(notification.data) if notification.data else {}
|
||||
}
|
||||
self.notification_queue.append(push_data)
|
||||
|
||||
# Mark as queued
|
||||
db = SessionLocal()
|
||||
try:
|
||||
notification.is_sent = True
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Global notification service instance
|
||||
notification_service = NotificationService()
|
255
app/services/push_notification_service.py
Normal file
255
app/services/push_notification_service.py
Normal file
@ -0,0 +1,255 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
import firebase_admin
|
||||
from firebase_admin import credentials, messaging
|
||||
from app.models.user import User
|
||||
from app.db.session import SessionLocal
|
||||
|
||||
class PushNotificationService:
|
||||
def __init__(self):
|
||||
self.firebase_app = None
|
||||
self._initialize_firebase()
|
||||
|
||||
def _initialize_firebase(self):
|
||||
"""Initialize Firebase Admin SDK"""
|
||||
try:
|
||||
# Check if Firebase credentials are available
|
||||
firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH")
|
||||
firebase_cred_json = os.getenv("FIREBASE_CREDENTIALS_JSON")
|
||||
|
||||
if firebase_cred_path and os.path.exists(firebase_cred_path):
|
||||
cred = credentials.Certificate(firebase_cred_path)
|
||||
self.firebase_app = firebase_admin.initialize_app(cred)
|
||||
elif firebase_cred_json:
|
||||
cred_dict = json.loads(firebase_cred_json)
|
||||
cred = credentials.Certificate(cred_dict)
|
||||
self.firebase_app = firebase_admin.initialize_app(cred)
|
||||
else:
|
||||
print("Firebase credentials not found. Push notifications will be disabled.")
|
||||
self.firebase_app = None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize Firebase: {e}")
|
||||
self.firebase_app = None
|
||||
|
||||
async def send_push_notification(
|
||||
self,
|
||||
device_token: str,
|
||||
title: str,
|
||||
body: str,
|
||||
data: Optional[Dict[str, str]] = None
|
||||
) -> bool:
|
||||
"""Send push notification to a device"""
|
||||
if not self.firebase_app or not device_token:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create message
|
||||
message = messaging.Message(
|
||||
notification=messaging.Notification(
|
||||
title=title,
|
||||
body=body
|
||||
),
|
||||
data=data or {},
|
||||
token=device_token,
|
||||
android=messaging.AndroidConfig(
|
||||
notification=messaging.AndroidNotification(
|
||||
icon="ic_notification",
|
||||
color="#2196F3",
|
||||
sound="default",
|
||||
channel_id="chat_messages"
|
||||
),
|
||||
priority="high"
|
||||
),
|
||||
apns=messaging.APNSConfig(
|
||||
payload=messaging.APNSPayload(
|
||||
aps=messaging.Aps(
|
||||
alert=messaging.ApsAlert(
|
||||
title=title,
|
||||
body=body
|
||||
),
|
||||
badge=1,
|
||||
sound="default"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Send message
|
||||
response = messaging.send(message)
|
||||
print(f"Successfully sent message: {response}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to send push notification: {e}")
|
||||
return False
|
||||
|
||||
async def send_message_notification(
|
||||
self,
|
||||
recipient_user_id: int,
|
||||
sender_username: str,
|
||||
message_content: str,
|
||||
chat_id: int,
|
||||
chat_name: Optional[str] = None
|
||||
):
|
||||
"""Send push notification for new message"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == recipient_user_id).first()
|
||||
if not user or not user.device_token:
|
||||
return
|
||||
|
||||
title = f"New message from {sender_username}"
|
||||
if chat_name:
|
||||
title = f"New message in {chat_name}"
|
||||
|
||||
body = message_content[:100]
|
||||
if len(message_content) > 100:
|
||||
body += "..."
|
||||
|
||||
data = {
|
||||
"type": "message",
|
||||
"chat_id": str(chat_id),
|
||||
"sender_username": sender_username,
|
||||
"chat_name": chat_name or ""
|
||||
}
|
||||
|
||||
await self.send_push_notification(
|
||||
device_token=user.device_token,
|
||||
title=title,
|
||||
body=body,
|
||||
data=data
|
||||
)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def send_mention_notification(
|
||||
self,
|
||||
mentioned_user_id: int,
|
||||
sender_username: str,
|
||||
message_content: str,
|
||||
chat_id: int,
|
||||
chat_name: Optional[str] = None
|
||||
):
|
||||
"""Send push notification for mention"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == mentioned_user_id).first()
|
||||
if not user or not user.device_token:
|
||||
return
|
||||
|
||||
title = f"You were mentioned by {sender_username}"
|
||||
if chat_name:
|
||||
title = f"Mentioned in {chat_name}"
|
||||
|
||||
body = message_content[:100]
|
||||
if len(message_content) > 100:
|
||||
body += "..."
|
||||
|
||||
data = {
|
||||
"type": "mention",
|
||||
"chat_id": str(chat_id),
|
||||
"sender_username": sender_username,
|
||||
"chat_name": chat_name or ""
|
||||
}
|
||||
|
||||
await self.send_push_notification(
|
||||
device_token=user.device_token,
|
||||
title=title,
|
||||
body=body,
|
||||
data=data
|
||||
)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def send_group_invite_notification(
|
||||
self,
|
||||
invited_user_id: int,
|
||||
inviter_username: str,
|
||||
group_name: str,
|
||||
chat_id: int
|
||||
):
|
||||
"""Send push notification for group invitation"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == invited_user_id).first()
|
||||
if not user or not user.device_token:
|
||||
return
|
||||
|
||||
title = f"Invited to {group_name}"
|
||||
body = f"{inviter_username} invited you to join {group_name}"
|
||||
|
||||
data = {
|
||||
"type": "group_invite",
|
||||
"chat_id": str(chat_id),
|
||||
"inviter_username": inviter_username,
|
||||
"group_name": group_name
|
||||
}
|
||||
|
||||
await self.send_push_notification(
|
||||
device_token=user.device_token,
|
||||
title=title,
|
||||
body=body,
|
||||
data=data
|
||||
)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def send_bulk_notifications(
|
||||
self,
|
||||
tokens_and_messages: List[Dict]
|
||||
) -> Dict[str, int]:
|
||||
"""Send multiple notifications at once"""
|
||||
if not self.firebase_app:
|
||||
return {"success": 0, "failure": len(tokens_and_messages)}
|
||||
|
||||
messages = []
|
||||
for item in tokens_and_messages:
|
||||
message = messaging.Message(
|
||||
notification=messaging.Notification(
|
||||
title=item["title"],
|
||||
body=item["body"]
|
||||
),
|
||||
data=item.get("data", {}),
|
||||
token=item["token"]
|
||||
)
|
||||
messages.append(message)
|
||||
|
||||
try:
|
||||
response = messaging.send_all(messages)
|
||||
return {
|
||||
"success": response.success_count,
|
||||
"failure": response.failure_count
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to send bulk notifications: {e}")
|
||||
return {"success": 0, "failure": len(tokens_and_messages)}
|
||||
|
||||
def validate_device_token(self, device_token: str) -> bool:
|
||||
"""Validate if device token is valid"""
|
||||
if not self.firebase_app or not device_token:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Try to send a test message (dry run)
|
||||
message = messaging.Message(
|
||||
notification=messaging.Notification(
|
||||
title="Test",
|
||||
body="Test"
|
||||
),
|
||||
token=device_token
|
||||
)
|
||||
|
||||
# Dry run - doesn't actually send
|
||||
messaging.send(message, dry_run=True)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Global push notification service instance
|
||||
push_notification_service = PushNotificationService()
|
1
app/utils/__init__.py
Normal file
1
app/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Utils package
|
4
app/websocket/__init__.py
Normal file
4
app/websocket/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .connection_manager import connection_manager
|
||||
from .chat_handler import chat_handler
|
||||
|
||||
__all__ = ["connection_manager", "chat_handler"]
|
203
app/websocket/chat_handler.py
Normal file
203
app/websocket/chat_handler.py
Normal file
@ -0,0 +1,203 @@
|
||||
from typing import Dict, Any
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.orm import Session
|
||||
from app.websocket.connection_manager import connection_manager
|
||||
from app.models import Message, Chat, User, ChatMember, Mention, MessageType, MessageStatus
|
||||
from app.services.notification_service import notification_service
|
||||
from app.db.session import SessionLocal
|
||||
|
||||
class ChatHandler:
|
||||
async def handle_message(self, websocket: WebSocket, data: Dict[str, Any]):
|
||||
"""Handle incoming chat messages"""
|
||||
user_id = connection_manager.connection_user_map.get(websocket)
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
message_type = data.get("type")
|
||||
|
||||
if message_type == "send_message":
|
||||
await self._handle_send_message(user_id, data, db)
|
||||
elif message_type == "typing_start":
|
||||
await self._handle_typing_start(user_id, data, db)
|
||||
elif message_type == "typing_stop":
|
||||
await self._handle_typing_stop(user_id, data, db)
|
||||
elif message_type == "message_read":
|
||||
await self._handle_message_read(user_id, data, db)
|
||||
elif message_type == "join_chat":
|
||||
await self._handle_join_chat(user_id, data, db)
|
||||
elif message_type == "leave_chat":
|
||||
await self._handle_leave_chat(user_id, data, db)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _handle_send_message(self, user_id: int, data: Dict[str, Any], db: Session):
|
||||
"""Handle sending a new message"""
|
||||
chat_id = data.get("chat_id")
|
||||
content = data.get("content", "")
|
||||
reply_to_id = data.get("reply_to_id")
|
||||
message_type = data.get("message_type", "text")
|
||||
|
||||
# Verify user is member of chat
|
||||
member = db.query(ChatMember).filter(
|
||||
ChatMember.chat_id == chat_id,
|
||||
ChatMember.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not member:
|
||||
return
|
||||
|
||||
# Get chat and sender info
|
||||
chat = db.query(Chat).filter(Chat.id == chat_id).first()
|
||||
sender = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not chat or not sender:
|
||||
return
|
||||
|
||||
# Process mentions
|
||||
mentions = self._extract_mentions(content, db)
|
||||
|
||||
# Encrypt message content if recipients have public keys
|
||||
encrypted_content = await self._encrypt_for_recipients(content, chat_id, db)
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
chat_id=chat_id,
|
||||
sender_id=user_id,
|
||||
reply_to_id=reply_to_id,
|
||||
content=encrypted_content,
|
||||
content_type=getattr(MessageType, message_type.upper()),
|
||||
status=MessageStatus.SENT
|
||||
)
|
||||
|
||||
db.add(message)
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
# Create mention records
|
||||
for mentioned_user in mentions:
|
||||
mention = Mention(
|
||||
message_id=message.id,
|
||||
mentioned_user_id=mentioned_user.id
|
||||
)
|
||||
db.add(mention)
|
||||
|
||||
if mentions:
|
||||
db.commit()
|
||||
|
||||
# Prepare message data for broadcast
|
||||
message_data = {
|
||||
"type": "new_message",
|
||||
"message": {
|
||||
"id": message.id,
|
||||
"chat_id": chat_id,
|
||||
"sender_id": user_id,
|
||||
"sender_username": sender.username,
|
||||
"sender_avatar": sender.avatar_url,
|
||||
"content": content, # Send unencrypted for real-time display
|
||||
"content_type": message_type,
|
||||
"reply_to_id": reply_to_id,
|
||||
"created_at": message.created_at.isoformat(),
|
||||
"mentions": [{"id": u.id, "username": u.username} for u in mentions]
|
||||
}
|
||||
}
|
||||
|
||||
# Broadcast to chat members
|
||||
await connection_manager.send_to_chat(message_data, chat_id, exclude_user_id=user_id)
|
||||
|
||||
# Send delivery confirmation to sender
|
||||
await connection_manager.send_personal_message({
|
||||
"type": "message_sent",
|
||||
"message_id": message.id,
|
||||
"chat_id": chat_id,
|
||||
"status": "sent"
|
||||
}, user_id)
|
||||
|
||||
# Send push notifications for mentions
|
||||
for mentioned_user in mentions:
|
||||
await notification_service.send_mention_notification(
|
||||
mentioned_user.id, sender.username, content[:100], chat_id, message.id
|
||||
)
|
||||
|
||||
async def _handle_typing_start(self, user_id: int, data: Dict[str, Any], db: Session):
|
||||
"""Handle typing indicator start"""
|
||||
chat_id = data.get("chat_id")
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
typing_data = {
|
||||
"type": "typing_start",
|
||||
"chat_id": chat_id,
|
||||
"user_id": user_id,
|
||||
"username": user.username if user else "Unknown"
|
||||
}
|
||||
|
||||
await connection_manager.send_to_chat(typing_data, chat_id, exclude_user_id=user_id)
|
||||
|
||||
async def _handle_typing_stop(self, user_id: int, data: Dict[str, Any], db: Session):
|
||||
"""Handle typing indicator stop"""
|
||||
chat_id = data.get("chat_id")
|
||||
|
||||
typing_data = {
|
||||
"type": "typing_stop",
|
||||
"chat_id": chat_id,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
await connection_manager.send_to_chat(typing_data, chat_id, exclude_user_id=user_id)
|
||||
|
||||
async def _handle_message_read(self, user_id: int, data: Dict[str, Any], db: Session):
|
||||
"""Handle message read receipt"""
|
||||
message_id = data.get("message_id")
|
||||
chat_id = data.get("chat_id")
|
||||
|
||||
# Update message status
|
||||
message = db.query(Message).filter(Message.id == message_id).first()
|
||||
if message and message.sender_id != user_id:
|
||||
message.status = MessageStatus.READ
|
||||
db.commit()
|
||||
|
||||
# Notify sender
|
||||
read_receipt = {
|
||||
"type": "message_read",
|
||||
"message_id": message_id,
|
||||
"chat_id": chat_id,
|
||||
"read_by_user_id": user_id
|
||||
}
|
||||
|
||||
await connection_manager.send_personal_message(read_receipt, message.sender_id)
|
||||
|
||||
async def _handle_join_chat(self, user_id: int, data: Dict[str, Any], db: Session):
|
||||
"""Handle user joining chat room (for real-time updates)"""
|
||||
chat_id = data.get("chat_id")
|
||||
connection_manager.add_user_to_chat(user_id, chat_id)
|
||||
|
||||
async def _handle_leave_chat(self, user_id: int, data: Dict[str, Any], db: Session):
|
||||
"""Handle user leaving chat room"""
|
||||
chat_id = data.get("chat_id")
|
||||
connection_manager.remove_user_from_chat(user_id, chat_id)
|
||||
|
||||
def _extract_mentions(self, content: str, db: Session) -> list:
|
||||
"""Extract mentioned users from message content"""
|
||||
import re
|
||||
mentions = []
|
||||
# Find @username patterns
|
||||
mention_pattern = r'@(\w+)'
|
||||
usernames = re.findall(mention_pattern, content)
|
||||
|
||||
for username in usernames:
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if user:
|
||||
mentions.append(user)
|
||||
|
||||
return mentions
|
||||
|
||||
async def _encrypt_for_recipients(self, content: str, chat_id: int, db: Session) -> str:
|
||||
"""Encrypt message for chat recipients (simplified E2E)"""
|
||||
# For now, return content as-is
|
||||
# In a full implementation, you'd encrypt for each recipient's public key
|
||||
return content
|
||||
|
||||
# Global chat handler instance
|
||||
chat_handler = ChatHandler()
|
171
app/websocket/connection_manager.py
Normal file
171
app/websocket/connection_manager.py
Normal file
@ -0,0 +1,171 @@
|
||||
import json
|
||||
from typing import Dict, List, Set
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user import User
|
||||
from app.models.chat_member import ChatMember
|
||||
from app.core.deps import get_user_from_token
|
||||
from app.db.session import SessionLocal
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
# user_id -> list of websocket connections (multiple devices/tabs)
|
||||
self.active_connections: Dict[int, List[WebSocket]] = {}
|
||||
# chat_id -> set of user_ids
|
||||
self.chat_members: Dict[int, Set[int]] = {}
|
||||
# websocket -> user_id mapping
|
||||
self.connection_user_map: Dict[WebSocket, int] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, token: str):
|
||||
"""Accept websocket connection and authenticate user"""
|
||||
await websocket.accept()
|
||||
|
||||
# Get database session
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Authenticate user
|
||||
user = await get_user_from_token(token, db)
|
||||
if not user:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "auth_error",
|
||||
"message": "Authentication failed"
|
||||
}))
|
||||
await websocket.close()
|
||||
return None
|
||||
|
||||
# Update user online status
|
||||
user.is_online = True
|
||||
db.commit()
|
||||
|
||||
# Store connection
|
||||
if user.id not in self.active_connections:
|
||||
self.active_connections[user.id] = []
|
||||
self.active_connections[user.id].append(websocket)
|
||||
self.connection_user_map[websocket] = user.id
|
||||
|
||||
# Load user's chat memberships
|
||||
await self._load_user_chats(user.id, db)
|
||||
|
||||
# Send connection success
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "connected",
|
||||
"user_id": user.id,
|
||||
"message": "Connected successfully"
|
||||
}))
|
||||
|
||||
# Notify other users in chats that this user is online
|
||||
await self._broadcast_user_status(user.id, "online", db)
|
||||
|
||||
return user.id
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def disconnect(self, websocket: WebSocket):
|
||||
"""Handle websocket disconnection"""
|
||||
if websocket not in self.connection_user_map:
|
||||
return
|
||||
|
||||
user_id = self.connection_user_map[websocket]
|
||||
|
||||
# Remove connection
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].remove(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
# Update user offline status if no active connections
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user:
|
||||
user.is_online = False
|
||||
from datetime import datetime
|
||||
user.last_seen = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
# Notify other users that this user is offline
|
||||
await self._broadcast_user_status(user_id, "offline", db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
del self.connection_user_map[websocket]
|
||||
|
||||
async def send_personal_message(self, message: dict, user_id: int):
|
||||
"""Send message to specific user (all their connections)"""
|
||||
if user_id in self.active_connections:
|
||||
disconnected = []
|
||||
for websocket in self.active_connections[user_id]:
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected websockets
|
||||
for ws in disconnected:
|
||||
if ws in self.connection_user_map:
|
||||
await self.disconnect(ws)
|
||||
|
||||
async def send_to_chat(self, message: dict, chat_id: int, exclude_user_id: int = None):
|
||||
"""Send message to all members of a chat"""
|
||||
if chat_id in self.chat_members:
|
||||
for user_id in self.chat_members[chat_id]:
|
||||
if exclude_user_id and user_id == exclude_user_id:
|
||||
continue
|
||||
await self.send_personal_message(message, user_id)
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""Broadcast message to all connected users"""
|
||||
disconnected = []
|
||||
for user_id, connections in self.active_connections.items():
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected websockets
|
||||
for ws in disconnected:
|
||||
if ws in self.connection_user_map:
|
||||
await self.disconnect(ws)
|
||||
|
||||
async def _load_user_chats(self, user_id: int, db: Session):
|
||||
"""Load all chats for a user into memory"""
|
||||
chat_members = db.query(ChatMember).filter(ChatMember.user_id == user_id).all()
|
||||
for member in chat_members:
|
||||
chat_id = member.chat_id
|
||||
if chat_id not in self.chat_members:
|
||||
self.chat_members[chat_id] = set()
|
||||
self.chat_members[chat_id].add(user_id)
|
||||
|
||||
async def _broadcast_user_status(self, user_id: int, status: str, db: Session):
|
||||
"""Broadcast user online/offline status to relevant chats"""
|
||||
chat_members = db.query(ChatMember).filter(ChatMember.user_id == user_id).all()
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
status_message = {
|
||||
"type": "user_status",
|
||||
"user_id": user_id,
|
||||
"username": user.username if user else "Unknown",
|
||||
"status": status,
|
||||
"last_seen": user.last_seen.isoformat() if user and user.last_seen else None
|
||||
}
|
||||
|
||||
for member in chat_members:
|
||||
await self.send_to_chat(status_message, member.chat_id, exclude_user_id=user_id)
|
||||
|
||||
def add_user_to_chat(self, user_id: int, chat_id: int):
|
||||
"""Add user to chat members tracking"""
|
||||
if chat_id not in self.chat_members:
|
||||
self.chat_members[chat_id] = set()
|
||||
self.chat_members[chat_id].add(user_id)
|
||||
|
||||
def remove_user_from_chat(self, user_id: int, chat_id: int):
|
||||
"""Remove user from chat members tracking"""
|
||||
if chat_id in self.chat_members:
|
||||
self.chat_members[chat_id].discard(user_id)
|
||||
if not self.chat_members[chat_id]:
|
||||
del self.chat_members[chat_id]
|
||||
|
||||
# Global connection manager instance
|
||||
connection_manager = ConnectionManager()
|
109
main.py
109
main.py
@ -1,11 +1,14 @@
|
||||
from fastapi import FastAPI
|
||||
import json
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api import auth, tasks, categories
|
||||
from app.api import auth, chats, messages, media, notifications, encryption
|
||||
from app.db.session import create_tables
|
||||
from app.websocket.connection_manager import connection_manager
|
||||
from app.websocket.chat_handler import chat_handler
|
||||
|
||||
app = FastAPI(
|
||||
title="Personal Task Management API",
|
||||
description="A comprehensive API for managing personal tasks, categories, and user authentication",
|
||||
title="Real-time Chat API",
|
||||
description="A comprehensive real-time chat API with WebSocket support, media sharing, E2E encryption, and push notifications",
|
||||
version="1.0.0",
|
||||
openapi_url="/openapi.json"
|
||||
)
|
||||
@ -22,18 +25,32 @@ app.add_middleware(
|
||||
# Create database tables
|
||||
create_tables()
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
||||
app.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
app.include_router(categories.router, prefix="/categories", tags=["categories"])
|
||||
# Include API routers
|
||||
app.include_router(auth.router, prefix="/api/auth", tags=["authentication"])
|
||||
app.include_router(chats.router, prefix="/api/chats", tags=["chats"])
|
||||
app.include_router(messages.router, prefix="/api/messages", tags=["messages"])
|
||||
app.include_router(media.router, prefix="/api/media", tags=["media"])
|
||||
app.include_router(notifications.router, prefix="/api/notifications", tags=["notifications"])
|
||||
app.include_router(encryption.router, prefix="/api/encryption", tags=["encryption"])
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {
|
||||
"title": "Personal Task Management API",
|
||||
"description": "A comprehensive API for managing personal tasks and categories",
|
||||
"title": "Real-time Chat API",
|
||||
"description": "A comprehensive real-time chat API with WebSocket support",
|
||||
"version": "1.0.0",
|
||||
"features": [
|
||||
"Real-time messaging with WebSocket",
|
||||
"Direct messages and group chats",
|
||||
"Media sharing (images, videos, documents)",
|
||||
"End-to-end encryption",
|
||||
"Push notifications",
|
||||
"Mention alerts",
|
||||
"Message read receipts",
|
||||
"Typing indicators"
|
||||
],
|
||||
"documentation": "/docs",
|
||||
"websocket_endpoint": "/ws/{token}",
|
||||
"health_check": "/health"
|
||||
}
|
||||
|
||||
@ -41,6 +58,74 @@ def read_root():
|
||||
def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "Personal Task Management API",
|
||||
"version": "1.0.0"
|
||||
"service": "Real-time Chat API",
|
||||
"version": "1.0.0",
|
||||
"features": {
|
||||
"websocket": "enabled",
|
||||
"media_upload": "enabled",
|
||||
"encryption": "enabled",
|
||||
"notifications": "enabled"
|
||||
}
|
||||
}
|
||||
|
||||
@app.websocket("/ws/{token}")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str):
|
||||
"""WebSocket endpoint for real-time chat"""
|
||||
user_id = await connection_manager.connect(websocket, token)
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
|
||||
try:
|
||||
message_data = json.loads(data)
|
||||
await chat_handler.handle_message(websocket, message_data)
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format"
|
||||
}))
|
||||
except Exception as e:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Error processing message: {str(e)}"
|
||||
}))
|
||||
|
||||
except WebSocketDisconnect:
|
||||
await connection_manager.disconnect(websocket)
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
await connection_manager.disconnect(websocket)
|
||||
|
||||
@app.get("/api/status")
|
||||
def get_system_status():
|
||||
"""Get system status and statistics"""
|
||||
return {
|
||||
"active_connections": len(connection_manager.active_connections),
|
||||
"active_chats": len(connection_manager.chat_members),
|
||||
"server_time": "2024-12-20T12:00:00Z",
|
||||
"uptime": "0 days, 0 hours, 0 minutes"
|
||||
}
|
||||
|
||||
# Add startup event
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
print("🚀 Real-time Chat API starting up...")
|
||||
print("📡 WebSocket endpoint: /ws/{token}")
|
||||
print("📖 API Documentation: /docs")
|
||||
print("💚 Health check: /health")
|
||||
|
||||
# Add shutdown event
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
print("🛑 Real-time Chat API shutting down...")
|
||||
# Clean up active connections
|
||||
for user_id, connections in connection_manager.active_connections.items():
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
@ -1,5 +1,7 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
websockets==12.0
|
||||
python-socketio==5.10.0
|
||||
sqlalchemy==2.0.23
|
||||
alembic==1.13.1
|
||||
pydantic==2.5.0
|
||||
@ -7,4 +9,13 @@ python-multipart==0.0.6
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
python-decouple==3.8
|
||||
cryptography==41.0.7
|
||||
aiofiles==23.2.1
|
||||
pillow==10.1.0
|
||||
python-magic==0.4.27
|
||||
redis==5.0.1
|
||||
celery==5.3.4
|
||||
firebase-admin==6.2.0
|
||||
ruff==0.1.6
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
Loading…
x
Reference in New Issue
Block a user