import os import shutil import uuid from pathlib import Path from typing import List from fastapi import HTTPException, UploadFile, status # Define allowed extensions ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "webp"} ALLOWED_DOCUMENT_EXTENSIONS = {"pdf", "doc", "docx", "ppt", "pptx", "xls", "xlsx", "txt"} # Define storage paths STORAGE_DIR = Path("/app") / "storage" IMAGES_DIR = STORAGE_DIR / "images" DOCUMENTS_DIR = STORAGE_DIR / "documents" # Create directories if they don't exist IMAGES_DIR.mkdir(parents=True, exist_ok=True) DOCUMENTS_DIR.mkdir(parents=True, exist_ok=True) def validate_file_extension(filename: str, allowed_extensions: List[str]) -> bool: """ Validate if the file extension is allowed """ return filename.split(".")[-1].lower() in allowed_extensions async def save_upload_file( upload_file: UploadFile, destination: Path, allowed_extensions: List[str] ) -> str: """ Save an uploaded file to the destination directory """ if not validate_file_extension(upload_file.filename, allowed_extensions): valid_extensions = ", ".join(allowed_extensions) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"File extension not allowed. Allowed extensions: {valid_extensions}", ) # Generate a unique filename to prevent overwriting file_extension = upload_file.filename.split(".")[-1].lower() new_filename = f"{uuid.uuid4()}.{file_extension}" file_path = destination / new_filename # Write file with open(file_path, "wb") as buffer: shutil.copyfileobj(upload_file.file, buffer) return str(file_path.relative_to(STORAGE_DIR)) async def save_image(upload_file: UploadFile) -> str: """ Save an uploaded image """ return await save_upload_file(upload_file, IMAGES_DIR, ALLOWED_IMAGE_EXTENSIONS) async def save_document(upload_file: UploadFile) -> str: """ Save an uploaded document """ return await save_upload_file(upload_file, DOCUMENTS_DIR, ALLOWED_DOCUMENT_EXTENSIONS) def delete_file(file_path: str) -> bool: """ Delete a file from storage """ # Ensure the file path is relative to the storage directory if file_path.startswith("/"): file_path = file_path.lstrip("/") full_path = STORAGE_DIR / file_path # Security check: ensure the file is within the storage directory try: if not full_path.resolve().is_relative_to(STORAGE_DIR.resolve()): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file path", ) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file path", ) if full_path.exists(): os.remove(full_path) return True return False