""" CSRF token utilities """ import secrets from typing import Optional from fastapi import Request, HTTPException, status from web.utils.auth import get_session import logging logger = logging.getLogger(__name__) # CSRF token length CSRF_TOKEN_LENGTH = 32 def generate_csrf_token() -> str: """ Generate CSRF token Returns: Random token """ return secrets.token_urlsafe(CSRF_TOKEN_LENGTH) async def get_csrf_token(request: Request) -> Optional[str]: """ Get CSRF token from session. Args: request: FastAPI Request object Returns: CSRF token or None """ session_id = request.cookies.get("session_id") if not session_id: return None session_data = await get_session(session_id) if not session_data: return None return session_data.get("csrf_token") async def set_csrf_token(request: Request, token: str) -> None: """ Set CSRF token in session. Args: request: FastAPI Request object token: CSRF token """ session_id = request.cookies.get("session_id") if not session_id: return session_data = await get_session(session_id) if session_data: session_data["csrf_token"] = token # Save session back to Redis if using Redis from shared.config import settings if settings.USE_REDIS_SESSIONS: from web.utils.redis_session import update_redis_session await update_redis_session(session_id, session_data) async def validate_csrf_token(request: Request, token: Optional[str] = None) -> bool: """ Validate CSRF token Args: request: FastAPI Request object token: Token to validate (if None, taken from form/headers) Returns: True if token is valid """ # Get token from different sources if not token: # Try to get from form try: form = await request.form() token = form.get("csrf_token") except Exception as e: logger.debug(f"Failed to get CSRF token from form: {e}") pass # If not found in form, try from headers if not token: token = request.headers.get("X-CSRF-Token") if not token: return False # Get token from session session_token = await get_csrf_token(request) if not session_token: return False # Compare tokens return secrets.compare_digest(token, session_token) async def verify_csrf(request: Request, token: Optional[str] = None): """ Verify CSRF token with exception on error Args: request: FastAPI Request object token: Token to verify Raises: HTTPException: If token is invalid """ if not await validate_csrf_token(request, token): logger.warning(f"CSRF token invalid for IP {request.client.host if request.client else 'unknown'}") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid CSRF token" )