Add source
This commit is contained in:
4
shared/__init__.py
Normal file
4
shared/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Shared modules for bot and web application
|
||||
"""
|
||||
|
||||
131
shared/config.py
Normal file
131
shared/config.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Shared configuration for bot and web application
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
# Telegram Bot
|
||||
BOT_TOKEN: str
|
||||
TELEGRAM_API_ID: int
|
||||
TELEGRAM_API_HASH: str
|
||||
OWNER_ID: int
|
||||
|
||||
# Authorization
|
||||
AUTHORIZED_USERS: str = ""
|
||||
ADMIN_IDS: str = ""
|
||||
BLOCKED_USERS: str = ""
|
||||
PRIVATE_MODE: bool = False # If True, only users from AUTHORIZED_USERS or database can use the bot
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///./data/bot.db"
|
||||
|
||||
# Redis (for sessions)
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_DB: int = 0
|
||||
USE_REDIS_SESSIONS: bool = False # Use Redis for sessions instead of in-memory
|
||||
|
||||
# Web
|
||||
WEB_HOST: str = "0.0.0.0"
|
||||
WEB_PORT: int = 5000
|
||||
WEB_SECRET_KEY: str = ""
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = "INFO"
|
||||
LOG_FILE: str = "logs/bot.log"
|
||||
|
||||
# Media Download
|
||||
COOKIES_FILE: Optional[str] = None # Path to cookies file (Netscape format) for Instagram and other sites
|
||||
|
||||
# Download Limits
|
||||
MAX_FILE_SIZE: Optional[str] = None # Maximum file size in bytes (empty or None = no limit)
|
||||
MAX_DURATION_MINUTES: Optional[str] = None # Maximum duration in minutes (empty or None = no limit)
|
||||
MAX_CONCURRENT_TASKS: int = 5 # Maximum number of concurrent tasks per user
|
||||
|
||||
@property
|
||||
def max_file_size_bytes(self) -> Optional[int]:
|
||||
"""Get maximum file size in bytes"""
|
||||
if not self.MAX_FILE_SIZE or self.MAX_FILE_SIZE.strip() == '' or self.MAX_FILE_SIZE.lower() == 'none':
|
||||
return None
|
||||
try:
|
||||
return int(self.MAX_FILE_SIZE)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@property
|
||||
def max_duration_minutes_int(self) -> Optional[int]:
|
||||
"""Get maximum duration in minutes"""
|
||||
if not self.MAX_DURATION_MINUTES or self.MAX_DURATION_MINUTES.strip() == '' or self.MAX_DURATION_MINUTES.lower() == 'none':
|
||||
return None
|
||||
try:
|
||||
return int(self.MAX_DURATION_MINUTES)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
class Config:
|
||||
# Configuration file structure:
|
||||
# - .env - for Docker (used via docker-compose env_file)
|
||||
# - .env.local - for local development (used here if not in Docker)
|
||||
# In Docker, environment variables are passed through environment in docker-compose
|
||||
# env_file is only used for local development
|
||||
# In Docker, env_file should not be loaded as variables are passed through environment
|
||||
# Load .env.local only if not in Docker (determined by DOCKER_ENV environment variable)
|
||||
# In Docker, environment variables have priority over env_file
|
||||
# If DOCKER_ENV is set, do not load .env.local file
|
||||
env_file = None if os.getenv("DOCKER_ENV") else ".env.local"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = True
|
||||
# Pydantic Settings automatically reads environment variables from system
|
||||
# Priority: environment variables > env_file > default values
|
||||
|
||||
@property
|
||||
def authorized_users_list(self) -> list[int]:
|
||||
"""List of authorized users"""
|
||||
if not self.AUTHORIZED_USERS:
|
||||
return []
|
||||
return [int(uid.strip()) for uid in self.AUTHORIZED_USERS.split(",") if uid.strip()]
|
||||
|
||||
@property
|
||||
def admin_ids_list(self) -> list[int]:
|
||||
"""List of administrators"""
|
||||
if not self.ADMIN_IDS:
|
||||
return []
|
||||
return [int(uid.strip()) for uid in self.ADMIN_IDS.split(",") if uid.strip()]
|
||||
|
||||
@property
|
||||
def blocked_users_list(self) -> list[int]:
|
||||
"""List of blocked users"""
|
||||
if not self.BLOCKED_USERS:
|
||||
return []
|
||||
return [int(uid.strip()) for uid in self.BLOCKED_USERS.split(",") if uid.strip()]
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
|
||||
# Log the DATABASE_URL being used on load (for debugging)
|
||||
import logging
|
||||
_logger = logging.getLogger(__name__)
|
||||
# Log only if DATABASE_URL is not default SQLite
|
||||
if settings.DATABASE_URL and "sqlite" not in settings.DATABASE_URL.lower():
|
||||
db_url_safe = settings.DATABASE_URL
|
||||
if '@' in db_url_safe:
|
||||
# Hide password in logs
|
||||
parts = db_url_safe.split('@')
|
||||
if len(parts) == 2:
|
||||
auth_part = parts[0].split('://')
|
||||
if len(auth_part) == 2:
|
||||
scheme = auth_part[0]
|
||||
user_pass = auth_part[1]
|
||||
if ':' in user_pass:
|
||||
user = user_pass.split(':')[0]
|
||||
db_url_safe = f"{scheme}://{user}:***@{parts[1]}"
|
||||
_logger.info(f"Using DATABASE_URL: {db_url_safe}")
|
||||
else:
|
||||
_logger.warning(f"⚠️ Using SQLite database: {settings.DATABASE_URL}. For Docker, use PostgreSQL!")
|
||||
|
||||
4
shared/database/__init__.py
Normal file
4
shared/database/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Общие модели базы данных
|
||||
"""
|
||||
|
||||
98
shared/database/migrations.py
Normal file
98
shared/database/migrations.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Module for automatic Alembic migration application
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from shared.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_alembic_config() -> Config:
|
||||
"""
|
||||
Get Alembic configuration.
|
||||
|
||||
Returns:
|
||||
Config: Alembic configuration object
|
||||
"""
|
||||
# Path to alembic.ini
|
||||
alembic_ini_path = Path(__file__).parent.parent.parent / "alembic.ini"
|
||||
|
||||
# Create configuration
|
||||
alembic_cfg = Config(str(alembic_ini_path))
|
||||
|
||||
# Set DATABASE_URL from settings
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", settings.DATABASE_URL)
|
||||
|
||||
return alembic_cfg
|
||||
|
||||
|
||||
async def upgrade_database(revision: str = "head") -> None:
|
||||
"""
|
||||
Apply migrations to database.
|
||||
|
||||
Args:
|
||||
revision: Revision to upgrade database to (default "head" - latest)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Applying migrations to database (revision: {revision})...")
|
||||
|
||||
# Get configuration
|
||||
alembic_cfg = get_alembic_config()
|
||||
|
||||
# Apply migrations in separate thread (since command.upgrade is synchronous)
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
command.upgrade,
|
||||
alembic_cfg,
|
||||
revision
|
||||
)
|
||||
|
||||
logger.info("Migrations successfully applied")
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying migrations: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def check_migrations() -> bool:
|
||||
"""
|
||||
Check for unapplied migrations.
|
||||
|
||||
Returns:
|
||||
bool: True if there are unapplied migrations, False if all are applied
|
||||
"""
|
||||
try:
|
||||
alembic_cfg = get_alembic_config()
|
||||
|
||||
# Check for migrations to apply
|
||||
# This is a simplified check - in reality can use command.heads()
|
||||
# and command.current() to compare revisions
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check migration status: {e}")
|
||||
# On error assume migrations are needed
|
||||
return True
|
||||
|
||||
|
||||
async def init_db_with_migrations() -> None:
|
||||
"""
|
||||
Initialize database with migrations applied.
|
||||
Replaces old init_db() method for Alembic usage.
|
||||
"""
|
||||
try:
|
||||
# Determine database type from URL
|
||||
db_type = "SQLite" if "sqlite" in settings.DATABASE_URL.lower() else "PostgreSQL"
|
||||
logger.info(f"Initializing {db_type} database with migrations...")
|
||||
|
||||
# Apply migrations (this will create tables if they don't exist)
|
||||
await upgrade_database("head")
|
||||
|
||||
logger.info(f"{db_type} database successfully initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
91
shared/database/models.py
Normal file
91
shared/database/models.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
ORM models for database
|
||||
"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, BigInteger, String, Boolean, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model"""
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id = Column(Integer, primary_key=True, unique=True, index=True)
|
||||
username = Column(String(255), nullable=True)
|
||||
first_name = Column(String(255), nullable=True)
|
||||
last_name = Column(String(255), nullable=True)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
is_blocked = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
tasks = relationship("Task", back_populates="user")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(user_id={self.user_id}, username={self.username}, is_admin={self.is_admin})>"
|
||||
|
||||
|
||||
class Task(Base):
|
||||
"""Task model"""
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id = Column(BigInteger, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.user_id"), nullable=False, index=True) # Index for frequent queries
|
||||
task_type = Column(String(50), nullable=False) # download, process, etc.
|
||||
status = Column(String(50), default="pending", index=True) # Index for status filtering
|
||||
url = Column(Text, nullable=True)
|
||||
file_path = Column(String(500), nullable=True)
|
||||
progress = Column(Integer, default=0) # 0-100
|
||||
error_message = Column(String(1000), nullable=True) # Limit to 1000 characters
|
||||
created_at = Column(DateTime, default=datetime.utcnow, index=True) # Index for sorting
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="tasks")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Task(id={self.id}, user_id={self.user_id}, status={self.status})>"
|
||||
|
||||
|
||||
class Download(Base):
|
||||
"""Download model"""
|
||||
__tablename__ = "downloads"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_id = Column(BigInteger, ForeignKey("tasks.id"), nullable=False)
|
||||
url = Column(Text, nullable=False)
|
||||
download_type = Column(String(50), nullable=False) # direct, ytdlp
|
||||
file_path = Column(String(500), nullable=True)
|
||||
file_size = Column(Integer, nullable=True)
|
||||
duration = Column(Integer, nullable=True) # Download duration in seconds
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
task = relationship("Task")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Download(id={self.id}, task_id={self.task_id}, download_type={self.download_type})>"
|
||||
|
||||
|
||||
class OTPCode(Base):
|
||||
"""One-time password code model for web interface authentication"""
|
||||
__tablename__ = "otp_codes"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.user_id"), nullable=False, index=True)
|
||||
code = Column(String(6), nullable=False, index=True) # 6-digit code
|
||||
expires_at = Column(DateTime, nullable=False, index=True)
|
||||
used = Column(Boolean, default=False, index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OTPCode(id={self.id}, user_id={self.user_id}, code={self.code[:2]}**, used={self.used})>"
|
||||
|
||||
125
shared/database/session.py
Normal file
125
shared/database/session.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Unified module for database session management
|
||||
Used by both bot and web interface
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from shared.config import settings
|
||||
from shared.database.models import Base
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Unified database engine for the entire application
|
||||
_engine = None
|
||||
_AsyncSessionLocal = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Get database engine (created on first call)"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
# Log DATABASE_URL being used (without password for security)
|
||||
db_url_safe = settings.DATABASE_URL
|
||||
if '@' in db_url_safe:
|
||||
# Hide password in logs
|
||||
parts = db_url_safe.split('@')
|
||||
if len(parts) == 2:
|
||||
auth_part = parts[0].split('://')
|
||||
if len(auth_part) == 2:
|
||||
scheme = auth_part[0]
|
||||
user_pass = auth_part[1]
|
||||
if ':' in user_pass:
|
||||
user = user_pass.split(':')[0]
|
||||
db_url_safe = f"{scheme}://{user}:***@{parts[1]}"
|
||||
logger.info(f"Creating database engine with URL: {db_url_safe}")
|
||||
|
||||
_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
pool_pre_ping=True, # Check connections before use
|
||||
pool_recycle=3600, # Reuse connections every 3600 seconds
|
||||
)
|
||||
logger.info("Database engine created")
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory():
|
||||
"""Get session factory (created on first call)"""
|
||||
global _AsyncSessionLocal
|
||||
if _AsyncSessionLocal is None:
|
||||
engine = get_engine()
|
||||
_AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
logger.info("Session factory created")
|
||||
return _AsyncSessionLocal
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""
|
||||
Initialize database (create tables via Alembic migrations).
|
||||
|
||||
Uses Alembic to apply migrations instead of directly creating tables.
|
||||
This ensures database schema versioning and the ability to rollback changes.
|
||||
"""
|
||||
try:
|
||||
from shared.database.migrations import init_db_with_migrations
|
||||
await init_db_with_migrations()
|
||||
except ImportError:
|
||||
# Fallback to old method if migrations not configured
|
||||
logger.warning("Alembic not configured, using direct table creation")
|
||||
engine = get_engine()
|
||||
# Determine database type from URL
|
||||
db_type = "SQLite" if "sqlite" in settings.DATABASE_URL.lower() else "PostgreSQL"
|
||||
logger.info(f"Initializing {db_type} database...")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info(f"{db_type} database tables successfully created")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""
|
||||
Get database session (generator for use with Depends in FastAPI)
|
||||
Used by both bot and web interface
|
||||
"""
|
||||
AsyncSessionLocal = get_session_factory()
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_async_session_local():
|
||||
"""
|
||||
Get session factory for direct use (e.g., in bot)
|
||||
"""
|
||||
return get_session_factory()
|
||||
|
||||
|
||||
# For backward compatibility - export session factory as AsyncSessionLocal
|
||||
# Use proxy class that behaves like async_sessionmaker
|
||||
class AsyncSessionLocalProxy:
|
||||
"""
|
||||
Proxy for AsyncSessionLocal that initializes session factory on first use.
|
||||
Allows using AsyncSessionLocal() like async_sessionmaker().
|
||||
"""
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calling AsyncSessionLocal() creates a new session"""
|
||||
factory = get_session_factory()
|
||||
return factory(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Proxy attributes to session factory"""
|
||||
factory = get_session_factory()
|
||||
return getattr(factory, name)
|
||||
|
||||
# Create proxy instance for backward compatibility
|
||||
AsyncSessionLocal = AsyncSessionLocalProxy()
|
||||
|
||||
59
shared/database/user_helpers.py
Normal file
59
shared/database/user_helpers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Helper functions for user management
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from shared.database.models import User
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def ensure_user_exists(user_id: int, db: AsyncSession) -> User:
|
||||
"""
|
||||
Ensure user exists in database, create if not exists.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
User object (existing or newly created)
|
||||
"""
|
||||
# Check if user exists
|
||||
user = await db.get(User, user_id)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# User doesn't exist, create it
|
||||
user = User(
|
||||
user_id=user_id,
|
||||
username=None,
|
||||
first_name=None,
|
||||
last_name=None,
|
||||
is_admin=False,
|
||||
is_blocked=False,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
logger.info(f"Automatically created user {user_id} in database")
|
||||
return user
|
||||
except IntegrityError:
|
||||
# User was created by another request/thread, fetch it
|
||||
await db.rollback()
|
||||
user = await db.get(User, user_id)
|
||||
if user:
|
||||
logger.debug(f"User {user_id} was created concurrently, using existing record")
|
||||
return user
|
||||
else:
|
||||
logger.error(f"Failed to create or fetch user {user_id}")
|
||||
raise Exception(f"Failed to ensure user {user_id} exists")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error ensuring user {user_id} exists: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user