289 lines
8.4 KiB
Python
289 lines
8.4 KiB
Python
"""
|
|
One-time password (OTP) utilities
|
|
"""
|
|
import secrets
|
|
import string
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_
|
|
from shared.database.models import OTPCode, User
|
|
import logging
|
|
from collections import defaultdict
|
|
import time
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# OTP code lifetime (10 minutes)
|
|
OTP_EXPIRY_MINUTES = 10
|
|
|
|
# Rate limiting for OTP code verification
|
|
# Format: {ip_address: [(timestamp, success), ...]}
|
|
_otp_attempts: dict[str, list[tuple[float, bool]]] = defaultdict(list)
|
|
import threading
|
|
_otp_locks: dict[str, threading.Lock] = defaultdict(lambda: threading.Lock())
|
|
|
|
# Rate limiting settings
|
|
MAX_OTP_ATTEMPTS = 5 # Maximum attempts
|
|
OTP_ATTEMPT_WINDOW = 60 # Time window in seconds (1 minute)
|
|
OTP_CLEANUP_INTERVAL = 3600 # Old attempts cleanup interval (1 hour)
|
|
|
|
|
|
def generate_otp_code() -> str:
|
|
"""
|
|
Generate 6-digit OTP code
|
|
|
|
Returns:
|
|
6-digit code
|
|
"""
|
|
return ''.join(secrets.choice(string.digits) for _ in range(6))
|
|
|
|
|
|
async def create_otp_code(user_id: int, db: AsyncSession) -> Optional[str]:
|
|
"""
|
|
Create new OTP code for user
|
|
|
|
Args:
|
|
user_id: User ID
|
|
db: Database session
|
|
|
|
Returns:
|
|
OTP code or None on error
|
|
"""
|
|
try:
|
|
# Invalidate all previous unused codes for user
|
|
await invalidate_user_otp_codes(user_id, db)
|
|
|
|
# Generate new code
|
|
code = generate_otp_code()
|
|
expires_at = datetime.utcnow() + timedelta(minutes=OTP_EXPIRY_MINUTES)
|
|
|
|
# Create database record
|
|
otp = OTPCode(
|
|
user_id=user_id,
|
|
code=code,
|
|
expires_at=expires_at,
|
|
used=False
|
|
)
|
|
db.add(otp)
|
|
await db.commit()
|
|
|
|
logger.info(f"OTP code created for user {user_id}, expires at {expires_at}")
|
|
return code
|
|
except Exception as e:
|
|
logger.error(f"Error creating OTP code: {e}")
|
|
await db.rollback()
|
|
return None
|
|
|
|
|
|
def _check_rate_limit(ip_address: str) -> bool:
|
|
"""
|
|
Check rate limiting for IP address
|
|
|
|
Args:
|
|
ip_address: Client IP address
|
|
|
|
Returns:
|
|
True if can continue, False if limit exceeded
|
|
"""
|
|
lock = _otp_locks[ip_address]
|
|
with lock:
|
|
now = time.time()
|
|
attempts = _otp_attempts[ip_address]
|
|
|
|
# Remove old attempts (older than time window)
|
|
attempts[:] = [(ts, success) for ts, success in attempts if now - ts < OTP_ATTEMPT_WINDOW]
|
|
|
|
# Check attempt count
|
|
if len(attempts) >= MAX_OTP_ATTEMPTS:
|
|
logger.warning(f"OTP verification rate limit exceeded for IP {ip_address}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _record_otp_attempt(ip_address: str, success: bool):
|
|
"""
|
|
Record OTP verification attempt
|
|
|
|
Args:
|
|
ip_address: Client IP address
|
|
success: Whether attempt was successful
|
|
"""
|
|
lock = _otp_locks[ip_address]
|
|
with lock:
|
|
now = time.time()
|
|
_otp_attempts[ip_address].append((now, success))
|
|
|
|
|
|
async def verify_otp_code(code: str, db: AsyncSession, ip_address: Optional[str] = None) -> Optional[int]:
|
|
"""
|
|
Verify and use OTP code with brute force protection
|
|
|
|
Args:
|
|
code: OTP code to verify
|
|
db: Database session
|
|
ip_address: Client IP address (for rate limiting)
|
|
|
|
Returns:
|
|
user_id if code is valid, None otherwise
|
|
"""
|
|
# Check rate limiting
|
|
if ip_address and not _check_rate_limit(ip_address):
|
|
logger.warning(f"OTP verification rate limit exceeded for IP {ip_address}")
|
|
return None
|
|
|
|
try:
|
|
# Search for code
|
|
result = await db.execute(
|
|
select(OTPCode).where(
|
|
and_(
|
|
OTPCode.code == code,
|
|
OTPCode.used == False,
|
|
OTPCode.expires_at > datetime.utcnow()
|
|
)
|
|
)
|
|
)
|
|
otp = result.scalar_one_or_none()
|
|
|
|
if not otp:
|
|
logger.warning(f"Invalid or expired OTP code: {code}")
|
|
if ip_address:
|
|
_record_otp_attempt(ip_address, False)
|
|
return None
|
|
|
|
# Mark code as used
|
|
otp.used = True
|
|
await db.commit()
|
|
|
|
if ip_address:
|
|
_record_otp_attempt(ip_address, True)
|
|
|
|
logger.info(f"OTP code used for user {otp.user_id}")
|
|
return otp.user_id
|
|
except Exception as e:
|
|
logger.error(f"Error verifying OTP code: {e}")
|
|
await db.rollback()
|
|
if ip_address:
|
|
_record_otp_attempt(ip_address, False)
|
|
return None
|
|
|
|
|
|
async def invalidate_user_otp_codes(user_id: int, db: AsyncSession):
|
|
"""
|
|
Invalidate all unused OTP codes for user
|
|
|
|
Args:
|
|
user_id: User ID
|
|
db: Database session
|
|
"""
|
|
try:
|
|
result = await db.execute(
|
|
select(OTPCode).where(
|
|
and_(
|
|
OTPCode.user_id == user_id,
|
|
OTPCode.used == False
|
|
)
|
|
)
|
|
)
|
|
otps = result.scalars().all()
|
|
|
|
for otp in otps:
|
|
otp.used = True
|
|
|
|
if otps:
|
|
await db.commit()
|
|
logger.info(f"Invalidated {len(otps)} OTP codes for user {user_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error invalidating OTP codes: {e}")
|
|
await db.rollback()
|
|
|
|
|
|
async def cleanup_expired_otp_codes(db: AsyncSession):
|
|
"""
|
|
Cleanup expired and used OTP codes older than 24 hours
|
|
|
|
Args:
|
|
db: Database session
|
|
"""
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(hours=24)
|
|
result = await db.execute(
|
|
select(OTPCode).where(
|
|
and_(
|
|
OTPCode.created_at < cutoff_time,
|
|
OTPCode.used == True
|
|
)
|
|
)
|
|
)
|
|
expired_otps = result.scalars().all()
|
|
|
|
for otp in expired_otps:
|
|
await db.delete(otp)
|
|
|
|
if expired_otps:
|
|
await db.commit()
|
|
logger.info(f"Cleaned up {len(expired_otps)} old OTP codes")
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up old OTP codes: {e}")
|
|
await db.rollback()
|
|
|
|
|
|
async def get_user_by_identifier(identifier: str, db: AsyncSession) -> Optional[User]:
|
|
"""
|
|
Search user by ID or username
|
|
|
|
Args:
|
|
identifier: User ID (number) or username (string with @ or without)
|
|
db: Database session
|
|
|
|
Returns:
|
|
User object or None
|
|
"""
|
|
try:
|
|
# Try to find by user_id
|
|
try:
|
|
user_id = int(identifier)
|
|
user = await db.get(User, user_id)
|
|
if user:
|
|
return user
|
|
except ValueError:
|
|
pass
|
|
|
|
# Search by username in database
|
|
username = identifier.lstrip('@')
|
|
result = await db.execute(
|
|
select(User).where(User.username == username)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
return user
|
|
|
|
# If not found in database, try to find by username via Telegram API
|
|
# This is needed for cases when user is not yet registered in database
|
|
# Bot API doesn't allow searching users by username directly,
|
|
# but we can try get_chat if bot has already interacted with user
|
|
try:
|
|
from bot.modules.task_scheduler.executor import get_app_client
|
|
app_client = get_app_client()
|
|
if app_client:
|
|
try:
|
|
# Try to get user information via get_chat
|
|
# This only works if bot has already interacted with user
|
|
chat = await app_client.get_chat(username)
|
|
if chat and hasattr(chat, 'id'):
|
|
# Found user, return None to create later
|
|
logger.info(f"Found user {chat.id} by username {username} via Telegram API")
|
|
return None
|
|
except Exception as e:
|
|
logger.debug(f"Failed to get user {username} via get_chat: {e}")
|
|
# get_chat didn't work, user must use ID for first login
|
|
except Exception as e:
|
|
logger.debug(f"Failed to get user information via Telegram API: {e}")
|
|
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error searching for user: {e}")
|
|
return None
|
|
|