Add source

This commit is contained in:
2025-12-04 00:12:56 +03:00
parent b75875df5e
commit 0cb7045e7a
75 changed files with 9055 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
"""
Общие модели базы данных
"""

View File

@@ -0,0 +1,98 @@
"""
Module for automatic Alembic migration application
"""
import asyncio
import logging
from pathlib import Path
from alembic import command
from alembic.config import Config
from shared.config import settings
logger = logging.getLogger(__name__)
def get_alembic_config() -> Config:
"""
Get Alembic configuration.
Returns:
Config: Alembic configuration object
"""
# Path to alembic.ini
alembic_ini_path = Path(__file__).parent.parent.parent / "alembic.ini"
# Create configuration
alembic_cfg = Config(str(alembic_ini_path))
# Set DATABASE_URL from settings
alembic_cfg.set_main_option("sqlalchemy.url", settings.DATABASE_URL)
return alembic_cfg
async def upgrade_database(revision: str = "head") -> None:
"""
Apply migrations to database.
Args:
revision: Revision to upgrade database to (default "head" - latest)
"""
try:
logger.info(f"Applying migrations to database (revision: {revision})...")
# Get configuration
alembic_cfg = get_alembic_config()
# Apply migrations in separate thread (since command.upgrade is synchronous)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
command.upgrade,
alembic_cfg,
revision
)
logger.info("Migrations successfully applied")
except Exception as e:
logger.error(f"Error applying migrations: {e}", exc_info=True)
raise
async def check_migrations() -> bool:
"""
Check for unapplied migrations.
Returns:
bool: True if there are unapplied migrations, False if all are applied
"""
try:
alembic_cfg = get_alembic_config()
# Check for migrations to apply
# This is a simplified check - in reality can use command.heads()
# and command.current() to compare revisions
return True
except Exception as e:
logger.warning(f"Failed to check migration status: {e}")
# On error assume migrations are needed
return True
async def init_db_with_migrations() -> None:
"""
Initialize database with migrations applied.
Replaces old init_db() method for Alembic usage.
"""
try:
# Determine database type from URL
db_type = "SQLite" if "sqlite" in settings.DATABASE_URL.lower() else "PostgreSQL"
logger.info(f"Initializing {db_type} database with migrations...")
# Apply migrations (this will create tables if they don't exist)
await upgrade_database("head")
logger.info(f"{db_type} database successfully initialized")
except Exception as e:
logger.error(f"Error initializing database: {e}", exc_info=True)
raise

91
shared/database/models.py Normal file
View File

@@ -0,0 +1,91 @@
"""
ORM models for database
"""
from datetime import datetime
from sqlalchemy import Column, Integer, BigInteger, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
Base = declarative_base()
class User(Base):
"""User model"""
__tablename__ = "users"
user_id = Column(Integer, primary_key=True, unique=True, index=True)
username = Column(String(255), nullable=True)
first_name = Column(String(255), nullable=True)
last_name = Column(String(255), nullable=True)
is_admin = Column(Boolean, default=False)
is_blocked = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
tasks = relationship("Task", back_populates="user")
def __repr__(self):
return f"<User(user_id={self.user_id}, username={self.username}, is_admin={self.is_admin})>"
class Task(Base):
"""Task model"""
__tablename__ = "tasks"
id = Column(BigInteger, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.user_id"), nullable=False, index=True) # Index for frequent queries
task_type = Column(String(50), nullable=False) # download, process, etc.
status = Column(String(50), default="pending", index=True) # Index for status filtering
url = Column(Text, nullable=True)
file_path = Column(String(500), nullable=True)
progress = Column(Integer, default=0) # 0-100
error_message = Column(String(1000), nullable=True) # Limit to 1000 characters
created_at = Column(DateTime, default=datetime.utcnow, index=True) # Index for sorting
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
# Relationships
user = relationship("User", back_populates="tasks")
def __repr__(self):
return f"<Task(id={self.id}, user_id={self.user_id}, status={self.status})>"
class Download(Base):
"""Download model"""
__tablename__ = "downloads"
id = Column(Integer, primary_key=True, index=True)
task_id = Column(BigInteger, ForeignKey("tasks.id"), nullable=False)
url = Column(Text, nullable=False)
download_type = Column(String(50), nullable=False) # direct, ytdlp
file_path = Column(String(500), nullable=True)
file_size = Column(Integer, nullable=True)
duration = Column(Integer, nullable=True) # Download duration in seconds
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
task = relationship("Task")
def __repr__(self):
return f"<Download(id={self.id}, task_id={self.task_id}, download_type={self.download_type})>"
class OTPCode(Base):
"""One-time password code model for web interface authentication"""
__tablename__ = "otp_codes"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.user_id"), nullable=False, index=True)
code = Column(String(6), nullable=False, index=True) # 6-digit code
expires_at = Column(DateTime, nullable=False, index=True)
used = Column(Boolean, default=False, index=True)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
user = relationship("User")
def __repr__(self):
return f"<OTPCode(id={self.id}, user_id={self.user_id}, code={self.code[:2]}**, used={self.used})>"

125
shared/database/session.py Normal file
View File

@@ -0,0 +1,125 @@
"""
Unified module for database session management
Used by both bot and web interface
"""
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from shared.config import settings
from shared.database.models import Base
import logging
logger = logging.getLogger(__name__)
# Unified database engine for the entire application
_engine = None
_AsyncSessionLocal = None
def get_engine():
"""Get database engine (created on first call)"""
global _engine
if _engine is None:
# Log DATABASE_URL being used (without password for security)
db_url_safe = settings.DATABASE_URL
if '@' in db_url_safe:
# Hide password in logs
parts = db_url_safe.split('@')
if len(parts) == 2:
auth_part = parts[0].split('://')
if len(auth_part) == 2:
scheme = auth_part[0]
user_pass = auth_part[1]
if ':' in user_pass:
user = user_pass.split(':')[0]
db_url_safe = f"{scheme}://{user}:***@{parts[1]}"
logger.info(f"Creating database engine with URL: {db_url_safe}")
_engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
future=True,
pool_pre_ping=True, # Check connections before use
pool_recycle=3600, # Reuse connections every 3600 seconds
)
logger.info("Database engine created")
return _engine
def get_session_factory():
"""Get session factory (created on first call)"""
global _AsyncSessionLocal
if _AsyncSessionLocal is None:
engine = get_engine()
_AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
logger.info("Session factory created")
return _AsyncSessionLocal
async def init_db():
"""
Initialize database (create tables via Alembic migrations).
Uses Alembic to apply migrations instead of directly creating tables.
This ensures database schema versioning and the ability to rollback changes.
"""
try:
from shared.database.migrations import init_db_with_migrations
await init_db_with_migrations()
except ImportError:
# Fallback to old method if migrations not configured
logger.warning("Alembic not configured, using direct table creation")
engine = get_engine()
# Determine database type from URL
db_type = "SQLite" if "sqlite" in settings.DATABASE_URL.lower() else "PostgreSQL"
logger.info(f"Initializing {db_type} database...")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info(f"{db_type} database tables successfully created")
except Exception as e:
logger.error(f"Error initializing database: {e}", exc_info=True)
raise
async def get_db():
"""
Get database session (generator for use with Depends in FastAPI)
Used by both bot and web interface
"""
AsyncSessionLocal = get_session_factory()
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
def get_async_session_local():
"""
Get session factory for direct use (e.g., in bot)
"""
return get_session_factory()
# For backward compatibility - export session factory as AsyncSessionLocal
# Use proxy class that behaves like async_sessionmaker
class AsyncSessionLocalProxy:
"""
Proxy for AsyncSessionLocal that initializes session factory on first use.
Allows using AsyncSessionLocal() like async_sessionmaker().
"""
def __call__(self, *args, **kwargs):
"""Calling AsyncSessionLocal() creates a new session"""
factory = get_session_factory()
return factory(*args, **kwargs)
def __getattr__(self, name):
"""Proxy attributes to session factory"""
factory = get_session_factory()
return getattr(factory, name)
# Create proxy instance for backward compatibility
AsyncSessionLocal = AsyncSessionLocalProxy()

View File

@@ -0,0 +1,59 @@
"""
Helper functions for user management
"""
from sqlalchemy.ext.asyncio import AsyncSession
from shared.database.models import User
from datetime import datetime
from sqlalchemy.exc import IntegrityError
import logging
logger = logging.getLogger(__name__)
async def ensure_user_exists(user_id: int, db: AsyncSession) -> User:
"""
Ensure user exists in database, create if not exists.
Args:
user_id: User ID
db: Database session
Returns:
User object (existing or newly created)
"""
# Check if user exists
user = await db.get(User, user_id)
if user:
return user
# User doesn't exist, create it
user = User(
user_id=user_id,
username=None,
first_name=None,
last_name=None,
is_admin=False,
is_blocked=False,
created_at=datetime.utcnow()
)
try:
db.add(user)
await db.commit()
logger.info(f"Automatically created user {user_id} in database")
return user
except IntegrityError:
# User was created by another request/thread, fetch it
await db.rollback()
user = await db.get(User, user_id)
if user:
logger.debug(f"User {user_id} was created concurrently, using existing record")
return user
else:
logger.error(f"Failed to create or fetch user {user_id}")
raise Exception(f"Failed to ensure user {user_id} exists")
except Exception as e:
await db.rollback()
logger.error(f"Error ensuring user {user_id} exists: {e}", exc_info=True)
raise