169 lines
4.6 KiB
Python
169 lines
4.6 KiB
Python
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.db.database import Base, get_db
|
|
from app.models.todo import Todo
|
|
from main import app
|
|
|
|
# Setup in-memory SQLite database for testing
|
|
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
|
|
|
engine = create_engine(
|
|
SQLALCHEMY_DATABASE_URL,
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_db():
|
|
# Create the tables in the test database
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Create a session for testing
|
|
db = TestingSessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
# Drop all tables after the test
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(test_db):
|
|
# Override the get_db dependency
|
|
def override_get_db():
|
|
try:
|
|
yield test_db
|
|
finally:
|
|
pass
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
with TestClient(app) as c:
|
|
yield c
|
|
|
|
# Clear overrides
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
def test_health_check(client):
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "healthy"}
|
|
|
|
|
|
def test_create_todo(client):
|
|
response = client.post(
|
|
"/todos/",
|
|
json={
|
|
"title": "Test Todo",
|
|
"description": "This is a test",
|
|
"completed": False
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["title"] == "Test Todo"
|
|
assert data["description"] == "This is a test"
|
|
assert data["completed"] is False
|
|
assert "id" in data
|
|
assert "created_at" in data
|
|
|
|
|
|
def test_read_todos_empty(client):
|
|
response = client.get("/todos/")
|
|
assert response.status_code == 200
|
|
assert response.json() == []
|
|
|
|
|
|
def test_read_todos(client, test_db):
|
|
# Add a test todo to the database
|
|
todo = Todo(title="Test Todo", description="Test Description", completed=False)
|
|
test_db.add(todo)
|
|
test_db.commit()
|
|
test_db.refresh(todo)
|
|
|
|
response = client.get("/todos/")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 1
|
|
assert data[0]["title"] == "Test Todo"
|
|
assert data[0]["description"] == "Test Description"
|
|
assert data[0]["id"] == todo.id
|
|
|
|
|
|
def test_read_todo(client, test_db):
|
|
# Add a test todo to the database
|
|
todo = Todo(title="Test Todo", description="Test Description", completed=False)
|
|
test_db.add(todo)
|
|
test_db.commit()
|
|
test_db.refresh(todo)
|
|
|
|
response = client.get(f"/todos/{todo.id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["title"] == "Test Todo"
|
|
assert data["description"] == "Test Description"
|
|
assert data["id"] == todo.id
|
|
|
|
|
|
def test_read_todo_not_found(client):
|
|
response = client.get("/todos/999")
|
|
assert response.status_code == 404
|
|
assert response.json() == {"detail": "Todo not found"}
|
|
|
|
|
|
def test_update_todo(client, test_db):
|
|
# Add a test todo to the database
|
|
todo = Todo(title="Test Todo", description="Test Description", completed=False)
|
|
test_db.add(todo)
|
|
test_db.commit()
|
|
test_db.refresh(todo)
|
|
|
|
response = client.put(
|
|
f"/todos/{todo.id}",
|
|
json={"title": "Updated Todo", "completed": True},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["title"] == "Updated Todo"
|
|
assert data["description"] == "Test Description" # Unchanged
|
|
assert data["completed"] is True # Updated
|
|
assert data["id"] == todo.id
|
|
|
|
|
|
def test_update_todo_not_found(client):
|
|
response = client.put(
|
|
"/todos/999",
|
|
json={"title": "Updated Todo", "completed": True},
|
|
)
|
|
assert response.status_code == 404
|
|
assert response.json() == {"detail": "Todo not found"}
|
|
|
|
|
|
def test_delete_todo(client, test_db):
|
|
# Add a test todo to the database
|
|
todo = Todo(title="Test Todo", description="Test Description", completed=False)
|
|
test_db.add(todo)
|
|
test_db.commit()
|
|
test_db.refresh(todo)
|
|
|
|
response = client.delete(f"/todos/{todo.id}")
|
|
assert response.status_code == 204
|
|
|
|
# Verify it's deleted
|
|
todo_check = test_db.query(Todo).filter(Todo.id == todo.id).first()
|
|
assert todo_check is None
|
|
|
|
|
|
def test_delete_todo_not_found(client):
|
|
response = client.delete("/todos/999")
|
|
assert response.status_code == 404
|
|
assert response.json() == {"detail": "Todo not found"} |