87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
from typing import List, Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.wallet import Wallet, WalletType
|
|
from app.schemas.wallet import WalletCreate, WalletUpdate
|
|
|
|
|
|
class CRUDWallet(CRUDBase[Wallet, WalletCreate, WalletUpdate]):
|
|
def get_by_user_and_type(
|
|
self, db: Session, *, user_id: int, wallet_type: WalletType
|
|
) -> Optional[Wallet]:
|
|
return db.query(Wallet).filter(
|
|
Wallet.user_id == user_id,
|
|
Wallet.wallet_type == wallet_type
|
|
).first()
|
|
|
|
def get_by_user(
|
|
self, db: Session, *, user_id: int
|
|
) -> List[Wallet]:
|
|
return db.query(Wallet).filter(Wallet.user_id == user_id).all()
|
|
|
|
def create_for_user(
|
|
self, db: Session, *, user_id: int, wallet_type: WalletType
|
|
) -> Wallet:
|
|
wallet = Wallet(
|
|
user_id=user_id,
|
|
wallet_type=wallet_type,
|
|
balance=0.0
|
|
)
|
|
db.add(wallet)
|
|
db.commit()
|
|
db.refresh(wallet)
|
|
return wallet
|
|
|
|
def update_balance(
|
|
self, db: Session, *, wallet_id: int, amount: float, add: bool = True
|
|
) -> Wallet:
|
|
wallet = self.get(db, id=wallet_id)
|
|
if not wallet:
|
|
return None
|
|
|
|
if add:
|
|
wallet.balance += amount
|
|
else:
|
|
wallet.balance -= amount
|
|
# Ensure balance doesn't go negative
|
|
if wallet.balance < 0:
|
|
wallet.balance = 0
|
|
|
|
db.add(wallet)
|
|
db.commit()
|
|
db.refresh(wallet)
|
|
return wallet
|
|
|
|
def transfer(
|
|
self, db: Session, *, user_id: int, from_type: WalletType, to_type: WalletType, amount: float
|
|
) -> tuple[Wallet, Wallet]:
|
|
from_wallet = self.get_by_user_and_type(db, user_id=user_id, wallet_type=from_type)
|
|
to_wallet = self.get_by_user_and_type(db, user_id=user_id, wallet_type=to_type)
|
|
|
|
if not from_wallet or not to_wallet:
|
|
return None, None
|
|
|
|
if from_wallet.balance < amount:
|
|
return None, None
|
|
|
|
# Update from wallet (subtract)
|
|
from_wallet = self.update_balance(db, wallet_id=from_wallet.id, amount=amount, add=False)
|
|
|
|
# Update to wallet (add)
|
|
to_wallet = self.update_balance(db, wallet_id=to_wallet.id, amount=amount, add=True)
|
|
|
|
return from_wallet, to_wallet
|
|
|
|
|
|
wallet = CRUDWallet(Wallet)
|
|
|
|
|
|
# Aliases for convenience
|
|
get_wallet = wallet.get
|
|
get_wallets_by_user = wallet.get_by_user
|
|
get_wallet_by_user_and_type = wallet.get_by_user_and_type
|
|
create_wallet_for_user = wallet.create_for_user
|
|
update_wallet_balance = wallet.update_balance
|
|
transfer_between_wallets = wallet.transfer |