import json import base64 from typing import Dict from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.serialization import load_pem_public_key, load_pem_private_key from sqlalchemy.orm import Session from app.models.user import User from app.models.chat_member import ChatMember import os class EncryptionService: def __init__(self): self.algorithm = hashes.SHA256() def generate_rsa_key_pair(self) -> tuple[str, str]: """Generate RSA key pair for E2E encryption""" private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) public_key = private_key.public_key() # Serialize keys private_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() ) public_pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) return private_pem.decode(), public_pem.decode() def encrypt_message_for_chat(self, message: str, chat_id: int, db: Session) -> Dict[str, str]: """ Encrypt message for all chat members Returns a dict with user_id as key and encrypted message as value """ # Get all chat members with their public keys members = db.query(ChatMember).join(User).filter( ChatMember.chat_id == chat_id, User.public_key.isnot(None) ).all() encrypted_messages = {} for member in members: user = member.user if user.public_key: try: encrypted_msg = self.encrypt_message(message, user.public_key) encrypted_messages[str(user.id)] = encrypted_msg except Exception as e: print(f"Failed to encrypt for user {user.id}: {e}") # Store unencrypted as fallback encrypted_messages[str(user.id)] = message else: # No public key, store unencrypted encrypted_messages[str(user.id)] = message return encrypted_messages def encrypt_message(self, message: str, public_key_pem: str) -> str: """Encrypt message using recipient's public key""" try: public_key = load_pem_public_key(public_key_pem.encode()) # For messages longer than RSA key size, use hybrid encryption if len(message.encode()) > 190: # RSA 2048 can encrypt ~190 bytes return self._hybrid_encrypt(message, public_key) else: encrypted = public_key.encrypt( message.encode(), padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) return base64.b64encode(encrypted).decode() except Exception as e: print(f"Encryption failed: {e}") return message # Fallback to unencrypted def decrypt_message(self, encrypted_message: str, private_key_pem: str) -> str: """Decrypt message using recipient's private key""" try: private_key = load_pem_private_key(private_key_pem.encode(), password=None) # Check if it's hybrid encryption (contains ':') if ':' in encrypted_message: return self._hybrid_decrypt(encrypted_message, private_key) else: encrypted_bytes = base64.b64decode(encrypted_message.encode()) decrypted = private_key.decrypt( encrypted_bytes, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) return decrypted.decode() except Exception as e: print(f"Decryption failed: {e}") return encrypted_message # Return encrypted if decryption fails def _hybrid_encrypt(self, message: str, public_key) -> str: """ Hybrid encryption: Use AES for message, RSA for AES key Format: base64(encrypted_aes_key):base64(encrypted_message) """ # Generate AES key aes_key = os.urandom(32) # 256-bit key iv = os.urandom(16) # 128-bit IV # Encrypt message with AES cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv)) encryptor = cipher.encryptor() # Pad message to multiple of 16 bytes padded_message = self._pad_message(message.encode()) encrypted_message = encryptor.update(padded_message) + encryptor.finalize() # Encrypt AES key + IV with RSA key_iv = aes_key + iv encrypted_key_iv = public_key.encrypt( key_iv, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) # Combine encrypted key and message encrypted_key_b64 = base64.b64encode(encrypted_key_iv).decode() encrypted_msg_b64 = base64.b64encode(encrypted_message).decode() return f"{encrypted_key_b64}:{encrypted_msg_b64}" def _hybrid_decrypt(self, encrypted_data: str, private_key) -> str: """Hybrid decryption: Decrypt AES key with RSA, then message with AES""" try: encrypted_key_b64, encrypted_msg_b64 = encrypted_data.split(':', 1) # Decrypt AES key + IV encrypted_key_iv = base64.b64decode(encrypted_key_b64.encode()) key_iv = private_key.decrypt( encrypted_key_iv, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) aes_key = key_iv[:32] iv = key_iv[32:] # Decrypt message with AES encrypted_message = base64.b64decode(encrypted_msg_b64.encode()) cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv)) decryptor = cipher.decryptor() padded_message = decryptor.update(encrypted_message) + decryptor.finalize() message = self._unpad_message(padded_message).decode() return message except Exception as e: print(f"Hybrid decryption failed: {e}") return encrypted_data def _pad_message(self, message: bytes) -> bytes: """PKCS7 padding""" padding_length = 16 - (len(message) % 16) padding = bytes([padding_length] * padding_length) return message + padding def _unpad_message(self, padded_message: bytes) -> bytes: """Remove PKCS7 padding""" padding_length = padded_message[-1] return padded_message[:-padding_length] def get_user_encrypted_message(self, encrypted_messages: Dict[str, str], user_id: int) -> str: """Get encrypted message for specific user""" return encrypted_messages.get(str(user_id), "") def encrypt_file_metadata(self, metadata: Dict, public_key_pem: str) -> str: """Encrypt file metadata""" metadata_json = json.dumps(metadata) return self.encrypt_message(metadata_json, public_key_pem) def decrypt_file_metadata(self, encrypted_metadata: str, private_key_pem: str) -> Dict: """Decrypt file metadata""" try: metadata_json = self.decrypt_message(encrypted_metadata, private_key_pem) return json.loads(metadata_json) except Exception: return {} # Global encryption service instance encryption_service = EncryptionService()