Add source
This commit is contained in:
125
shared/database/session.py
Normal file
125
shared/database/session.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Unified module for database session management
|
||||
Used by both bot and web interface
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from shared.config import settings
|
||||
from shared.database.models import Base
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Unified database engine for the entire application
|
||||
_engine = None
|
||||
_AsyncSessionLocal = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Get database engine (created on first call)"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
# Log DATABASE_URL being used (without password for security)
|
||||
db_url_safe = settings.DATABASE_URL
|
||||
if '@' in db_url_safe:
|
||||
# Hide password in logs
|
||||
parts = db_url_safe.split('@')
|
||||
if len(parts) == 2:
|
||||
auth_part = parts[0].split('://')
|
||||
if len(auth_part) == 2:
|
||||
scheme = auth_part[0]
|
||||
user_pass = auth_part[1]
|
||||
if ':' in user_pass:
|
||||
user = user_pass.split(':')[0]
|
||||
db_url_safe = f"{scheme}://{user}:***@{parts[1]}"
|
||||
logger.info(f"Creating database engine with URL: {db_url_safe}")
|
||||
|
||||
_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
pool_pre_ping=True, # Check connections before use
|
||||
pool_recycle=3600, # Reuse connections every 3600 seconds
|
||||
)
|
||||
logger.info("Database engine created")
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory():
|
||||
"""Get session factory (created on first call)"""
|
||||
global _AsyncSessionLocal
|
||||
if _AsyncSessionLocal is None:
|
||||
engine = get_engine()
|
||||
_AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
logger.info("Session factory created")
|
||||
return _AsyncSessionLocal
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""
|
||||
Initialize database (create tables via Alembic migrations).
|
||||
|
||||
Uses Alembic to apply migrations instead of directly creating tables.
|
||||
This ensures database schema versioning and the ability to rollback changes.
|
||||
"""
|
||||
try:
|
||||
from shared.database.migrations import init_db_with_migrations
|
||||
await init_db_with_migrations()
|
||||
except ImportError:
|
||||
# Fallback to old method if migrations not configured
|
||||
logger.warning("Alembic not configured, using direct table creation")
|
||||
engine = get_engine()
|
||||
# Determine database type from URL
|
||||
db_type = "SQLite" if "sqlite" in settings.DATABASE_URL.lower() else "PostgreSQL"
|
||||
logger.info(f"Initializing {db_type} database...")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info(f"{db_type} database tables successfully created")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""
|
||||
Get database session (generator for use with Depends in FastAPI)
|
||||
Used by both bot and web interface
|
||||
"""
|
||||
AsyncSessionLocal = get_session_factory()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_async_session_local():
|
||||
"""
|
||||
Get session factory for direct use (e.g., in bot)
|
||||
"""
|
||||
return get_session_factory()
|
||||
|
||||
|
||||
# For backward compatibility - export session factory as AsyncSessionLocal
|
||||
# Use proxy class that behaves like async_sessionmaker
|
||||
class AsyncSessionLocalProxy:
|
||||
"""
|
||||
Proxy for AsyncSessionLocal that initializes session factory on first use.
|
||||
Allows using AsyncSessionLocal() like async_sessionmaker().
|
||||
"""
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calling AsyncSessionLocal() creates a new session"""
|
||||
factory = get_session_factory()
|
||||
return factory(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Proxy attributes to session factory"""
|
||||
factory = get_session_factory()
|
||||
return getattr(factory, name)
|
||||
|
||||
# Create proxy instance for backward compatibility
|
||||
AsyncSessionLocal = AsyncSessionLocalProxy()
|
||||
|
||||
Reference in New Issue
Block a user