Add source
This commit is contained in:
310
web/utils/auth.py
Normal file
310
web/utils/auth.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Authentication for web interface
|
||||
"""
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from shared.config import settings
|
||||
from shared.database.session import AsyncSessionLocal
|
||||
from shared.database.models import User
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variable for storing sessions (use Redis in production)
|
||||
# Format: {session_id: {"user_id": int, "is_owner": bool, "created_at": datetime}}
|
||||
_sessions: dict[str, dict] = {}
|
||||
_sessions_lock = asyncio.Lock() # Use asyncio.Lock for async context
|
||||
|
||||
# Constants for time intervals
|
||||
SECONDS_PER_HOUR = 3600
|
||||
HOURS_PER_DAY = 24
|
||||
|
||||
# TTL for sessions (7 days)
|
||||
SESSION_LIFETIME_DAYS = 7
|
||||
SESSION_TTL_DAYS = 7
|
||||
SESSION_TTL = timedelta(days=SESSION_TTL_DAYS)
|
||||
|
||||
# Session cleanup interval (6 hours)
|
||||
SESSION_CLEANUP_INTERVAL_HOURS = 6
|
||||
|
||||
# Flag for background cleanup task
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def verify_web_user(user_id: int) -> bool:
|
||||
"""
|
||||
Check user access rights to web interface
|
||||
All authorized users can use web interface
|
||||
|
||||
Args:
|
||||
user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
True if user has access, False otherwise
|
||||
"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
|
||||
# Check user authorization
|
||||
return await is_authorized(user_id)
|
||||
|
||||
|
||||
async def create_session(user_id: int) -> str:
|
||||
"""
|
||||
Create session for user
|
||||
|
||||
Args:
|
||||
user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
# Use Redis if enabled
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import create_redis_session
|
||||
session_id = await create_redis_session(user_id)
|
||||
if session_id:
|
||||
return session_id
|
||||
else:
|
||||
logger.warning("Failed to create session in Redis, using in-memory")
|
||||
return await _create_in_memory_session(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create session in Redis, using in-memory: {e}")
|
||||
return await _create_in_memory_session(user_id)
|
||||
else:
|
||||
return await _create_in_memory_session(user_id)
|
||||
|
||||
|
||||
async def _create_in_memory_session(user_id: int) -> str:
|
||||
"""Create in-memory session (fallback)"""
|
||||
import secrets
|
||||
from web.utils.csrf import generate_csrf_token
|
||||
|
||||
session_id = secrets.token_urlsafe(32)
|
||||
csrf_token = generate_csrf_token()
|
||||
|
||||
with _sessions_lock:
|
||||
_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"is_owner": user_id == settings.OWNER_ID,
|
||||
"created_at": datetime.utcnow(),
|
||||
"expires_at": datetime.utcnow() + timedelta(days=SESSION_LIFETIME_DAYS),
|
||||
"csrf_token": csrf_token
|
||||
}
|
||||
return session_id
|
||||
|
||||
|
||||
async def get_session(session_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get session data (async version).
|
||||
|
||||
Checks Redis first if enabled, then falls back to in-memory sessions.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Session data or None (if session expired)
|
||||
"""
|
||||
# Use Redis if enabled
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import get_redis_session
|
||||
session_data = await get_redis_session(session_id)
|
||||
if session_data:
|
||||
return session_data
|
||||
# If not found in Redis, check in-memory (fallback)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get session from Redis, using in-memory: {e}")
|
||||
|
||||
# Fallback to in-memory sessions
|
||||
async with _sessions_lock:
|
||||
session_data = _sessions.get(session_id)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
# Check TTL by expires_at (more reliable)
|
||||
expires_at = session_data.get("expires_at")
|
||||
if expires_at:
|
||||
if isinstance(expires_at, datetime):
|
||||
if datetime.utcnow() >= expires_at:
|
||||
# Session expired, delete
|
||||
del _sessions[session_id]
|
||||
logger.debug(f"Session {session_id} expired (expires_at: {expires_at})")
|
||||
return None
|
||||
else:
|
||||
# If expires_at in different format, try to parse
|
||||
try:
|
||||
if isinstance(expires_at, str):
|
||||
expires_at_dt = datetime.fromisoformat(expires_at)
|
||||
if datetime.utcnow() >= expires_at_dt:
|
||||
del _sessions[session_id]
|
||||
logger.debug(f"Session {session_id} expired (expires_at: {expires_at})")
|
||||
return None
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Failed to parse expires_at for session {session_id}: {expires_at}")
|
||||
|
||||
# Fallback: check by created_at (for backward compatibility)
|
||||
if not expires_at:
|
||||
created_at = session_data.get("created_at")
|
||||
if created_at and isinstance(created_at, datetime):
|
||||
if datetime.utcnow() - created_at > SESSION_TTL:
|
||||
# Session expired, delete
|
||||
del _sessions[session_id]
|
||||
logger.debug(f"Session {session_id} expired (by created_at)")
|
||||
return None
|
||||
|
||||
return session_data
|
||||
|
||||
|
||||
async def delete_session(session_id: str):
|
||||
"""
|
||||
Delete session
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
"""
|
||||
# Use Redis if enabled
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import delete_redis_session
|
||||
await delete_redis_session(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session from Redis: {e}")
|
||||
|
||||
# Also remove from memory just in case
|
||||
async with _sessions_lock:
|
||||
if session_id in _sessions:
|
||||
del _sessions[session_id]
|
||||
|
||||
|
||||
async def cleanup_expired_sessions():
|
||||
"""
|
||||
Cleanup expired sessions
|
||||
Called periodically in background
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
expired_sessions = []
|
||||
|
||||
async with _sessions_lock:
|
||||
for session_id, session_data in _sessions.items():
|
||||
created_at = session_data.get("created_at")
|
||||
if created_at and isinstance(created_at, datetime):
|
||||
if now - created_at > SESSION_TTL:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del _sessions[session_id]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
||||
|
||||
|
||||
async def cleanup_sessions_periodically():
|
||||
"""
|
||||
Periodic cleanup of expired sessions
|
||||
Runs in background every 6 hours
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(SESSION_CLEANUP_INTERVAL_HOURS * SECONDS_PER_HOUR)
|
||||
await cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Session cleanup task stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions: {e}", exc_info=True)
|
||||
|
||||
|
||||
def start_session_cleanup_task():
|
||||
"""
|
||||
Start background session cleanup task
|
||||
"""
|
||||
global _cleanup_task
|
||||
if _cleanup_task is None or _cleanup_task.done():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
_cleanup_task = loop.create_task(cleanup_sessions_periodically())
|
||||
logger.info("Background session cleanup task started")
|
||||
except RuntimeError:
|
||||
# If no running loop, try to get current one
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
_cleanup_task = loop.create_task(cleanup_sessions_periodically())
|
||||
logger.info("Background session cleanup task started")
|
||||
else:
|
||||
logger.warning("Event loop not running, session cleanup task will be started later")
|
||||
except RuntimeError:
|
||||
logger.warning("Failed to start session cleanup task: no event loop")
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> dict:
|
||||
"""
|
||||
Get current user from session
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
Dictionary with user data
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authorized
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authorized"
|
||||
)
|
||||
|
||||
# Get session (from Redis or in-memory)
|
||||
session_data = None
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import get_redis_session
|
||||
session_data = await get_redis_session(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting session from Redis: {e}, trying in-memory")
|
||||
session_data = await get_session(session_id)
|
||||
else:
|
||||
session_data = await get_session(session_id)
|
||||
|
||||
if not session_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session expired"
|
||||
)
|
||||
|
||||
# Check that user still has access
|
||||
user_id = session_data.get("user_id")
|
||||
if not await verify_web_user(user_id):
|
||||
await delete_session(session_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
return session_data
|
||||
|
||||
|
||||
async def is_owner_web(request: Request) -> bool:
|
||||
"""
|
||||
Check if current user is Owner
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if Owner, False otherwise
|
||||
"""
|
||||
try:
|
||||
user_data = await get_current_user(request)
|
||||
return user_data.get("is_owner", False)
|
||||
except HTTPException:
|
||||
return False
|
||||
Reference in New Issue
Block a user