Add source
This commit is contained in:
4
web/utils/__init__.py
Normal file
4
web/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Web application utilities
|
||||
"""
|
||||
|
||||
310
web/utils/auth.py
Normal file
310
web/utils/auth.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Authentication for web interface
|
||||
"""
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from shared.config import settings
|
||||
from shared.database.session import AsyncSessionLocal
|
||||
from shared.database.models import User
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variable for storing sessions (use Redis in production)
|
||||
# Format: {session_id: {"user_id": int, "is_owner": bool, "created_at": datetime}}
|
||||
_sessions: dict[str, dict] = {}
|
||||
_sessions_lock = asyncio.Lock() # Use asyncio.Lock for async context
|
||||
|
||||
# Constants for time intervals
|
||||
SECONDS_PER_HOUR = 3600
|
||||
HOURS_PER_DAY = 24
|
||||
|
||||
# TTL for sessions (7 days)
|
||||
SESSION_LIFETIME_DAYS = 7
|
||||
SESSION_TTL_DAYS = 7
|
||||
SESSION_TTL = timedelta(days=SESSION_TTL_DAYS)
|
||||
|
||||
# Session cleanup interval (6 hours)
|
||||
SESSION_CLEANUP_INTERVAL_HOURS = 6
|
||||
|
||||
# Flag for background cleanup task
|
||||
_cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def verify_web_user(user_id: int) -> bool:
|
||||
"""
|
||||
Check user access rights to web interface
|
||||
All authorized users can use web interface
|
||||
|
||||
Args:
|
||||
user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
True if user has access, False otherwise
|
||||
"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
|
||||
# Check user authorization
|
||||
return await is_authorized(user_id)
|
||||
|
||||
|
||||
async def create_session(user_id: int) -> str:
|
||||
"""
|
||||
Create session for user
|
||||
|
||||
Args:
|
||||
user_id: Telegram user ID
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
# Use Redis if enabled
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import create_redis_session
|
||||
session_id = await create_redis_session(user_id)
|
||||
if session_id:
|
||||
return session_id
|
||||
else:
|
||||
logger.warning("Failed to create session in Redis, using in-memory")
|
||||
return await _create_in_memory_session(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create session in Redis, using in-memory: {e}")
|
||||
return await _create_in_memory_session(user_id)
|
||||
else:
|
||||
return await _create_in_memory_session(user_id)
|
||||
|
||||
|
||||
async def _create_in_memory_session(user_id: int) -> str:
|
||||
"""Create in-memory session (fallback)"""
|
||||
import secrets
|
||||
from web.utils.csrf import generate_csrf_token
|
||||
|
||||
session_id = secrets.token_urlsafe(32)
|
||||
csrf_token = generate_csrf_token()
|
||||
|
||||
with _sessions_lock:
|
||||
_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"is_owner": user_id == settings.OWNER_ID,
|
||||
"created_at": datetime.utcnow(),
|
||||
"expires_at": datetime.utcnow() + timedelta(days=SESSION_LIFETIME_DAYS),
|
||||
"csrf_token": csrf_token
|
||||
}
|
||||
return session_id
|
||||
|
||||
|
||||
async def get_session(session_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get session data (async version).
|
||||
|
||||
Checks Redis first if enabled, then falls back to in-memory sessions.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Session data or None (if session expired)
|
||||
"""
|
||||
# Use Redis if enabled
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import get_redis_session
|
||||
session_data = await get_redis_session(session_id)
|
||||
if session_data:
|
||||
return session_data
|
||||
# If not found in Redis, check in-memory (fallback)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get session from Redis, using in-memory: {e}")
|
||||
|
||||
# Fallback to in-memory sessions
|
||||
async with _sessions_lock:
|
||||
session_data = _sessions.get(session_id)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
# Check TTL by expires_at (more reliable)
|
||||
expires_at = session_data.get("expires_at")
|
||||
if expires_at:
|
||||
if isinstance(expires_at, datetime):
|
||||
if datetime.utcnow() >= expires_at:
|
||||
# Session expired, delete
|
||||
del _sessions[session_id]
|
||||
logger.debug(f"Session {session_id} expired (expires_at: {expires_at})")
|
||||
return None
|
||||
else:
|
||||
# If expires_at in different format, try to parse
|
||||
try:
|
||||
if isinstance(expires_at, str):
|
||||
expires_at_dt = datetime.fromisoformat(expires_at)
|
||||
if datetime.utcnow() >= expires_at_dt:
|
||||
del _sessions[session_id]
|
||||
logger.debug(f"Session {session_id} expired (expires_at: {expires_at})")
|
||||
return None
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Failed to parse expires_at for session {session_id}: {expires_at}")
|
||||
|
||||
# Fallback: check by created_at (for backward compatibility)
|
||||
if not expires_at:
|
||||
created_at = session_data.get("created_at")
|
||||
if created_at and isinstance(created_at, datetime):
|
||||
if datetime.utcnow() - created_at > SESSION_TTL:
|
||||
# Session expired, delete
|
||||
del _sessions[session_id]
|
||||
logger.debug(f"Session {session_id} expired (by created_at)")
|
||||
return None
|
||||
|
||||
return session_data
|
||||
|
||||
|
||||
async def delete_session(session_id: str):
|
||||
"""
|
||||
Delete session
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
"""
|
||||
# Use Redis if enabled
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import delete_redis_session
|
||||
await delete_redis_session(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session from Redis: {e}")
|
||||
|
||||
# Also remove from memory just in case
|
||||
async with _sessions_lock:
|
||||
if session_id in _sessions:
|
||||
del _sessions[session_id]
|
||||
|
||||
|
||||
async def cleanup_expired_sessions():
|
||||
"""
|
||||
Cleanup expired sessions
|
||||
Called periodically in background
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
expired_sessions = []
|
||||
|
||||
async with _sessions_lock:
|
||||
for session_id, session_data in _sessions.items():
|
||||
created_at = session_data.get("created_at")
|
||||
if created_at and isinstance(created_at, datetime):
|
||||
if now - created_at > SESSION_TTL:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del _sessions[session_id]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
||||
|
||||
|
||||
async def cleanup_sessions_periodically():
|
||||
"""
|
||||
Periodic cleanup of expired sessions
|
||||
Runs in background every 6 hours
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(SESSION_CLEANUP_INTERVAL_HOURS * SECONDS_PER_HOUR)
|
||||
await cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Session cleanup task stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions: {e}", exc_info=True)
|
||||
|
||||
|
||||
def start_session_cleanup_task():
|
||||
"""
|
||||
Start background session cleanup task
|
||||
"""
|
||||
global _cleanup_task
|
||||
if _cleanup_task is None or _cleanup_task.done():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
_cleanup_task = loop.create_task(cleanup_sessions_periodically())
|
||||
logger.info("Background session cleanup task started")
|
||||
except RuntimeError:
|
||||
# If no running loop, try to get current one
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
_cleanup_task = loop.create_task(cleanup_sessions_periodically())
|
||||
logger.info("Background session cleanup task started")
|
||||
else:
|
||||
logger.warning("Event loop not running, session cleanup task will be started later")
|
||||
except RuntimeError:
|
||||
logger.warning("Failed to start session cleanup task: no event loop")
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> dict:
|
||||
"""
|
||||
Get current user from session
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
Dictionary with user data
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authorized
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authorized"
|
||||
)
|
||||
|
||||
# Get session (from Redis or in-memory)
|
||||
session_data = None
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
try:
|
||||
from web.utils.redis_session import get_redis_session
|
||||
session_data = await get_redis_session(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting session from Redis: {e}, trying in-memory")
|
||||
session_data = await get_session(session_id)
|
||||
else:
|
||||
session_data = await get_session(session_id)
|
||||
|
||||
if not session_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session expired"
|
||||
)
|
||||
|
||||
# Check that user still has access
|
||||
user_id = session_data.get("user_id")
|
||||
if not await verify_web_user(user_id):
|
||||
await delete_session(session_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
return session_data
|
||||
|
||||
|
||||
async def is_owner_web(request: Request) -> bool:
|
||||
"""
|
||||
Check if current user is Owner
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if Owner, False otherwise
|
||||
"""
|
||||
try:
|
||||
user_data = await get_current_user(request)
|
||||
return user_data.get("is_owner", False)
|
||||
except HTTPException:
|
||||
return False
|
||||
72
web/utils/bot_client.py
Normal file
72
web/utils/bot_client.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Utilities for interacting with bot from web interface
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from bot.modules.task_scheduler.executor import get_app_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_otp_to_user(user_id: int, code: str) -> bool:
|
||||
"""
|
||||
Send OTP code to user via Telegram bot
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
code: OTP code
|
||||
|
||||
Returns:
|
||||
True if successfully sent, False otherwise
|
||||
"""
|
||||
try:
|
||||
app_client = get_app_client()
|
||||
if not app_client:
|
||||
logger.warning(f"Bot client not available for sending OTP to user {user_id}")
|
||||
return False
|
||||
|
||||
# Check that client is started and connected
|
||||
try:
|
||||
if not hasattr(app_client, 'is_connected') or not app_client.is_connected:
|
||||
logger.warning(f"Bot client not connected for sending OTP to user {user_id}")
|
||||
return False
|
||||
except Exception as check_error:
|
||||
logger.warning(f"Failed to check bot connection status: {check_error}")
|
||||
# Continue sending attempt, client might be working
|
||||
|
||||
from shared.config import settings
|
||||
# Form URL for web interface
|
||||
if settings.WEB_HOST == "0.0.0.0":
|
||||
login_url = f"localhost:{settings.WEB_PORT}"
|
||||
else:
|
||||
login_url = f"{settings.WEB_HOST}:{settings.WEB_PORT}"
|
||||
|
||||
message = (
|
||||
f"🔐 **Ваш код для входа в веб-интерфейс:**\n\n"
|
||||
f"**`{code}`**\n\n"
|
||||
f"⏰ Код действителен 10 минут\n\n"
|
||||
f"🌐 Перейдите на http://{login_url}/admin/login и введите этот код"
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to send message
|
||||
result = await app_client.send_message(user_id, message)
|
||||
logger.info(f"OTP code successfully sent to user {user_id}, message_id: {result.id if result else 'unknown'}")
|
||||
return True
|
||||
except Exception as send_error:
|
||||
error_msg = str(send_error)
|
||||
logger.error(f"Error sending message to user {user_id}: {error_msg}", exc_info=True)
|
||||
|
||||
# Check error type for more informative message
|
||||
if "chat not found" in error_msg.lower() or "user not found" in error_msg.lower():
|
||||
logger.error(f"User {user_id} not found or hasn't started dialog with bot")
|
||||
elif "flood" in error_msg.lower():
|
||||
logger.error(f"Message sending rate limit exceeded for user {user_id}")
|
||||
elif "unauthorized" in error_msg.lower():
|
||||
logger.error(f"Bot not authorized or stopped")
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error sending OTP to user {user_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
123
web/utils/csrf.py
Normal file
123
web/utils/csrf.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
CSRF token utilities
|
||||
"""
|
||||
import secrets
|
||||
from typing import Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from web.utils.auth import get_session
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CSRF token length
|
||||
CSRF_TOKEN_LENGTH = 32
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""
|
||||
Generate CSRF token
|
||||
|
||||
Returns:
|
||||
Random token
|
||||
"""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
async def get_csrf_token(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get CSRF token from session.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
CSRF token or None
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
session_data = await get_session(session_id)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
return session_data.get("csrf_token")
|
||||
|
||||
|
||||
async def set_csrf_token(request: Request, token: str) -> None:
|
||||
"""
|
||||
Set CSRF token in session.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
token: CSRF token
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
session_data = await get_session(session_id)
|
||||
if session_data:
|
||||
session_data["csrf_token"] = token
|
||||
# Save session back to Redis if using Redis
|
||||
from shared.config import settings
|
||||
if settings.USE_REDIS_SESSIONS:
|
||||
from web.utils.redis_session import update_redis_session
|
||||
await update_redis_session(session_id, session_data)
|
||||
|
||||
|
||||
async def validate_csrf_token(request: Request, token: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Validate CSRF token
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
token: Token to validate (if None, taken from form/headers)
|
||||
|
||||
Returns:
|
||||
True if token is valid
|
||||
"""
|
||||
# Get token from different sources
|
||||
if not token:
|
||||
# Try to get from form
|
||||
try:
|
||||
form = await request.form()
|
||||
token = form.get("csrf_token")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get CSRF token from form: {e}")
|
||||
pass
|
||||
|
||||
# If not found in form, try from headers
|
||||
if not token:
|
||||
token = request.headers.get("X-CSRF-Token")
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
# Get token from session
|
||||
session_token = await get_csrf_token(request)
|
||||
if not session_token:
|
||||
return False
|
||||
|
||||
# Compare tokens
|
||||
return secrets.compare_digest(token, session_token)
|
||||
|
||||
|
||||
async def verify_csrf(request: Request, token: Optional[str] = None):
|
||||
"""
|
||||
Verify CSRF token with exception on error
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
token: Token to verify
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid
|
||||
"""
|
||||
if not await validate_csrf_token(request, token):
|
||||
logger.warning(f"CSRF token invalid for IP {request.client.host if request.client else 'unknown'}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid CSRF token"
|
||||
)
|
||||
|
||||
9
web/utils/database.py
Normal file
9
web/utils/database.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Database utilities for web interface (wrapper over shared module)
|
||||
Uses unified module from shared/database/session.py
|
||||
"""
|
||||
from shared.database.session import get_db
|
||||
|
||||
# Export function for use in FastAPI Depends
|
||||
__all__ = ['get_db']
|
||||
|
||||
288
web/utils/otp.py
Normal file
288
web/utils/otp.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
One-time password (OTP) utilities
|
||||
"""
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from shared.database.models import OTPCode, User
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OTP code lifetime (10 minutes)
|
||||
OTP_EXPIRY_MINUTES = 10
|
||||
|
||||
# Rate limiting for OTP code verification
|
||||
# Format: {ip_address: [(timestamp, success), ...]}
|
||||
_otp_attempts: dict[str, list[tuple[float, bool]]] = defaultdict(list)
|
||||
import threading
|
||||
_otp_locks: dict[str, threading.Lock] = defaultdict(lambda: threading.Lock())
|
||||
|
||||
# Rate limiting settings
|
||||
MAX_OTP_ATTEMPTS = 5 # Maximum attempts
|
||||
OTP_ATTEMPT_WINDOW = 60 # Time window in seconds (1 minute)
|
||||
OTP_CLEANUP_INTERVAL = 3600 # Old attempts cleanup interval (1 hour)
|
||||
|
||||
|
||||
def generate_otp_code() -> str:
|
||||
"""
|
||||
Generate 6-digit OTP code
|
||||
|
||||
Returns:
|
||||
6-digit code
|
||||
"""
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(6))
|
||||
|
||||
|
||||
async def create_otp_code(user_id: int, db: AsyncSession) -> Optional[str]:
|
||||
"""
|
||||
Create new OTP code for user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OTP code or None on error
|
||||
"""
|
||||
try:
|
||||
# Invalidate all previous unused codes for user
|
||||
await invalidate_user_otp_codes(user_id, db)
|
||||
|
||||
# Generate new code
|
||||
code = generate_otp_code()
|
||||
expires_at = datetime.utcnow() + timedelta(minutes=OTP_EXPIRY_MINUTES)
|
||||
|
||||
# Create database record
|
||||
otp = OTPCode(
|
||||
user_id=user_id,
|
||||
code=code,
|
||||
expires_at=expires_at,
|
||||
used=False
|
||||
)
|
||||
db.add(otp)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"OTP code created for user {user_id}, expires at {expires_at}")
|
||||
return code
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating OTP code: {e}")
|
||||
await db.rollback()
|
||||
return None
|
||||
|
||||
|
||||
def _check_rate_limit(ip_address: str) -> bool:
|
||||
"""
|
||||
Check rate limiting for IP address
|
||||
|
||||
Args:
|
||||
ip_address: Client IP address
|
||||
|
||||
Returns:
|
||||
True if can continue, False if limit exceeded
|
||||
"""
|
||||
lock = _otp_locks[ip_address]
|
||||
with lock:
|
||||
now = time.time()
|
||||
attempts = _otp_attempts[ip_address]
|
||||
|
||||
# Remove old attempts (older than time window)
|
||||
attempts[:] = [(ts, success) for ts, success in attempts if now - ts < OTP_ATTEMPT_WINDOW]
|
||||
|
||||
# Check attempt count
|
||||
if len(attempts) >= MAX_OTP_ATTEMPTS:
|
||||
logger.warning(f"OTP verification rate limit exceeded for IP {ip_address}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _record_otp_attempt(ip_address: str, success: bool):
|
||||
"""
|
||||
Record OTP verification attempt
|
||||
|
||||
Args:
|
||||
ip_address: Client IP address
|
||||
success: Whether attempt was successful
|
||||
"""
|
||||
lock = _otp_locks[ip_address]
|
||||
with lock:
|
||||
now = time.time()
|
||||
_otp_attempts[ip_address].append((now, success))
|
||||
|
||||
|
||||
async def verify_otp_code(code: str, db: AsyncSession, ip_address: Optional[str] = None) -> Optional[int]:
|
||||
"""
|
||||
Verify and use OTP code with brute force protection
|
||||
|
||||
Args:
|
||||
code: OTP code to verify
|
||||
db: Database session
|
||||
ip_address: Client IP address (for rate limiting)
|
||||
|
||||
Returns:
|
||||
user_id if code is valid, None otherwise
|
||||
"""
|
||||
# Check rate limiting
|
||||
if ip_address and not _check_rate_limit(ip_address):
|
||||
logger.warning(f"OTP verification rate limit exceeded for IP {ip_address}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Search for code
|
||||
result = await db.execute(
|
||||
select(OTPCode).where(
|
||||
and_(
|
||||
OTPCode.code == code,
|
||||
OTPCode.used == False,
|
||||
OTPCode.expires_at > datetime.utcnow()
|
||||
)
|
||||
)
|
||||
)
|
||||
otp = result.scalar_one_or_none()
|
||||
|
||||
if not otp:
|
||||
logger.warning(f"Invalid or expired OTP code: {code}")
|
||||
if ip_address:
|
||||
_record_otp_attempt(ip_address, False)
|
||||
return None
|
||||
|
||||
# Mark code as used
|
||||
otp.used = True
|
||||
await db.commit()
|
||||
|
||||
if ip_address:
|
||||
_record_otp_attempt(ip_address, True)
|
||||
|
||||
logger.info(f"OTP code used for user {otp.user_id}")
|
||||
return otp.user_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying OTP code: {e}")
|
||||
await db.rollback()
|
||||
if ip_address:
|
||||
_record_otp_attempt(ip_address, False)
|
||||
return None
|
||||
|
||||
|
||||
async def invalidate_user_otp_codes(user_id: int, db: AsyncSession):
|
||||
"""
|
||||
Invalidate all unused OTP codes for user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
db: Database session
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OTPCode).where(
|
||||
and_(
|
||||
OTPCode.user_id == user_id,
|
||||
OTPCode.used == False
|
||||
)
|
||||
)
|
||||
)
|
||||
otps = result.scalars().all()
|
||||
|
||||
for otp in otps:
|
||||
otp.used = True
|
||||
|
||||
if otps:
|
||||
await db.commit()
|
||||
logger.info(f"Invalidated {len(otps)} OTP codes for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating OTP codes: {e}")
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def cleanup_expired_otp_codes(db: AsyncSession):
|
||||
"""
|
||||
Cleanup expired and used OTP codes older than 24 hours
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=24)
|
||||
result = await db.execute(
|
||||
select(OTPCode).where(
|
||||
and_(
|
||||
OTPCode.created_at < cutoff_time,
|
||||
OTPCode.used == True
|
||||
)
|
||||
)
|
||||
)
|
||||
expired_otps = result.scalars().all()
|
||||
|
||||
for otp in expired_otps:
|
||||
await db.delete(otp)
|
||||
|
||||
if expired_otps:
|
||||
await db.commit()
|
||||
logger.info(f"Cleaned up {len(expired_otps)} old OTP codes")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old OTP codes: {e}")
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def get_user_by_identifier(identifier: str, db: AsyncSession) -> Optional[User]:
|
||||
"""
|
||||
Search user by ID or username
|
||||
|
||||
Args:
|
||||
identifier: User ID (number) or username (string with @ or without)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object or None
|
||||
"""
|
||||
try:
|
||||
# Try to find by user_id
|
||||
try:
|
||||
user_id = int(identifier)
|
||||
user = await db.get(User, user_id)
|
||||
if user:
|
||||
return user
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Search by username in database
|
||||
username = identifier.lstrip('@')
|
||||
result = await db.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
return user
|
||||
|
||||
# If not found in database, try to find by username via Telegram API
|
||||
# This is needed for cases when user is not yet registered in database
|
||||
# Bot API doesn't allow searching users by username directly,
|
||||
# but we can try get_chat if bot has already interacted with user
|
||||
try:
|
||||
from bot.modules.task_scheduler.executor import get_app_client
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
# Try to get user information via get_chat
|
||||
# This only works if bot has already interacted with user
|
||||
chat = await app_client.get_chat(username)
|
||||
if chat and hasattr(chat, 'id'):
|
||||
# Found user, return None to create later
|
||||
logger.info(f"Found user {chat.id} by username {username} via Telegram API")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get user {username} via get_chat: {e}")
|
||||
# get_chat didn't work, user must use ID for first login
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get user information via Telegram API: {e}")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for user: {e}")
|
||||
return None
|
||||
|
||||
138
web/utils/redis_session.py
Normal file
138
web/utils/redis_session.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Redis session utilities
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
from shared.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_client = None
|
||||
|
||||
|
||||
def get_redis_client():
|
||||
"""Get Redis client"""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
_redis_client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True
|
||||
)
|
||||
logger.info(f"Connected to Redis: {settings.REDIS_HOST}:{settings.REDIS_PORT}")
|
||||
except ImportError:
|
||||
logger.error("Redis library not installed. Install: pip install redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to Redis: {e}")
|
||||
return None
|
||||
return _redis_client
|
||||
|
||||
|
||||
async def create_redis_session(user_id: int) -> str:
|
||||
"""Create session in Redis"""
|
||||
import secrets
|
||||
from web.utils.csrf import generate_csrf_token
|
||||
|
||||
session_id = secrets.token_urlsafe(32)
|
||||
csrf_token = generate_csrf_token()
|
||||
|
||||
session_data = {
|
||||
"user_id": user_id,
|
||||
"is_owner": user_id == settings.OWNER_ID,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (datetime.utcnow() + timedelta(days=7)).isoformat(),
|
||||
"csrf_token": csrf_token
|
||||
}
|
||||
|
||||
redis_client = get_redis_client()
|
||||
if not redis_client:
|
||||
logger.warning("Redis client unavailable, session not created")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Save session with TTL 7 days
|
||||
await redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
7 * 24 * 3600, # 7 days in seconds
|
||||
json.dumps(session_data)
|
||||
)
|
||||
logger.debug(f"Session created in Redis: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating session in Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_redis_session(session_id: str) -> Optional[dict]:
|
||||
"""Get session from Redis"""
|
||||
redis_client = get_redis_client()
|
||||
if not redis_client:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await redis_client.get(f"session:{session_id}")
|
||||
if data:
|
||||
session_data = json.loads(data)
|
||||
# Check expires_at
|
||||
expires_at_str = session_data.get('expires_at')
|
||||
if expires_at_str:
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if datetime.utcnow() < expires_at:
|
||||
return session_data
|
||||
else:
|
||||
# Session expired, delete
|
||||
logger.debug(f"Session {session_id} expired in Redis (expires_at: {expires_at})")
|
||||
await delete_redis_session(session_id)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse expires_at for session {session_id}: {e}")
|
||||
# If parsing failed, consider session valid (Redis TTL will still work)
|
||||
return session_data
|
||||
else:
|
||||
# If expires_at missing, return session (Redis TTL will still work)
|
||||
return session_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session from Redis: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def delete_redis_session(session_id: str) -> None:
|
||||
"""Delete session from Redis"""
|
||||
redis_client = get_redis_client()
|
||||
if redis_client:
|
||||
try:
|
||||
await redis_client.delete(f"session:{session_id}")
|
||||
logger.debug(f"Session deleted from Redis: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting session from Redis: {e}")
|
||||
|
||||
|
||||
async def update_redis_session(session_id: str, session_data: dict) -> bool:
|
||||
"""Update session in Redis"""
|
||||
redis_client = get_redis_client()
|
||||
if not redis_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(session_data['expires_at'])
|
||||
ttl = int((expires_at - datetime.utcnow()).total_seconds())
|
||||
if ttl > 0:
|
||||
await redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
ttl,
|
||||
json.dumps(session_data)
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating session in Redis: {e}")
|
||||
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user