Add source
This commit is contained in:
123
web/utils/csrf.py
Normal file
123
web/utils/csrf.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
CSRF token utilities
|
||||
"""
|
||||
import secrets
|
||||
from typing import Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from web.utils.auth import get_session
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CSRF token length
|
||||
CSRF_TOKEN_LENGTH = 32
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""
|
||||
Generate CSRF token
|
||||
|
||||
Returns:
|
||||
Random token
|
||||
"""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
async def get_csrf_token(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get CSRF token from session.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
CSRF token or None
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
session_data = await get_session(session_id)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
return session_data.get("csrf_token")
|
||||
|
||||
|
||||
async def set_csrf_token(request: Request, token: str) -> None:
|
||||
"""
|
||||
Set CSRF token in session.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
token: CSRF token
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
session_data = await get_session(session_id)
|
||||
if session_data:
|
||||
session_data["csrf_token"] = token
|
||||
# Save session back to Redis if using Redis
|
||||
from shared.config import settings
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
from web.utils.redis_session import update_redis_session
|
||||
await update_redis_session(session_id, session_data)
|
||||
|
||||
|
||||
async def validate_csrf_token(request: Request, token: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Validate CSRF token
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
token: Token to validate (if None, taken from form/headers)
|
||||
|
||||
Returns:
|
||||
True if token is valid
|
||||
"""
|
||||
# Get token from different sources
|
||||
if not token:
|
||||
# Try to get from form
|
||||
try:
|
||||
form = await request.form()
|
||||
token = form.get("csrf_token")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get CSRF token from form: {e}")
|
||||
pass
|
||||
|
||||
# If not found in form, try from headers
|
||||
if not token:
|
||||
token = request.headers.get("X-CSRF-Token")
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
# Get token from session
|
||||
session_token = await get_csrf_token(request)
|
||||
if not session_token:
|
||||
return False
|
||||
|
||||
# Compare tokens
|
||||
return secrets.compare_digest(token, session_token)
|
||||
|
||||
|
||||
async def verify_csrf(request: Request, token: Optional[str] = None):
|
||||
"""
|
||||
Verify CSRF token with exception on error
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
token: Token to verify
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid
|
||||
"""
|
||||
if not await validate_csrf_token(request, token):
|
||||
logger.warning(f"CSRF token invalid for IP {request.client.host if request.client else 'unknown'}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid CSRF token"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user