# SECURE: Parameterized queries with input validation
import asyncpg
import asyncio
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, validator, EmailStr
from typing import Optional, List, Dict, Any
import bcrypt
from datetime import datetime, date
import re
DATABASE_URL = "postgresql://user:password@localhost/mydb"
app = FastAPI()
# Pydantic models for input validation
class LoginRequest(BaseModel):
username: str
password: str
@validator('username')
def validate_username(cls, v):
if not re.match(r'^[a-zA-Z0-9_]{3,30}$', v):
raise ValueError('Invalid username format')
return v
class ProfileUpdate(BaseModel):
bio: Optional[str] = None
website: Optional[str] = None
location: Optional[str] = None
@validator('bio')
def validate_bio(cls, v):
if v and len(v) > 500:
raise ValueError('Bio too long')
return v
class TransactionFilter(BaseModel):
amount_min: Optional[float] = None
amount_max: Optional[float] = None
date_from: Optional[date] = None
date_to: Optional[date] = None
status: Optional[str] = None
@validator('status')
def validate_status(cls, v):
if v and v not in ['pending', 'completed', 'failed', 'cancelled']:
raise ValueError('Invalid status')
return v
# Database connection pool
class Database:
pool: asyncpg.Pool = None
db = Database()
@app.on_event("startup")
async def startup():
db.pool = await asyncpg.create_pool(DATABASE_URL)
@app.on_event("shutdown")
async def shutdown():
await db.pool.close()
# Secure database functions with parameterized queries
async def authenticate(username: str, password: str) -> Optional[Dict]:
# Input already validated by Pydantic
async with db.pool.acquire() as conn:
# Parameterized query with $1 placeholder
query = """
SELECT id, username, email, role, password_hash
FROM users
WHERE username = $1
"""
user = await conn.fetchrow(query, username)
if user:
# Verify password with bcrypt
if bcrypt.checkpw(password.encode('utf-8'),
user['password_hash'].encode('utf-8')):
return {
'id': user['id'],
'username': user['username'],
'email': user['email'],
'role': user['role']
}
return None
async def get_financial_report(user_id: int, account_type: str,
year: int, include_summary: bool) -> List[Dict]:
# Validate inputs
if user_id <= 0:
raise ValueError("Invalid user ID")
allowed_account_types = ['checking', 'savings', 'credit', 'investment']
if account_type not in allowed_account_types:
raise ValueError("Invalid account type")
if year < 2000 or year > datetime.now().year:
raise ValueError("Invalid year")
async with db.pool.acquire() as conn:
# Base query with parameters
base_query = """
SELECT
t.transaction_date,
t.amount,
t.description,
a.account_name,
a.account_type
FROM transactions t
JOIN accounts a ON t.account_id = a.id
WHERE a.user_id = $1
AND a.account_type = $2
AND EXTRACT(YEAR FROM t.transaction_date) = $3
ORDER BY t.transaction_date DESC
"""
if include_summary:
# Use CTE with parameters
query = """
WITH transaction_data AS (
SELECT
t.transaction_date,
t.amount,
t.description,
a.account_name,
a.account_type
FROM transactions t
JOIN accounts a ON t.account_id = a.id
WHERE a.user_id = $1
AND a.account_type = $2
AND EXTRACT(YEAR FROM t.transaction_date) = $3
)
SELECT
td.*,
summary.total,
summary.average
FROM transaction_data td
CROSS JOIN (
SELECT
COALESCE(SUM(amount), 0) as total,
COALESCE(AVG(amount), 0) as average
FROM transaction_data
) summary
ORDER BY td.transaction_date DESC
"""
results = await conn.fetch(query, user_id, account_type, year)
else:
results = await conn.fetch(base_query, user_id, account_type, year)
return [dict(r) for r in results]
async def update_profile(user_id: int, profile_data: ProfileUpdate) -> bool:
if user_id <= 0:
raise ValueError("Invalid user ID")
async with db.pool.acquire() as conn:
# Build UPDATE query safely with parameters
update_fields = []
params = []
param_count = 1
if profile_data.bio is not None:
update_fields.append(f"bio = ${param_count}")
params.append(profile_data.bio)
param_count += 1
if profile_data.website is not None:
update_fields.append(f"website = ${param_count}")
params.append(profile_data.website)
param_count += 1
if profile_data.location is not None:
update_fields.append(f"location = ${param_count}")
params.append(profile_data.location)
param_count += 1
if not update_fields:
return False
# Add updated_at field
update_fields.append(f"updated_at = ${param_count}")
params.append(datetime.now())
param_count += 1
# Add user_id as last parameter
params.append(user_id)
query = f"""
UPDATE user_profiles
SET {', '.join(update_fields)}
WHERE user_id = ${param_count}
RETURNING user_id
"""
result = await conn.fetchval(query, *params)
return result is not None
async def search_transactions(search_query: Optional[str],
filters: TransactionFilter) -> List[Dict]:
async with db.pool.acquire() as conn:
# Build query with parameters
query_parts = ["""
SELECT t.id, t.amount, t.description, t.transaction_date,
t.reference_number, t.status,
a.account_name, u.username
FROM transactions t
JOIN accounts a ON t.account_id = a.id
JOIN users u ON a.user_id = u.id
WHERE 1=1
"""]
params = []
param_count = 0
# Search in multiple fields safely
if search_query:
param_count += 1
query_parts.append(f"""
AND (
t.description ILIKE ${param_count} OR
t.reference_number LIKE ${param_count} OR
a.account_name ILIKE ${param_count}
)
""")
search_pattern = f"%{search_query}%"
# Add the same parameter three times for the three conditions
params.extend([search_pattern, search_pattern, search_pattern])
param_count += 2 # We added 3 references but same param
# Apply filters with parameters
if filters.amount_min is not None:
param_count += 1
query_parts.append(f"AND t.amount >= ${param_count}")
params.append(filters.amount_min)
if filters.amount_max is not None:
param_count += 1
query_parts.append(f"AND t.amount <= ${param_count}")
params.append(filters.amount_max)
if filters.date_from:
param_count += 1
query_parts.append(f"AND t.transaction_date >= ${param_count}")
params.append(filters.date_from)
if filters.date_to:
param_count += 1
query_parts.append(f"AND t.transaction_date <= ${param_count}")
params.append(filters.date_to)
if filters.status:
param_count += 1
query_parts.append(f"AND t.status = ${param_count}")
params.append(filters.status)
query_parts.append("ORDER BY t.transaction_date DESC LIMIT 1000")
query = " ".join(query_parts)
results = await conn.fetch(query, *params)
return [dict(r) for r in results]
# Secure API endpoints
@app.post("/login")
async def login(request: LoginRequest):
try:
user = await authenticate(request.username, request.password)
if user:
# In production, create secure JWT token
return {"success": True, "user": user}
raise HTTPException(status_code=401, detail="Invalid credentials")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/report/{user_id}")
async def get_report(user_id: int, account_type: str, year: int,
include_summary: bool = False):
try:
report = await get_financial_report(user_id, account_type,
year, include_summary)
return {"report": report}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.put("/profile/{user_id}")
async def update_user_profile(user_id: int, profile: ProfileUpdate):
try:
success = await update_profile(user_id, profile)
if success:
return {"success": True}
raise HTTPException(status_code=404, detail="User not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))