139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
"""
|
|
Redis session utilities
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta
|
|
from shared.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_redis_client = None
|
|
|
|
|
|
def get_redis_client():
|
|
"""Get Redis client"""
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
try:
|
|
import redis.asyncio as redis
|
|
_redis_client = redis.Redis(
|
|
host=settings.REDIS_HOST,
|
|
port=settings.REDIS_PORT,
|
|
db=settings.REDIS_DB,
|
|
decode_responses=True
|
|
)
|
|
logger.info(f"Connected to Redis: {settings.REDIS_HOST}:{settings.REDIS_PORT}")
|
|
except ImportError:
|
|
logger.error("Redis library not installed. Install: pip install redis")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error connecting to Redis: {e}")
|
|
return None
|
|
return _redis_client
|
|
|
|
|
|
async def create_redis_session(user_id: int) -> str:
|
|
"""Create session in Redis"""
|
|
import secrets
|
|
from web.utils.csrf import generate_csrf_token
|
|
|
|
session_id = secrets.token_urlsafe(32)
|
|
csrf_token = generate_csrf_token()
|
|
|
|
session_data = {
|
|
"user_id": user_id,
|
|
"is_owner": user_id == settings.OWNER_ID,
|
|
"created_at": datetime.utcnow().isoformat(),
|
|
"expires_at": (datetime.utcnow() + timedelta(days=7)).isoformat(),
|
|
"csrf_token": csrf_token
|
|
}
|
|
|
|
redis_client = get_redis_client()
|
|
if not redis_client:
|
|
logger.warning("Redis client unavailable, session not created")
|
|
return None
|
|
|
|
try:
|
|
# Save session with TTL 7 days
|
|
await redis_client.setex(
|
|
f"session:{session_id}",
|
|
7 * 24 * 3600, # 7 days in seconds
|
|
json.dumps(session_data)
|
|
)
|
|
logger.debug(f"Session created in Redis: {session_id}")
|
|
return session_id
|
|
except Exception as e:
|
|
logger.error(f"Error creating session in Redis: {e}")
|
|
return None
|
|
|
|
|
|
async def get_redis_session(session_id: str) -> Optional[dict]:
|
|
"""Get session from Redis"""
|
|
redis_client = get_redis_client()
|
|
if not redis_client:
|
|
return None
|
|
|
|
try:
|
|
data = await redis_client.get(f"session:{session_id}")
|
|
if data:
|
|
session_data = json.loads(data)
|
|
# Check expires_at
|
|
expires_at_str = session_data.get('expires_at')
|
|
if expires_at_str:
|
|
try:
|
|
expires_at = datetime.fromisoformat(expires_at_str)
|
|
if datetime.utcnow() < expires_at:
|
|
return session_data
|
|
else:
|
|
# Session expired, delete
|
|
logger.debug(f"Session {session_id} expired in Redis (expires_at: {expires_at})")
|
|
await delete_redis_session(session_id)
|
|
return None
|
|
except (ValueError, TypeError) as e:
|
|
logger.warning(f"Failed to parse expires_at for session {session_id}: {e}")
|
|
# If parsing failed, consider session valid (Redis TTL will still work)
|
|
return session_data
|
|
else:
|
|
# If expires_at missing, return session (Redis TTL will still work)
|
|
return session_data
|
|
except Exception as e:
|
|
logger.error(f"Error getting session from Redis: {e}")
|
|
|
|
return None
|
|
|
|
|
|
async def delete_redis_session(session_id: str) -> None:
|
|
"""Delete session from Redis"""
|
|
redis_client = get_redis_client()
|
|
if redis_client:
|
|
try:
|
|
await redis_client.delete(f"session:{session_id}")
|
|
logger.debug(f"Session deleted from Redis: {session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting session from Redis: {e}")
|
|
|
|
|
|
async def update_redis_session(session_id: str, session_data: dict) -> bool:
|
|
"""Update session in Redis"""
|
|
redis_client = get_redis_client()
|
|
if not redis_client:
|
|
return False
|
|
|
|
try:
|
|
expires_at = datetime.fromisoformat(session_data['expires_at'])
|
|
ttl = int((expires_at - datetime.utcnow()).total_seconds())
|
|
if ttl > 0:
|
|
await redis_client.setex(
|
|
f"session:{session_id}",
|
|
ttl,
|
|
json.dumps(session_data)
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error updating session in Redis: {e}")
|
|
|
|
return False
|
|
|