124 lines
3.0 KiB
Python
124 lines
3.0 KiB
Python
"""
|
|
CSRF token utilities
|
|
"""
|
|
import secrets
|
|
from typing import Optional
|
|
from fastapi import Request, HTTPException, status
|
|
from web.utils.auth import get_session
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# CSRF token length
|
|
CSRF_TOKEN_LENGTH = 32
|
|
|
|
|
|
def generate_csrf_token() -> str:
|
|
"""
|
|
Generate CSRF token
|
|
|
|
Returns:
|
|
Random token
|
|
"""
|
|
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
|
|
|
|
|
async def get_csrf_token(request: Request) -> Optional[str]:
|
|
"""
|
|
Get CSRF token from session.
|
|
|
|
Args:
|
|
request: FastAPI Request object
|
|
|
|
Returns:
|
|
CSRF token or None
|
|
"""
|
|
session_id = request.cookies.get("session_id")
|
|
if not session_id:
|
|
return None
|
|
|
|
session_data = await get_session(session_id)
|
|
if not session_data:
|
|
return None
|
|
|
|
return session_data.get("csrf_token")
|
|
|
|
|
|
async def set_csrf_token(request: Request, token: str) -> None:
|
|
"""
|
|
Set CSRF token in session.
|
|
|
|
Args:
|
|
request: FastAPI Request object
|
|
token: CSRF token
|
|
"""
|
|
session_id = request.cookies.get("session_id")
|
|
if not session_id:
|
|
return
|
|
|
|
session_data = await get_session(session_id)
|
|
if session_data:
|
|
session_data["csrf_token"] = token
|
|
# Save session back to Redis if using Redis
|
|
from shared.config import settings
|
|
if settings.USE_REDIS_SESSIONS:
|
|
from web.utils.redis_session import update_redis_session
|
|
await update_redis_session(session_id, session_data)
|
|
|
|
|
|
async def validate_csrf_token(request: Request, token: Optional[str] = None) -> bool:
|
|
"""
|
|
Validate CSRF token
|
|
|
|
Args:
|
|
request: FastAPI Request object
|
|
token: Token to validate (if None, taken from form/headers)
|
|
|
|
Returns:
|
|
True if token is valid
|
|
"""
|
|
# Get token from different sources
|
|
if not token:
|
|
# Try to get from form
|
|
try:
|
|
form = await request.form()
|
|
token = form.get("csrf_token")
|
|
except Exception as e:
|
|
logger.debug(f"Failed to get CSRF token from form: {e}")
|
|
pass
|
|
|
|
# If not found in form, try from headers
|
|
if not token:
|
|
token = request.headers.get("X-CSRF-Token")
|
|
|
|
if not token:
|
|
return False
|
|
|
|
# Get token from session
|
|
session_token = await get_csrf_token(request)
|
|
if not session_token:
|
|
return False
|
|
|
|
# Compare tokens
|
|
return secrets.compare_digest(token, session_token)
|
|
|
|
|
|
async def verify_csrf(request: Request, token: Optional[str] = None):
|
|
"""
|
|
Verify CSRF token with exception on error
|
|
|
|
Args:
|
|
request: FastAPI Request object
|
|
token: Token to verify
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid
|
|
"""
|
|
if not await validate_csrf_token(request, token):
|
|
logger.warning(f"CSRF token invalid for IP {request.client.host if request.client else 'unknown'}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Invalid CSRF token"
|
|
)
|
|
|