177 lines
5.2 KiB
Python

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
import pytest
from datetime import datetime, timedelta
from app.database.session import Base
from app.models.task import Task
from main import app
from app.database import get_db
# Setup in-memory SQLite database for testing
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Override the dependency
def override_get_db():
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
# Initialize the test client
client = TestClient(app)
@pytest.fixture(scope="function")
def test_db():
# Create the database and tables
Base.metadata.create_all(bind=engine)
yield
# Drop the tables after the test
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def sample_tasks(test_db):
db = TestingSessionLocal()
tasks = [
Task(
title="Task 1",
description="Description 1",
status="pending",
priority="high",
due_date=datetime.utcnow() + timedelta(days=1),
completed=False
),
Task(
title="Task 2",
description="Description 2",
status="in_progress",
priority="medium",
due_date=datetime.utcnow() + timedelta(days=2),
completed=False
),
Task(
title="Task 3",
description="Description 3",
status="completed",
priority="low",
due_date=datetime.utcnow() - timedelta(days=1),
completed=True
)
]
db.add_all(tasks)
db.commit()
yield tasks
db.close()
def test_health_endpoint(test_db):
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert "timestamp" in data
assert data["db_status"] == "ok"
def test_create_task(test_db):
task_data = {
"title": "New Task",
"description": "New Description",
"status": "pending",
"priority": "high",
"due_date": (datetime.utcnow() + timedelta(days=3)).isoformat(),
"completed": False
}
response = client.post("/api/v1/tasks", json=task_data)
assert response.status_code == 201
data = response.json()
assert data["title"] == task_data["title"]
assert data["description"] == task_data["description"]
assert data["status"] == task_data["status"]
assert data["priority"] == task_data["priority"]
assert data["completed"] == task_data["completed"]
assert "id" in data
def test_get_all_tasks(sample_tasks):
response = client.get("/api/v1/tasks")
assert response.status_code == 200
data = response.json()
assert "tasks" in data
assert "total" in data
assert data["total"] == 3
assert len(data["tasks"]) == 3
def test_get_task_by_id(sample_tasks):
# Get the first task
response = client.get(f"/api/v1/tasks/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
assert data["title"] == "Task 1"
def test_get_non_existent_task():
response = client.get("/api/v1/tasks/999")
assert response.status_code == 404
def test_update_task(sample_tasks):
update_data = {
"title": "Updated Task",
"description": "Updated Description",
"status": "completed",
"completed": True
}
response = client.put("/api/v1/tasks/1", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["title"] == update_data["title"]
assert data["description"] == update_data["description"]
assert data["status"] == update_data["status"]
assert data["completed"] == update_data["completed"]
def test_delete_task(sample_tasks):
# First verify the task exists
response = client.get("/api/v1/tasks/1")
assert response.status_code == 200
# Delete the task
response = client.delete("/api/v1/tasks/1")
assert response.status_code == 204
# Verify it's gone
response = client.get("/api/v1/tasks/1")
assert response.status_code == 404
def test_filter_tasks(sample_tasks):
# Filter by status
response = client.get("/api/v1/tasks?status=pending")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert len(data["tasks"]) == 1
assert data["tasks"][0]["status"] == "pending"
# Filter by priority
response = client.get("/api/v1/tasks?priority=medium")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert len(data["tasks"]) == 1
assert data["tasks"][0]["priority"] == "medium"
# Filter by completed
response = client.get("/api/v1/tasks?completed=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert len(data["tasks"]) == 1
assert data["tasks"][0]["completed"] is True