126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
"""
|
|
Unified module for database session management
|
|
Used by both bot and web interface
|
|
"""
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
from shared.config import settings
|
|
from shared.database.models import Base
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Unified database engine for the entire application
|
|
_engine = None
|
|
_AsyncSessionLocal = None
|
|
|
|
|
|
def get_engine():
|
|
"""Get database engine (created on first call)"""
|
|
global _engine
|
|
if _engine is None:
|
|
# Log DATABASE_URL being used (without password for security)
|
|
db_url_safe = settings.DATABASE_URL
|
|
if '@' in db_url_safe:
|
|
# Hide password in logs
|
|
parts = db_url_safe.split('@')
|
|
if len(parts) == 2:
|
|
auth_part = parts[0].split('://')
|
|
if len(auth_part) == 2:
|
|
scheme = auth_part[0]
|
|
user_pass = auth_part[1]
|
|
if ':' in user_pass:
|
|
user = user_pass.split(':')[0]
|
|
db_url_safe = f"{scheme}://{user}:***@{parts[1]}"
|
|
logger.info(f"Creating database engine with URL: {db_url_safe}")
|
|
|
|
_engine = create_async_engine(
|
|
settings.DATABASE_URL,
|
|
echo=False,
|
|
future=True,
|
|
pool_pre_ping=True, # Check connections before use
|
|
pool_recycle=3600, # Reuse connections every 3600 seconds
|
|
)
|
|
logger.info("Database engine created")
|
|
return _engine
|
|
|
|
|
|
def get_session_factory():
|
|
"""Get session factory (created on first call)"""
|
|
global _AsyncSessionLocal
|
|
if _AsyncSessionLocal is None:
|
|
engine = get_engine()
|
|
_AsyncSessionLocal = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False
|
|
)
|
|
logger.info("Session factory created")
|
|
return _AsyncSessionLocal
|
|
|
|
|
|
async def init_db():
|
|
"""
|
|
Initialize database (create tables via Alembic migrations).
|
|
|
|
Uses Alembic to apply migrations instead of directly creating tables.
|
|
This ensures database schema versioning and the ability to rollback changes.
|
|
"""
|
|
try:
|
|
from shared.database.migrations import init_db_with_migrations
|
|
await init_db_with_migrations()
|
|
except ImportError:
|
|
# Fallback to old method if migrations not configured
|
|
logger.warning("Alembic not configured, using direct table creation")
|
|
engine = get_engine()
|
|
# Determine database type from URL
|
|
db_type = "SQLite" if "sqlite" in settings.DATABASE_URL.lower() else "PostgreSQL"
|
|
logger.info(f"Initializing {db_type} database...")
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
logger.info(f"{db_type} database tables successfully created")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing database: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
async def get_db():
|
|
"""
|
|
Get database session (generator for use with Depends in FastAPI)
|
|
Used by both bot and web interface
|
|
"""
|
|
AsyncSessionLocal = get_session_factory()
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
def get_async_session_local():
|
|
"""
|
|
Get session factory for direct use (e.g., in bot)
|
|
"""
|
|
return get_session_factory()
|
|
|
|
|
|
# For backward compatibility - export session factory as AsyncSessionLocal
|
|
# Use proxy class that behaves like async_sessionmaker
|
|
class AsyncSessionLocalProxy:
|
|
"""
|
|
Proxy for AsyncSessionLocal that initializes session factory on first use.
|
|
Allows using AsyncSessionLocal() like async_sessionmaker().
|
|
"""
|
|
def __call__(self, *args, **kwargs):
|
|
"""Calling AsyncSessionLocal() creates a new session"""
|
|
factory = get_session_factory()
|
|
return factory(*args, **kwargs)
|
|
|
|
def __getattr__(self, name):
|
|
"""Proxy attributes to session factory"""
|
|
factory = get_session_factory()
|
|
return getattr(factory, name)
|
|
|
|
# Create proxy instance for backward compatibility
|
|
AsyncSessionLocal = AsyncSessionLocalProxy()
|
|
|