diff --git a/helpers/ride_helpers.py b/helpers/ride_helpers.py index 6c59774..001d3d4 100644 --- a/helpers/ride_helpers.py +++ b/helpers/ride_helpers.py @@ -1,85 +1,91 @@ from typing import List, Optional -from uuid import UUID -from datetime import datetime -from pydantic import BaseModel from sqlalchemy.orm import Session from models.ride import Ride from schemas.ride import RideCreate, RideSchema -def get_all_rides(db: Session) -> List[Ride]: +def get_all_rides(db: Session) -> List[RideSchema]: """ - Get all rides from the database. + Retrieve all dispatched rides from the database. Args: db (Session): SQLAlchemy database session. Returns: - List[Ride]: List of Ride objects. + List[RideSchema]: List of ride schemas. """ - return db.query(Ride).all() + rides = db.query(Ride).all() + return [RideSchema.from_orm(ride) for ride in rides] -def get_ride_by_id(db: Session, ride_id: UUID) -> Optional[Ride]: - """ - Get a ride by its ID. - - Args: - db (Session): SQLAlchemy database session. - ride_id (UUID): ID of the ride to retrieve. - - Returns: - Optional[Ride]: Ride object if found, None otherwise. - """ - return db.query(Ride).filter(Ride.id == ride_id).first() - -def create_ride(db: Session, ride: RideCreate) -> Ride: +def create_ride(db: Session, ride: RideCreate) -> Optional[RideSchema]: """ Create a new ride in the database. Args: db (Session): SQLAlchemy database session. - ride (RideCreate): Pydantic model for creating a new ride. + ride (RideCreate): Ride data for creation. Returns: - Ride: The newly created Ride object. + Optional[RideSchema]: Ride schema if created successfully, None otherwise. """ db_ride = Ride(**ride.dict()) db.add(db_ride) - db.commit() - db.refresh(db_ride) - return db_ride + try: + db.commit() + db.refresh(db_ride) + return RideSchema.from_orm(db_ride) + except Exception as e: + db.rollback() + return None -def dispatch_ride(db: Session, ride: Ride) -> Ride: +def get_ride_by_id(db: Session, ride_id: int) -> Optional[RideSchema]: """ - Dispatch a ride by updating its status and assigning a driver. + Retrieve a ride by its ID from the database. Args: db (Session): SQLAlchemy database session. - ride (Ride): Ride object to be dispatched. + ride_id (int): ID of the ride to retrieve. Returns: - Ride: The updated Ride object with a dispatched status. + Optional[RideSchema]: Ride schema if found, None otherwise. """ - # Implement logic to assign a driver and update ride status - ride.ride_status = "dispatched" - ride.driver_id = get_available_driver(db) # Implement get_available_driver function - db.commit() - db.refresh(ride) - return ride + ride = db.query(Ride).filter(Ride.id == ride_id).first() + return RideSchema.from_orm(ride) if ride else None -def calculate_ride_fare(ride: Ride) -> float: +def update_ride(db: Session, ride_id: int, ride_data: RideCreate) -> Optional[RideSchema]: """ - Calculate the fare for a given ride based on distance and duration. + Update an existing ride in the database. Args: - ride (Ride): Ride object for which the fare needs to be calculated. + db (Session): SQLAlchemy database session. + ride_id (int): ID of the ride to update. + ride_data (RideCreate): Updated ride data. Returns: - float: The calculated fare amount. + Optional[RideSchema]: Updated ride schema if successful, None otherwise. """ - # Implement logic to calculate fare based on ride distance and duration - base_fare = 5.0 - distance_rate = 1.5 # per km - duration_rate = 0.2 # per minute + ride = db.query(Ride).filter(Ride.id == ride_id).first() + if ride: + for key, value in ride_data.dict(exclude_unset=True).items(): + setattr(ride, key, value) + db.commit() + db.refresh(ride) + return RideSchema.from_orm(ride) + return None - fare = base_fare + (ride.ride_distance * distance_rate) + (ride.ride_duration / 60 * duration_rate) - return fare \ No newline at end of file +def delete_ride(db: Session, ride_id: int) -> bool: + """ + Delete a ride from the database. + + Args: + db (Session): SQLAlchemy database session. + ride_id (int): ID of the ride to delete. + + Returns: + bool: True if the ride was deleted successfully, False otherwise. + """ + ride = db.query(Ride).filter(Ride.id == ride_id).first() + if ride: + db.delete(ride) + db.commit() + return True + return False \ No newline at end of file