311 lines
9.9 KiB
Python
311 lines
9.9 KiB
Python
"""
|
|
Authentication for web interface
|
|
"""
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta
|
|
from fastapi import Request, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from shared.config import settings
|
|
from shared.database.session import AsyncSessionLocal
|
|
from shared.database.models import User
|
|
import logging
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global variable for storing sessions (use Redis in production)
|
|
# Format: {session_id: {"user_id": int, "is_owner": bool, "created_at": datetime}}
|
|
_sessions: dict[str, dict] = {}
|
|
_sessions_lock = asyncio.Lock() # Use asyncio.Lock for async context
|
|
|
|
# Constants for time intervals
|
|
SECONDS_PER_HOUR = 3600
|
|
HOURS_PER_DAY = 24
|
|
|
|
# TTL for sessions (7 days)
|
|
SESSION_LIFETIME_DAYS = 7
|
|
SESSION_TTL_DAYS = 7
|
|
SESSION_TTL = timedelta(days=SESSION_TTL_DAYS)
|
|
|
|
# Session cleanup interval (6 hours)
|
|
SESSION_CLEANUP_INTERVAL_HOURS = 6
|
|
|
|
# Flag for background cleanup task
|
|
_cleanup_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
async def verify_web_user(user_id: int) -> bool:
|
|
"""
|
|
Check user access rights to web interface
|
|
All authorized users can use web interface
|
|
|
|
Args:
|
|
user_id: Telegram user ID
|
|
|
|
Returns:
|
|
True if user has access, False otherwise
|
|
"""
|
|
from bot.modules.access_control.auth import is_authorized
|
|
|
|
# Check user authorization
|
|
return await is_authorized(user_id)
|
|
|
|
|
|
async def create_session(user_id: int) -> str:
|
|
"""
|
|
Create session for user
|
|
|
|
Args:
|
|
user_id: Telegram user ID
|
|
|
|
Returns:
|
|
Session ID
|
|
"""
|
|
# Use Redis if enabled
|
|
if settings.USE_REDIS_SESSIONS:
|
|
try:
|
|
from web.utils.redis_session import create_redis_session
|
|
session_id = await create_redis_session(user_id)
|
|
if session_id:
|
|
return session_id
|
|
else:
|
|
logger.warning("Failed to create session in Redis, using in-memory")
|
|
return await _create_in_memory_session(user_id)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create session in Redis, using in-memory: {e}")
|
|
return await _create_in_memory_session(user_id)
|
|
else:
|
|
return await _create_in_memory_session(user_id)
|
|
|
|
|
|
async def _create_in_memory_session(user_id: int) -> str:
|
|
"""Create in-memory session (fallback)"""
|
|
import secrets
|
|
from web.utils.csrf import generate_csrf_token
|
|
|
|
session_id = secrets.token_urlsafe(32)
|
|
csrf_token = generate_csrf_token()
|
|
|
|
with _sessions_lock:
|
|
_sessions[session_id] = {
|
|
"user_id": user_id,
|
|
"is_owner": user_id == settings.OWNER_ID,
|
|
"created_at": datetime.utcnow(),
|
|
"expires_at": datetime.utcnow() + timedelta(days=SESSION_LIFETIME_DAYS),
|
|
"csrf_token": csrf_token
|
|
}
|
|
return session_id
|
|
|
|
|
|
async def get_session(session_id: str) -> Optional[dict]:
|
|
"""
|
|
Get session data (async version).
|
|
|
|
Checks Redis first if enabled, then falls back to in-memory sessions.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
|
|
Returns:
|
|
Session data or None (if session expired)
|
|
"""
|
|
# Use Redis if enabled
|
|
if settings.USE_REDIS_SESSIONS:
|
|
try:
|
|
from web.utils.redis_session import get_redis_session
|
|
session_data = await get_redis_session(session_id)
|
|
if session_data:
|
|
return session_data
|
|
# If not found in Redis, check in-memory (fallback)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get session from Redis, using in-memory: {e}")
|
|
|
|
# Fallback to in-memory sessions
|
|
async with _sessions_lock:
|
|
session_data = _sessions.get(session_id)
|
|
if not session_data:
|
|
return None
|
|
|
|
# Check TTL by expires_at (more reliable)
|
|
expires_at = session_data.get("expires_at")
|
|
if expires_at:
|
|
if isinstance(expires_at, datetime):
|
|
if datetime.utcnow() >= expires_at:
|
|
# Session expired, delete
|
|
del _sessions[session_id]
|
|
logger.debug(f"Session {session_id} expired (expires_at: {expires_at})")
|
|
return None
|
|
else:
|
|
# If expires_at in different format, try to parse
|
|
try:
|
|
if isinstance(expires_at, str):
|
|
expires_at_dt = datetime.fromisoformat(expires_at)
|
|
if datetime.utcnow() >= expires_at_dt:
|
|
del _sessions[session_id]
|
|
logger.debug(f"Session {session_id} expired (expires_at: {expires_at})")
|
|
return None
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Failed to parse expires_at for session {session_id}: {expires_at}")
|
|
|
|
# Fallback: check by created_at (for backward compatibility)
|
|
if not expires_at:
|
|
created_at = session_data.get("created_at")
|
|
if created_at and isinstance(created_at, datetime):
|
|
if datetime.utcnow() - created_at > SESSION_TTL:
|
|
# Session expired, delete
|
|
del _sessions[session_id]
|
|
logger.debug(f"Session {session_id} expired (by created_at)")
|
|
return None
|
|
|
|
return session_data
|
|
|
|
|
|
async def delete_session(session_id: str):
|
|
"""
|
|
Delete session
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
"""
|
|
# Use Redis if enabled
|
|
if settings.USE_REDIS_SESSIONS:
|
|
try:
|
|
from web.utils.redis_session import delete_redis_session
|
|
await delete_redis_session(session_id)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete session from Redis: {e}")
|
|
|
|
# Also remove from memory just in case
|
|
async with _sessions_lock:
|
|
if session_id in _sessions:
|
|
del _sessions[session_id]
|
|
|
|
|
|
async def cleanup_expired_sessions():
|
|
"""
|
|
Cleanup expired sessions
|
|
Called periodically in background
|
|
"""
|
|
now = datetime.utcnow()
|
|
expired_sessions = []
|
|
|
|
async with _sessions_lock:
|
|
for session_id, session_data in _sessions.items():
|
|
created_at = session_data.get("created_at")
|
|
if created_at and isinstance(created_at, datetime):
|
|
if now - created_at > SESSION_TTL:
|
|
expired_sessions.append(session_id)
|
|
|
|
for session_id in expired_sessions:
|
|
del _sessions[session_id]
|
|
|
|
if expired_sessions:
|
|
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
|
|
|
|
|
async def cleanup_sessions_periodically():
|
|
"""
|
|
Periodic cleanup of expired sessions
|
|
Runs in background every 6 hours
|
|
"""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(SESSION_CLEANUP_INTERVAL_HOURS * SECONDS_PER_HOUR)
|
|
await cleanup_expired_sessions()
|
|
except asyncio.CancelledError:
|
|
logger.info("Session cleanup task stopped")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up sessions: {e}", exc_info=True)
|
|
|
|
|
|
def start_session_cleanup_task():
|
|
"""
|
|
Start background session cleanup task
|
|
"""
|
|
global _cleanup_task
|
|
if _cleanup_task is None or _cleanup_task.done():
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
_cleanup_task = loop.create_task(cleanup_sessions_periodically())
|
|
logger.info("Background session cleanup task started")
|
|
except RuntimeError:
|
|
# If no running loop, try to get current one
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
_cleanup_task = loop.create_task(cleanup_sessions_periodically())
|
|
logger.info("Background session cleanup task started")
|
|
else:
|
|
logger.warning("Event loop not running, session cleanup task will be started later")
|
|
except RuntimeError:
|
|
logger.warning("Failed to start session cleanup task: no event loop")
|
|
|
|
|
|
async def get_current_user(request: Request) -> dict:
|
|
"""
|
|
Get current user from session
|
|
|
|
Args:
|
|
request: FastAPI Request object
|
|
|
|
Returns:
|
|
Dictionary with user data
|
|
|
|
Raises:
|
|
HTTPException: If user is not authorized
|
|
"""
|
|
session_id = request.cookies.get("session_id")
|
|
if not session_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Not authorized"
|
|
)
|
|
|
|
# Get session (from Redis or in-memory)
|
|
session_data = None
|
|
if settings.USE_REDIS_SESSIONS:
|
|
try:
|
|
from web.utils.redis_session import get_redis_session
|
|
session_data = await get_redis_session(session_id)
|
|
except Exception as e:
|
|
logger.warning(f"Error getting session from Redis: {e}, trying in-memory")
|
|
session_data = await get_session(session_id)
|
|
else:
|
|
session_data = await get_session(session_id)
|
|
|
|
if not session_data:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Session expired"
|
|
)
|
|
|
|
# Check that user still has access
|
|
user_id = session_data.get("user_id")
|
|
if not await verify_web_user(user_id):
|
|
await delete_session(session_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied"
|
|
)
|
|
|
|
return session_data
|
|
|
|
|
|
async def is_owner_web(request: Request) -> bool:
|
|
"""
|
|
Check if current user is Owner
|
|
|
|
Args:
|
|
request: FastAPI Request object
|
|
|
|
Returns:
|
|
True if Owner, False otherwise
|
|
"""
|
|
try:
|
|
user_data = await get_current_user(request)
|
|
return user_data.get("is_owner", False)
|
|
except HTTPException:
|
|
return False
|