Add source
This commit is contained in:
4
bot/modules/__init__.py
Normal file
4
bot/modules/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Модули Telegram бота
|
||||
"""
|
||||
|
||||
4
bot/modules/access_control/__init__.py
Normal file
4
bot/modules/access_control/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Access control module
|
||||
"""
|
||||
|
||||
98
bot/modules/access_control/auth.py
Normal file
98
bot/modules/access_control/auth.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
User authorization
|
||||
"""
|
||||
from typing import Optional
|
||||
from bot.config import settings
|
||||
from bot.modules.database.session import AsyncSessionLocal
|
||||
from bot.modules.database.models import User
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def is_authorized(user_id: int) -> bool:
|
||||
"""
|
||||
Check user authorization
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if user is authorized, False otherwise
|
||||
"""
|
||||
# Check blacklist
|
||||
if user_id in settings.blocked_users_list:
|
||||
return False
|
||||
|
||||
# Check in database
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if user and user.is_blocked:
|
||||
return False
|
||||
|
||||
# If private mode is enabled, check only whitelist
|
||||
if settings.PRIVATE_MODE:
|
||||
# Check in configuration
|
||||
if user_id in settings.authorized_users_list:
|
||||
return True
|
||||
|
||||
# Check in database (users added via /adduser)
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if user and not user.is_blocked:
|
||||
return True
|
||||
|
||||
# In private mode, access only for authorized users
|
||||
return False
|
||||
|
||||
# If private mode is disabled
|
||||
# Check whitelist (if configured)
|
||||
if settings.authorized_users_list:
|
||||
return user_id in settings.authorized_users_list
|
||||
|
||||
# If whitelist is not configured, check in database
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
# If user exists in database and is not blocked - allow access
|
||||
if user and not user.is_blocked:
|
||||
return True
|
||||
|
||||
# By default - deny access
|
||||
return False
|
||||
|
||||
|
||||
async def is_admin(user_id: int) -> bool:
|
||||
"""
|
||||
Check if user is administrator
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if administrator, False otherwise
|
||||
"""
|
||||
# Check in configuration
|
||||
if user_id in settings.admin_ids_list:
|
||||
return True
|
||||
|
||||
# Check in database
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if user and user.is_admin:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def is_owner(user_id: int) -> bool:
|
||||
"""
|
||||
Check if user is owner
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if owner, False otherwise
|
||||
"""
|
||||
return user_id == settings.OWNER_ID
|
||||
|
||||
49
bot/modules/access_control/middleware.py
Normal file
49
bot/modules/access_control/middleware.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Access control middleware
|
||||
"""
|
||||
from pyrogram import Client
|
||||
from pyrogram.handlers import MessageHandler, CallbackQueryHandler
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def access_middleware(client: Client, update, *args, **kwargs):
|
||||
"""
|
||||
Middleware for checking bot access
|
||||
|
||||
Args:
|
||||
client: Pyrogram client
|
||||
update: Telegram update
|
||||
"""
|
||||
user_id = None
|
||||
|
||||
if hasattr(update, 'from_user') and update.from_user:
|
||||
user_id = update.from_user.id
|
||||
elif hasattr(update, 'message') and update.message and update.message.from_user:
|
||||
user_id = update.message.from_user.id
|
||||
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(user_id):
|
||||
logger.warning(f"Unauthorized user access attempt: {user_id}")
|
||||
if hasattr(update, 'message') and update.message:
|
||||
await update.message.reply("❌ You don't have access to this bot")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def setup_middleware(app: Client):
|
||||
"""
|
||||
Setup middleware for application
|
||||
|
||||
Args:
|
||||
app: Pyrogram client
|
||||
"""
|
||||
# Middleware will be applied via decorators in handlers
|
||||
logger.info("Access control middleware configured")
|
||||
|
||||
55
bot/modules/access_control/permissions.py
Normal file
55
bot/modules/access_control/permissions.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Access permissions system
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Callable, Awaitable
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message, CallbackQuery
|
||||
from bot.modules.access_control.auth import is_authorized, is_admin, is_owner
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Permission(Enum):
|
||||
"""Access permission types"""
|
||||
USER = "user" # Regular user
|
||||
ADMIN = "admin" # Administrator
|
||||
OWNER = "owner" # Owner
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Decorator for checking access permissions
|
||||
|
||||
Args:
|
||||
permission: Required access level
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
async def wrapper(client: Client, message: Message, *args, **kwargs):
|
||||
user_id = message.from_user.id if message.from_user else None
|
||||
|
||||
if not user_id:
|
||||
await message.reply("❌ Failed to identify user")
|
||||
return
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(user_id):
|
||||
await message.reply("❌ You don't have access to this bot")
|
||||
return
|
||||
|
||||
# Check permissions
|
||||
if permission == Permission.OWNER:
|
||||
if not await is_owner(user_id):
|
||||
await message.reply("❌ This command is only available to owner")
|
||||
return
|
||||
elif permission == Permission.ADMIN:
|
||||
if not await is_admin(user_id):
|
||||
await message.reply("❌ This command is only available to administrators")
|
||||
return
|
||||
|
||||
return await func(client, message, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
249
bot/modules/access_control/user_manager.py
Normal file
249
bot/modules/access_control/user_manager.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
User and administrator management
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from bot.modules.database.session import AsyncSessionLocal
|
||||
from bot.modules.database.models import User
|
||||
from bot.modules.access_control.auth import is_admin, is_owner
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def add_user(user_id: int, username: str = None, first_name: str = None, last_name: str = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Add user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
username: Username (if not specified, will be fetched from Telegram API)
|
||||
first_name: First name (if not specified, will be fetched from Telegram API)
|
||||
last_name: Last name (if not specified, will be fetched from Telegram API)
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
# Check existence
|
||||
existing_user = await session.get(User, user_id)
|
||||
if existing_user:
|
||||
# Update user information if missing
|
||||
updated = False
|
||||
if not existing_user.username and username:
|
||||
existing_user.username = username
|
||||
updated = True
|
||||
if not existing_user.first_name and first_name:
|
||||
existing_user.first_name = first_name
|
||||
updated = True
|
||||
if not existing_user.last_name and last_name:
|
||||
existing_user.last_name = last_name
|
||||
updated = True
|
||||
|
||||
# If information is missing, try to get from Telegram API
|
||||
if not existing_user.username or not existing_user.first_name:
|
||||
try:
|
||||
from bot.utils.user_info_updater import update_user_info_from_telegram
|
||||
if await update_user_info_from_telegram(user_id, db_session=session):
|
||||
updated = True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get user {user_id} information from Telegram: {e}")
|
||||
|
||||
if updated:
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} information updated")
|
||||
return (True, f"Пользователь {user_id} уже существует, информация обновлена")
|
||||
return (False, f"Пользователь {user_id} уже существует")
|
||||
|
||||
# If username/first_name/last_name not specified, get from Telegram API
|
||||
if not username or not first_name:
|
||||
try:
|
||||
from bot.utils.telegram_user import get_user_info
|
||||
user_info = await get_user_info(user_id)
|
||||
if user_info:
|
||||
if not username:
|
||||
username = user_info.get('username')
|
||||
if not first_name:
|
||||
first_name = user_info.get('first_name')
|
||||
if not last_name:
|
||||
last_name = user_info.get('last_name')
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get user {user_id} information from Telegram: {e}")
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_admin=False,
|
||||
is_blocked=False
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} added (username: {username})")
|
||||
return (True, f"Пользователь {user_id} успешно добавлен")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding user: {e}", exc_info=True)
|
||||
return (False, f"Ошибка базы данных: {str(e)}")
|
||||
|
||||
|
||||
async def remove_user(user_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Remove user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
return (False, f"Пользователь {user_id} не найден в базе данных")
|
||||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} removed")
|
||||
return (True, f"Пользователь {user_id} успешно удален")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing user: {e}", exc_info=True)
|
||||
return (False, f"Ошибка базы данных: {str(e)}")
|
||||
|
||||
|
||||
async def block_user(user_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Block user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
return (False, f"Пользователь {user_id} не найден в базе данных")
|
||||
|
||||
if user.is_blocked:
|
||||
return (False, f"Пользователь {user_id} уже заблокирован")
|
||||
|
||||
user.is_blocked = True
|
||||
user.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} blocked")
|
||||
return (True, f"Пользователь {user_id} успешно заблокирован")
|
||||
except Exception as e:
|
||||
logger.error(f"Error blocking user: {e}", exc_info=True)
|
||||
return (False, f"Ошибка базы данных: {str(e)}")
|
||||
|
||||
|
||||
async def unblock_user(user_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Unblock user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
return (False, f"Пользователь {user_id} не найден в базе данных")
|
||||
|
||||
if not user.is_blocked:
|
||||
return (False, f"Пользователь {user_id} не заблокирован")
|
||||
|
||||
user.is_blocked = False
|
||||
user.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} unblocked")
|
||||
return (True, f"Пользователь {user_id} успешно разблокирован")
|
||||
except Exception as e:
|
||||
logger.error(f"Error unblocking user: {e}", exc_info=True)
|
||||
return (False, f"Ошибка базы данных: {str(e)}")
|
||||
|
||||
|
||||
async def add_admin(user_id: int, requester_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Assign administrator
|
||||
|
||||
Args:
|
||||
user_id: User ID to assign as admin
|
||||
requester_id: ID of user making the request
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
# Check permissions
|
||||
if not await is_admin(requester_id):
|
||||
return (False, "У вас нет прав администратора")
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
# Create user if doesn't exist
|
||||
user = User(user_id=user_id, is_admin=True)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} created and assigned as administrator")
|
||||
return (True, f"Пользователь {user_id} создан и назначен администратором")
|
||||
else:
|
||||
if user.is_admin:
|
||||
return (False, f"Пользователь {user_id} уже является администратором")
|
||||
user.is_admin = True
|
||||
user.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.info(f"User {user_id} assigned as administrator")
|
||||
return (True, f"Пользователь {user_id} успешно назначен администратором")
|
||||
except Exception as e:
|
||||
logger.error(f"Error assigning administrator: {e}", exc_info=True)
|
||||
return (False, f"Ошибка базы данных: {str(e)}")
|
||||
|
||||
|
||||
async def remove_admin(user_id: int, requester_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Remove administrator privileges
|
||||
|
||||
Args:
|
||||
user_id: User ID to remove admin privileges from
|
||||
requester_id: ID of user making the request
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
# Check permissions
|
||||
if not await is_admin(requester_id):
|
||||
return (False, "У вас нет прав администратора")
|
||||
|
||||
# Protection against self-removal
|
||||
if user_id == requester_id:
|
||||
return (False, "Вы не можете снять права администратора у самого себя")
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
return (False, f"Пользователь {user_id} не найден в базе данных")
|
||||
|
||||
if not user.is_admin:
|
||||
return (False, f"Пользователь {user_id} не является администратором")
|
||||
|
||||
user.is_admin = False
|
||||
user.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.info(f"Administrator privileges removed from user {user_id}")
|
||||
return (True, f"Права администратора успешно сняты у пользователя {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing administrator privileges: {e}", exc_info=True)
|
||||
return (False, f"Ошибка базы данных: {str(e)}")
|
||||
|
||||
4
bot/modules/database/__init__.py
Normal file
4
bot/modules/database/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Database module (ORM)
|
||||
"""
|
||||
|
||||
7
bot/modules/database/models.py
Normal file
7
bot/modules/database/models.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
ORM models for bot (imported from shared)
|
||||
"""
|
||||
from shared.database.models import User, Task, Download
|
||||
|
||||
__all__ = ["User", "Task", "Download"]
|
||||
|
||||
22
bot/modules/database/session.py
Normal file
22
bot/modules/database/session.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Database session management (wrapper over shared module)
|
||||
Uses unified module from shared/database/session.py
|
||||
"""
|
||||
from shared.database.session import (
|
||||
init_db,
|
||||
get_async_session_local,
|
||||
get_engine
|
||||
)
|
||||
|
||||
# For backward compatibility - get session factory
|
||||
def get_AsyncSessionLocal():
|
||||
"""Get session factory (for backward compatibility)"""
|
||||
return get_async_session_local()
|
||||
|
||||
# Create object for backward compatibility
|
||||
AsyncSessionLocal = get_async_session_local()
|
||||
engine = get_engine()
|
||||
|
||||
# Export functions
|
||||
__all__ = ['init_db', 'AsyncSessionLocal', 'engine']
|
||||
|
||||
4
bot/modules/media_loader/__init__.py
Normal file
4
bot/modules/media_loader/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Media loader module
|
||||
"""
|
||||
|
||||
68
bot/modules/media_loader/direct.py
Normal file
68
bot/modules/media_loader/direct.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Direct link downloads
|
||||
"""
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def download_file(url: str, output_path: str, chunk_size: int = 8192) -> bool:
|
||||
"""
|
||||
Download file from direct link
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
output_path: Path to save
|
||||
chunk_size: Chunk size for download
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"Download error: HTTP {response.status}")
|
||||
return False
|
||||
|
||||
async with aiofiles.open(output_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(chunk_size):
|
||||
await f.write(chunk)
|
||||
|
||||
logger.info(f"File downloaded: {output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_file_size(url: str) -> Optional[int]:
|
||||
"""
|
||||
Get file size from URL
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
|
||||
Returns:
|
||||
File size in bytes or None
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.head(url) as response:
|
||||
if response.status == 200:
|
||||
content_length = response.headers.get('Content-Length')
|
||||
if content_length:
|
||||
return int(content_length)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file size: {e}")
|
||||
|
||||
return None
|
||||
|
||||
270
bot/modules/media_loader/sender.py
Normal file
270
bot/modules/media_loader/sender.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Sending files to users
|
||||
"""
|
||||
from pathlib import Path
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message
|
||||
from typing import Optional
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def download_thumbnail(url: str, output_path: str) -> Optional[str]:
|
||||
"""
|
||||
Download thumbnail from URL
|
||||
|
||||
Args:
|
||||
url: Thumbnail URL
|
||||
output_path: Path to save
|
||||
|
||||
Returns:
|
||||
Path to downloaded file or None
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
async with aiofiles.open(output_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
await f.write(chunk)
|
||||
return output_path
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download thumbnail: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def send_file_to_user(
|
||||
client: Client,
|
||||
chat_id: int,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
thumbnail: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send file to user
|
||||
|
||||
Args:
|
||||
client: Pyrogram client
|
||||
chat_id: Chat ID
|
||||
file_path: Path to file
|
||||
caption: File caption
|
||||
thumbnail: Path to thumbnail or URL
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
thumbnail_path = None
|
||||
try:
|
||||
file = Path(file_path)
|
||||
|
||||
if not file.exists():
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
# Maximum file size for Telegram (2GB)
|
||||
max_size = 2 * 1024 * 1024 * 1024
|
||||
file_size = file.stat().st_size
|
||||
|
||||
# If file is larger than 2GB, split into parts
|
||||
if file_size > max_size:
|
||||
logger.info(f"File too large ({file_size / (1024*1024*1024):.2f} GB), splitting into parts...")
|
||||
return await send_large_file_in_parts(
|
||||
client, chat_id, file_path, caption, thumbnail
|
||||
)
|
||||
|
||||
# Process thumbnail (can be URL or file path)
|
||||
if thumbnail:
|
||||
if thumbnail.startswith(('http://', 'https://')):
|
||||
# This is a URL - download thumbnail
|
||||
thumbnail_path = f"downloads/thumb_{file.stem}.jpg"
|
||||
downloaded = await download_thumbnail(thumbnail, thumbnail_path)
|
||||
if downloaded:
|
||||
thumbnail_path = downloaded
|
||||
else:
|
||||
thumbnail_path = None # Don't use thumbnail if download failed
|
||||
else:
|
||||
# This is a file path
|
||||
thumb_file = Path(thumbnail)
|
||||
if thumb_file.exists():
|
||||
thumbnail_path = thumbnail
|
||||
else:
|
||||
thumbnail_path = None
|
||||
|
||||
# Determine file type
|
||||
if file.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv', '.webm']:
|
||||
# Video - if no thumbnail, try to generate one
|
||||
if not thumbnail_path:
|
||||
from bot.utils.file_processor import generate_thumbnail
|
||||
thumbnail_path_temp = f"downloads/thumb_{file.stem}.jpg"
|
||||
if await generate_thumbnail(str(file), thumbnail_path_temp):
|
||||
thumbnail_path = thumbnail_path_temp
|
||||
|
||||
await client.send_video(
|
||||
chat_id=chat_id,
|
||||
video=str(file),
|
||||
caption=caption,
|
||||
thumb=thumbnail_path
|
||||
)
|
||||
elif file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif', '.webp']:
|
||||
# Image
|
||||
await client.send_photo(
|
||||
chat_id=chat_id,
|
||||
photo=str(file),
|
||||
caption=caption
|
||||
)
|
||||
elif file.suffix.lower() in ['.mp3', '.wav', '.ogg', '.m4a', '.flac']:
|
||||
# Audio
|
||||
await client.send_audio(
|
||||
chat_id=chat_id,
|
||||
audio=str(file),
|
||||
caption=caption,
|
||||
thumb=thumbnail_path
|
||||
)
|
||||
else:
|
||||
# Document
|
||||
await client.send_document(
|
||||
chat_id=chat_id,
|
||||
document=str(file),
|
||||
caption=caption,
|
||||
thumb=thumbnail_path
|
||||
)
|
||||
|
||||
logger.info(f"File sent to user {chat_id}: {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending file: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Delete temporary thumbnail if it was downloaded
|
||||
if thumbnail_path and thumbnail_path.startswith("downloads/thumb_"):
|
||||
try:
|
||||
thumb_file = Path(thumbnail_path)
|
||||
if thumb_file.exists():
|
||||
thumb_file.unlink()
|
||||
logger.debug(f"Temporary thumbnail deleted: {thumbnail_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temporary thumbnail: {e}")
|
||||
|
||||
|
||||
async def send_large_file_in_parts(
|
||||
client: Client,
|
||||
chat_id: int,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
thumbnail: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send large file in parts
|
||||
|
||||
Args:
|
||||
client: Pyrogram client
|
||||
chat_id: Chat ID
|
||||
file_path: Path to file
|
||||
caption: File caption
|
||||
thumbnail: Path to thumbnail or URL
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
from bot.utils.file_splitter import split_file, delete_file_parts, get_part_info
|
||||
|
||||
parts = []
|
||||
try:
|
||||
# Split file into parts
|
||||
parts = await split_file(file_path)
|
||||
part_info = get_part_info(parts)
|
||||
total_parts = part_info["total_parts"]
|
||||
|
||||
logger.info(f"Sending file in parts: {total_parts} parts")
|
||||
|
||||
# Send each part
|
||||
for part_num, part_path in enumerate(parts, 1):
|
||||
part_caption = None
|
||||
if caption:
|
||||
part_caption = f"{caption}\n\n📦 Part {part_num} of {total_parts}"
|
||||
else:
|
||||
part_caption = f"📦 Part {part_num} of {total_parts}"
|
||||
|
||||
# Send thumbnail only with first part
|
||||
part_thumbnail = thumbnail if part_num == 1 else None
|
||||
|
||||
try:
|
||||
await client.send_document(
|
||||
chat_id=chat_id,
|
||||
document=str(part_path),
|
||||
caption=part_caption,
|
||||
thumb=part_thumbnail
|
||||
)
|
||||
logger.info(f"Sent part {part_num}/{total_parts}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending part {part_num}: {e}", exc_info=True)
|
||||
# Continue sending other parts
|
||||
continue
|
||||
|
||||
logger.info(f"File sent in parts to user {chat_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending large file in parts: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Delete file parts after sending
|
||||
if parts:
|
||||
await delete_file_parts(parts)
|
||||
logger.debug("File parts deleted")
|
||||
|
||||
|
||||
async def delete_file(file_path: str, max_retries: int = 3) -> bool:
|
||||
"""
|
||||
Delete file with retries
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
max_retries: Maximum number of retries
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
file = Path(file_path)
|
||||
if not file.exists():
|
||||
return True # File already doesn't exist
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
file.unlink()
|
||||
logger.info(f"File deleted: {file_path}")
|
||||
return True
|
||||
except PermissionError as e:
|
||||
# File may be locked by another process
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = (attempt + 1) * 0.5 # Exponential backoff
|
||||
logger.warning(f"File locked, retrying in {wait_time}s: {file_path}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Failed to delete file after {max_retries} attempts (locked): {file_path}")
|
||||
# Add to cleanup queue for background cleanup task
|
||||
from bot.utils.file_cleanup import add_file_to_cleanup_queue
|
||||
add_file_to_cleanup_queue(str(file_path))
|
||||
return False
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = (attempt + 1) * 0.5
|
||||
logger.warning(f"Error deleting file, retrying in {wait_time}s: {e}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Error deleting file after {max_retries} attempts: {e}", exc_info=True)
|
||||
# Add to cleanup queue for background cleanup task
|
||||
from bot.utils.file_cleanup import add_file_to_cleanup_queue
|
||||
add_file_to_cleanup_queue(str(file_path))
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
358
bot/modules/media_loader/ytdlp.py
Normal file
358
bot/modules/media_loader/ytdlp.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Downloads via yt-dlp
|
||||
"""
|
||||
import yt_dlp
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Callable
|
||||
import asyncio
|
||||
import threading
|
||||
import logging
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_progress_hook(progress_callback: Optional[Callable] = None, event_loop=None, cancel_event: Optional[threading.Event] = None, last_update_time: list = None):
|
||||
"""
|
||||
Create progress hook for tracking download progress
|
||||
|
||||
Args:
|
||||
progress_callback: Async callback function for updating progress
|
||||
event_loop: Event loop from main thread (for calling from executor)
|
||||
cancel_event: Event for checking download cancellation
|
||||
last_update_time: List to store last update time (for rate limiting)
|
||||
|
||||
Returns:
|
||||
Hook function for yt-dlp
|
||||
"""
|
||||
if last_update_time is None:
|
||||
last_update_time = [0]
|
||||
|
||||
def progress_hook(d: dict):
|
||||
# Check for cancellation
|
||||
if cancel_event and cancel_event.is_set():
|
||||
raise KeyboardInterrupt("Download cancelled")
|
||||
|
||||
if d.get('status') == 'downloading':
|
||||
percent = 0
|
||||
if 'total_bytes' in d and d['total_bytes']:
|
||||
percent = (d.get('downloaded_bytes', 0) / d['total_bytes']) * 100
|
||||
elif 'total_bytes_estimate' in d and d['total_bytes_estimate']:
|
||||
percent = (d.get('downloaded_bytes', 0) / d['total_bytes_estimate']) * 100
|
||||
|
||||
# Limit update frequency (no more than once per second)
|
||||
current_time = time.time()
|
||||
if progress_callback and percent > 0 and event_loop and (current_time - last_update_time[0] >= 1.0):
|
||||
try:
|
||||
last_update_time[0] = current_time
|
||||
# Use provided event loop for safe call from another thread
|
||||
# run_coroutine_threadsafe doesn't block current thread and doesn't block event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
progress_callback(int(percent)),
|
||||
event_loop
|
||||
)
|
||||
# Don't wait for completion (future.result()) to avoid blocking download
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating progress: {e}")
|
||||
|
||||
return progress_hook
|
||||
|
||||
|
||||
async def download_media(
|
||||
url: str,
|
||||
output_dir: str = "downloads",
|
||||
quality: str = "best",
|
||||
progress_callback: Optional[Callable] = None,
|
||||
cookies_file: Optional[str] = None,
|
||||
cancel_event: Optional[threading.Event] = None,
|
||||
task_id: Optional[int] = None
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Download media via yt-dlp
|
||||
|
||||
Args:
|
||||
url: Video/media URL
|
||||
output_dir: Directory for saving
|
||||
quality: Video quality (best, worst, 720p, etc.)
|
||||
progress_callback: Function for updating progress (accepts int 0-100)
|
||||
cookies_file: Path to cookies file (optional)
|
||||
cancel_event: Event for cancellation check (optional)
|
||||
task_id: Task ID for unique file naming (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary with downloaded file information or None
|
||||
"""
|
||||
try:
|
||||
# URL validation
|
||||
from bot.utils.helpers import is_valid_url
|
||||
if not is_valid_url(url):
|
||||
logger.error(f"Invalid or unsafe URL: {url}")
|
||||
return None
|
||||
|
||||
# Create directory
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check free disk space (minimum 1GB)
|
||||
import shutil
|
||||
try:
|
||||
disk_usage = shutil.disk_usage(output_dir)
|
||||
free_space_gb = disk_usage.free / (1024 ** 3)
|
||||
min_free_space_gb = 1.0 # Minimum 1GB free space
|
||||
if free_space_gb < min_free_space_gb:
|
||||
logger.error(f"Insufficient free disk space: {free_space_gb:.2f} GB (minimum {min_free_space_gb} GB required)")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check free disk space: {e}")
|
||||
|
||||
# Get event loop BEFORE starting executor to pass it to progress hook
|
||||
# Use get_running_loop() for explicit check that we're in async context
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# If no running loop, try to get current one (for backward compatibility)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# List to store last progress update time
|
||||
last_update_time = [0]
|
||||
|
||||
# Configure yt-dlp with progress hook that uses correct event loop
|
||||
progress_hook_func = create_progress_hook(
|
||||
progress_callback,
|
||||
event_loop=loop,
|
||||
cancel_event=cancel_event,
|
||||
last_update_time=last_update_time
|
||||
)
|
||||
|
||||
# Form unique filename with task_id to prevent conflicts
|
||||
if task_id:
|
||||
outtmpl = str(Path(output_dir) / f'%(title)s_[task_{task_id}].%(ext)s')
|
||||
else:
|
||||
outtmpl = str(Path(output_dir) / '%(title)s.%(ext)s')
|
||||
|
||||
# Configure format selector for maximum quality with correct aspect ratio
|
||||
# Priority: best video + best audio, or best combined format
|
||||
# This ensures we get the highest quality available while maintaining original proportions
|
||||
if quality == "best":
|
||||
# Format selector for maximum quality:
|
||||
# 1. bestvideo (highest quality video) + bestaudio (highest quality audio) - best quality
|
||||
# 2. best (best combined format if separate streams not available) - fallback
|
||||
# This selector maintains original aspect ratio and resolution
|
||||
format_selector = 'bestvideo+bestaudio/best'
|
||||
else:
|
||||
# Use custom quality if specified
|
||||
format_selector = quality
|
||||
|
||||
ydl_opts = {
|
||||
'format': format_selector,
|
||||
'outtmpl': outtmpl,
|
||||
'quiet': False,
|
||||
'no_warnings': False,
|
||||
'progress_hooks': [progress_hook_func],
|
||||
# Merge video and audio into single file (if separate streams)
|
||||
'merge_output_format': 'mp4',
|
||||
# Don't prefer free formats (they may be lower quality)
|
||||
'prefer_free_formats': False,
|
||||
# Additional options for better quality
|
||||
'writesubtitles': False,
|
||||
'writeautomaticsub': False,
|
||||
'ignoreerrors': False,
|
||||
}
|
||||
|
||||
# Add cookies if specified (for Instagram and other sites)
|
||||
if cookies_file:
|
||||
# Resolve cookies file path (support relative and absolute paths)
|
||||
cookies_path = Path(cookies_file)
|
||||
if not cookies_path.is_absolute():
|
||||
# If path is relative, search relative to project root
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
cookies_path = project_root / cookies_file
|
||||
# Also check current working directory
|
||||
if not cookies_path.exists():
|
||||
cookies_path = Path(cookies_file).resolve()
|
||||
|
||||
if cookies_path.exists():
|
||||
ydl_opts['cookiefile'] = str(cookies_path)
|
||||
logger.info(f"Using cookies from file: {cookies_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cookies file not found: {cookies_path} (original path: {cookies_file}). "
|
||||
f"Continuing without cookies. Check COOKIES_FILE path in configuration."
|
||||
)
|
||||
|
||||
def run_download():
|
||||
"""Synchronous function to execute in separate thread"""
|
||||
# This function runs in a separate thread (ThreadPoolExecutor)
|
||||
# progress hook will be called from this thread and use
|
||||
# run_coroutine_threadsafe for safe call in main event loop
|
||||
try:
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
# Check for cancellation before start
|
||||
if cancel_event and cancel_event.is_set():
|
||||
raise KeyboardInterrupt("Download cancelled")
|
||||
|
||||
# Get video information
|
||||
info = ydl.extract_info(url, download=False)
|
||||
|
||||
# Check for cancellation after getting info
|
||||
if cancel_event and cancel_event.is_set():
|
||||
raise KeyboardInterrupt("Download cancelled")
|
||||
|
||||
# Download (progress hook will be called from this thread)
|
||||
ydl.download([url])
|
||||
|
||||
return info
|
||||
except KeyboardInterrupt:
|
||||
# Interrupt download on cancellation
|
||||
logger.info("Download interrupted")
|
||||
raise
|
||||
|
||||
# Execute in executor for non-blocking download
|
||||
# None uses ThreadPoolExecutor by default
|
||||
# This ensures download doesn't block message processing
|
||||
# Event loop continues processing messages in parallel with download
|
||||
info = await loop.run_in_executor(None, run_download)
|
||||
|
||||
# Search for downloaded file
|
||||
title = info.get('title', 'video')
|
||||
# Clean title from invalid characters
|
||||
title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
|
||||
ext = info.get('ext', 'mp4')
|
||||
|
||||
logger.info(f"Searching for downloaded file. Title: {title}, ext: {ext}, task_id: {task_id}")
|
||||
|
||||
# Form filename with task_id
|
||||
if task_id:
|
||||
filename = f"{title}_[task_{task_id}].{ext}"
|
||||
else:
|
||||
filename = f"{title}.{ext}"
|
||||
|
||||
file_path = Path(output_dir) / filename
|
||||
logger.debug(f"Expected file path: {file_path}")
|
||||
|
||||
# If file not found at expected path, search in directory
|
||||
if not file_path.exists():
|
||||
logger.info(f"File not found at expected path {file_path}, starting search...")
|
||||
|
||||
# If task_id exists, search for file with this task_id
|
||||
if task_id:
|
||||
# Pattern 1: exact match with task_id
|
||||
pattern = f"*[task_{task_id}].{ext}"
|
||||
files = list(Path(output_dir).glob(pattern))
|
||||
logger.debug(f"Search by pattern '{pattern}': found {len(files)} files")
|
||||
|
||||
if not files:
|
||||
# Pattern 2: search files containing task_id (in case format differs slightly)
|
||||
pattern2 = f"*task_{task_id}*.{ext}"
|
||||
files = list(Path(output_dir).glob(pattern2))
|
||||
logger.debug(f"Search by pattern '{pattern2}': found {len(files)} files")
|
||||
|
||||
if files:
|
||||
# Take newest file from found ones
|
||||
file_path = max(files, key=lambda p: p.stat().st_mtime)
|
||||
logger.info(f"Found file by task_id: {file_path}")
|
||||
else:
|
||||
# If not found by task_id, search newest file with this extension
|
||||
logger.info(f"File with task_id {task_id} not found, searching newest .{ext} file")
|
||||
files = list(Path(output_dir).glob(f"*.{ext}"))
|
||||
if files:
|
||||
# Filter files created recently (last 5 minutes)
|
||||
import time
|
||||
current_time = time.time()
|
||||
recent_files = [
|
||||
f for f in files
|
||||
if (current_time - f.stat().st_mtime) < 300 # 5 minutes
|
||||
]
|
||||
if recent_files:
|
||||
file_path = max(recent_files, key=lambda p: p.stat().st_mtime)
|
||||
logger.info(f"Found recently created file: {file_path}")
|
||||
else:
|
||||
file_path = max(files, key=lambda p: p.stat().st_mtime)
|
||||
logger.warning(f"No recent files found, taking newest: {file_path}")
|
||||
else:
|
||||
# Search file by extension
|
||||
files = list(Path(output_dir).glob(f"*.{ext}"))
|
||||
if files:
|
||||
# Take newest file
|
||||
file_path = max(files, key=lambda p: p.stat().st_mtime)
|
||||
logger.info(f"Found file by time: {file_path}")
|
||||
|
||||
if file_path.exists():
|
||||
file_size = file_path.stat().st_size
|
||||
logger.info(f"File found: {file_path}, size: {file_size / (1024*1024):.2f} MB")
|
||||
return {
|
||||
'file_path': str(file_path),
|
||||
'title': title,
|
||||
'duration': info.get('duration'),
|
||||
'thumbnail': info.get('thumbnail'),
|
||||
'size': file_size
|
||||
}
|
||||
else:
|
||||
# Output list of all files in directory for debugging
|
||||
all_files = list(Path(output_dir).glob("*"))
|
||||
logger.error(
|
||||
f"File not found after download: {file_path}\n"
|
||||
f"Files in downloads directory: {[str(f.name) for f in all_files[:10]]}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading via yt-dlp: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_media_info(url: str, cookies_file: Optional[str] = None) -> Optional[Dict]:
|
||||
"""
|
||||
Get media information without downloading
|
||||
|
||||
Args:
|
||||
url: Media URL
|
||||
cookies_file: Path to cookies file (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary with information or None
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
ydl_opts = {
|
||||
'quiet': True,
|
||||
'no_warnings': True,
|
||||
}
|
||||
|
||||
# Add cookies if specified
|
||||
if cookies_file:
|
||||
# Resolve cookies file path (support relative and absolute paths)
|
||||
cookies_path = Path(cookies_file)
|
||||
if not cookies_path.is_absolute():
|
||||
# If path is relative, search relative to project root
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
cookies_path = project_root / cookies_file
|
||||
# Also check current working directory
|
||||
if not cookies_path.exists():
|
||||
cookies_path = Path(cookies_file).resolve()
|
||||
|
||||
if cookies_path.exists():
|
||||
ydl_opts['cookiefile'] = str(cookies_path)
|
||||
logger.debug(f"Using cookies to get info: {cookies_path}")
|
||||
else:
|
||||
logger.warning(f"Cookies file not found for get_media_info: {cookies_path}")
|
||||
|
||||
def extract_info_sync():
|
||||
"""Synchronous function for extracting information"""
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
return ydl.extract_info(url, download=False)
|
||||
|
||||
# Run synchronous yt-dlp in executor to avoid blocking event loop
|
||||
info = await loop.run_in_executor(None, extract_info_sync)
|
||||
|
||||
return {
|
||||
'title': info.get('title'),
|
||||
'duration': info.get('duration'),
|
||||
'thumbnail': info.get('thumbnail'),
|
||||
'uploader': info.get('uploader'),
|
||||
'view_count': info.get('view_count'),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting media info: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
4
bot/modules/message_handler/__init__.py
Normal file
4
bot/modules/message_handler/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Message handler module
|
||||
"""
|
||||
|
||||
223
bot/modules/message_handler/callbacks.py
Normal file
223
bot/modules/message_handler/callbacks.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Callback button handling
|
||||
"""
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import CallbackQuery, InlineKeyboardMarkup, InlineKeyboardButton
|
||||
from pyrogram.handlers import CallbackQueryHandler
|
||||
from pyrogram.errors import MessageNotModified
|
||||
from bot.modules.access_control.auth import is_authorized, is_admin, is_owner
|
||||
from bot.modules.message_handler.commands import get_start_keyboard
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def callback_handler(client: Client, callback_query: CallbackQuery):
|
||||
"""Handle callback queries"""
|
||||
user_id = callback_query.from_user.id
|
||||
data = callback_query.data
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(user_id):
|
||||
await callback_query.answer("❌ У вас нет доступа к этому боту", show_alert=True)
|
||||
return
|
||||
|
||||
# Handle different callback data
|
||||
if data == "back":
|
||||
# Return to main menu
|
||||
welcome_text = (
|
||||
"👋 **Привет! Я бот для загрузки медиа-файлов.**\n\n"
|
||||
"📥 **Что я умею:**\n"
|
||||
"• Загружать видео с YouTube, Instagram и других платформ\n"
|
||||
"• Загружать файлы по прямым ссылкам\n"
|
||||
"• Отправлять файлы вам в Telegram\n\n"
|
||||
"**Как использовать:**\n"
|
||||
"Просто отправьте мне ссылку на видео или файл, и я загружу его для вас!\n\n"
|
||||
"Используйте кнопки ниже для управления:"
|
||||
)
|
||||
keyboard = await get_start_keyboard(user_id)
|
||||
await callback_query.edit_message_text(welcome_text, reply_markup=keyboard)
|
||||
await callback_query.answer()
|
||||
|
||||
elif data == "help":
|
||||
# Show help
|
||||
help_text = (
|
||||
"👋 **Привет! Рад помочь!**\n\n"
|
||||
|
||||
"🎯 **Как начать работу:**\n"
|
||||
"Это очень просто! Просто отправьте мне ссылку на видео или файл, и я сразу начну загрузку.\n\n"
|
||||
|
||||
"📥 **Что я умею загружать:**\n"
|
||||
"• 🎬 Видео с YouTube, Instagram, TikTok и других платформ\n"
|
||||
"• 📁 Файлы по прямым ссылкам\n"
|
||||
"• 🎵 Аудио и музыку\n"
|
||||
"• 📸 Изображения и фото\n\n"
|
||||
|
||||
"⌨️ **Основные команды:**\n"
|
||||
"• `/start` - Открыть главное меню с кнопками\n"
|
||||
"• `/help` - Показать эту справку\n"
|
||||
"• `/status` - Посмотреть статус ваших загрузок\n\n"
|
||||
|
||||
"💡 **Совет:** Используйте кнопки в главном меню для быстрого доступа к функциям!"
|
||||
)
|
||||
|
||||
if await is_admin(user_id):
|
||||
help_text += (
|
||||
"\n\n"
|
||||
"👑 **Команды для администраторов:**\n"
|
||||
"• `/adduser <user_id или @username>` - Добавить нового пользователя\n"
|
||||
"• `/removeuser <user_id или @username>` - Удалить пользователя\n"
|
||||
"• `/blockuser <user_id или @username>` - Заблокировать пользователя\n"
|
||||
"• `/listusers` - Посмотреть список всех пользователей\n\n"
|
||||
"💼 **Управление администраторами:**\n"
|
||||
"• `/addadmin <user_id или @username>` - Назначить администратора\n"
|
||||
"• `/removeadmin <user_id или @username>` - Снять права администратора\n"
|
||||
"• `/listadmins` - Список всех администраторов"
|
||||
)
|
||||
|
||||
keyboard = InlineKeyboardMarkup([[
|
||||
InlineKeyboardButton("🔙 Назад", callback_data="back")
|
||||
]])
|
||||
|
||||
await callback_query.edit_message_text(help_text, reply_markup=keyboard)
|
||||
await callback_query.answer()
|
||||
|
||||
elif data == "status":
|
||||
# Show task status
|
||||
from bot.modules.task_scheduler.monitor import get_user_tasks_status
|
||||
from bot.modules.task_scheduler.queue import TaskStatus
|
||||
|
||||
tasks = await get_user_tasks_status(user_id)
|
||||
active_tasks = [t for t in tasks if t.get('status') in ['pending', 'processing']]
|
||||
completed = [t for t in tasks if t.get('status') == 'completed']
|
||||
failed = [t for t in tasks if t.get('status') == 'failed']
|
||||
|
||||
status_text = (
|
||||
"📊 **Статус задач:**\n\n"
|
||||
f"⏳ Активных задач: {len(active_tasks)}\n"
|
||||
f"✅ Завершено: {len(completed)}\n"
|
||||
f"❌ Ошибок: {len(failed)}\n\n"
|
||||
)
|
||||
|
||||
if active_tasks:
|
||||
status_text += "**Активные задачи:**\n"
|
||||
for task in active_tasks[:5]: # Show first 5
|
||||
task_id = task.get('id')
|
||||
progress = task.get('progress', 0)
|
||||
status_text += f"• #{task_id} - {progress}%\n"
|
||||
if len(active_tasks) > 5:
|
||||
status_text += f"... и еще {len(active_tasks) - 5}\n"
|
||||
status_text += "\n💡 Используйте `/cancel <task_id>` для отмены"
|
||||
|
||||
keyboard = InlineKeyboardMarkup([[
|
||||
InlineKeyboardButton("🔄 Обновить", callback_data="status"),
|
||||
InlineKeyboardButton("🔙 Назад", callback_data="back")
|
||||
]])
|
||||
|
||||
try:
|
||||
await callback_query.edit_message_text(status_text, reply_markup=keyboard)
|
||||
await callback_query.answer("✅ Статус обновлен")
|
||||
except MessageNotModified:
|
||||
# If text didn't change, just answer callback
|
||||
await callback_query.answer("✅ Статус актуален")
|
||||
|
||||
elif data == "download":
|
||||
# Download information
|
||||
download_text = (
|
||||
"📥 **Загрузка файлов:**\n\n"
|
||||
"**Поддерживаемые источники:**\n"
|
||||
"• YouTube (видео, плейлисты)\n"
|
||||
"• Instagram (посты, истории)\n"
|
||||
"• Прямые ссылки на файлы\n"
|
||||
"• Другие платформы через yt-dlp\n\n"
|
||||
"**Как использовать:**\n"
|
||||
"Просто отправьте мне ссылку на видео или файл, и я начну загрузку!\n\n"
|
||||
"Примеры:\n"
|
||||
"• https://www.youtube.com/watch?v=...\n"
|
||||
"• https://www.instagram.com/p/...\n"
|
||||
"• https://example.com/file.mp4"
|
||||
)
|
||||
|
||||
keyboard = InlineKeyboardMarkup([[
|
||||
InlineKeyboardButton("🔙 Назад", callback_data="back")
|
||||
]])
|
||||
|
||||
await callback_query.edit_message_text(download_text, reply_markup=keyboard)
|
||||
await callback_query.answer()
|
||||
|
||||
elif data == "admin_users":
|
||||
# User management (admin only)
|
||||
if not await is_admin(user_id):
|
||||
await callback_query.answer("❌ Только для администраторов", show_alert=True)
|
||||
return
|
||||
|
||||
# Determine user status
|
||||
is_owner_user = await is_owner(user_id)
|
||||
|
||||
# Form text and buttons depending on status
|
||||
if is_owner_user:
|
||||
# Main admin - full functionality
|
||||
users_text = (
|
||||
"👥 **Управление пользователями:**\n\n"
|
||||
"**Управление пользователями:**\n"
|
||||
"• /adduser <user_id или @username> - Добавить пользователя\n"
|
||||
"• /removeuser <user_id или @username> - Удалить пользователя\n"
|
||||
"• /blockuser <user_id или @username> - Заблокировать пользователя\n"
|
||||
"• /listusers - Список всех пользователей\n\n"
|
||||
"**Управление администраторами:**\n"
|
||||
"• /addadmin <user_id или @username> - Назначить администратора\n"
|
||||
"• /removeadmin <user_id или @username> - Снять права администратора\n"
|
||||
"• /listadmins - Список всех администраторов\n\n"
|
||||
"⚠️ **Внимание:** Вы не можете снять права администратора у самого себя."
|
||||
)
|
||||
else:
|
||||
# Regular administrator - only user management
|
||||
users_text = (
|
||||
"👥 **Управление пользователями:**\n\n"
|
||||
"**Доступные команды:**\n"
|
||||
"• /adduser <user_id или @username> - Добавить пользователя\n"
|
||||
"• /removeuser <user_id или @username> - Удалить пользователя\n"
|
||||
"• /blockuser <user_id или @username> - Заблокировать пользователя\n"
|
||||
"• /listusers - Список всех пользователей\n\n"
|
||||
"_Управление через веб-интерфейс будет доступно позже_"
|
||||
)
|
||||
|
||||
keyboard = InlineKeyboardMarkup([[
|
||||
InlineKeyboardButton("🔙 Назад", callback_data="back")
|
||||
]])
|
||||
|
||||
await callback_query.edit_message_text(users_text, reply_markup=keyboard)
|
||||
await callback_query.answer()
|
||||
|
||||
elif data == "admin_stats":
|
||||
# Statistics (admin only)
|
||||
if not await is_admin(user_id):
|
||||
await callback_query.answer("❌ Только для администраторов", show_alert=True)
|
||||
return
|
||||
|
||||
stats_text = (
|
||||
"📈 **Статистика:**\n\n"
|
||||
"👥 Всего пользователей: 0\n"
|
||||
"👑 Администраторов: 0\n"
|
||||
"📥 Всего загрузок: 0\n"
|
||||
"✅ Успешных: 0\n"
|
||||
"❌ Ошибок: 0\n\n"
|
||||
"_Статистика будет реализована в следующем этапе_"
|
||||
)
|
||||
|
||||
keyboard = InlineKeyboardMarkup([[
|
||||
InlineKeyboardButton("🔙 Назад", callback_data="back")
|
||||
]])
|
||||
|
||||
await callback_query.edit_message_text(stats_text, reply_markup=keyboard)
|
||||
await callback_query.answer()
|
||||
|
||||
else:
|
||||
await callback_query.answer("❓ Неизвестная команда")
|
||||
|
||||
|
||||
def register_callbacks(app: Client):
|
||||
"""Register all callback handlers"""
|
||||
app.add_handler(CallbackQueryHandler(callback_handler))
|
||||
logger.info("Callback handlers registered")
|
||||
|
||||
747
bot/modules/message_handler/commands.py
Normal file
747
bot/modules/message_handler/commands.py
Normal file
@@ -0,0 +1,747 @@
|
||||
"""
|
||||
Command handling
|
||||
"""
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
|
||||
from pyrogram.filters import command
|
||||
from pyrogram.handlers import MessageHandler
|
||||
from bot.modules.access_control.permissions import require_permission, Permission
|
||||
from bot.modules.access_control.user_manager import (
|
||||
add_user, remove_user, block_user, unblock_user,
|
||||
add_admin, remove_admin
|
||||
)
|
||||
from bot.modules.message_handler.filters import is_url_message
|
||||
from bot.utils.helpers import parse_user_id
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_start_keyboard(user_id: int) -> InlineKeyboardMarkup:
|
||||
"""
|
||||
Create keyboard for /start command
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
InlineKeyboardMarkup with buttons
|
||||
"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
|
||||
# Base buttons for all users
|
||||
buttons = [
|
||||
[
|
||||
InlineKeyboardButton("📥 Загрузить", callback_data="download"),
|
||||
InlineKeyboardButton("📊 Статус", callback_data="status")
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton("❓ Помощь", callback_data="help")
|
||||
]
|
||||
]
|
||||
|
||||
# Additional buttons for administrators
|
||||
if await is_admin(user_id):
|
||||
buttons.append([
|
||||
InlineKeyboardButton("📈 Статистика", callback_data="admin_stats")
|
||||
])
|
||||
|
||||
return InlineKeyboardMarkup(buttons)
|
||||
|
||||
|
||||
async def start_command(client: Client, message: Message):
|
||||
"""Handle /start command"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(message.from_user.id):
|
||||
await message.reply("❌ У вас нет доступа к этому боту")
|
||||
return
|
||||
|
||||
welcome_text = (
|
||||
"👋 **Привет! Я бот для загрузки медиа-файлов.**\n\n"
|
||||
"📥 **Что я умею:**\n"
|
||||
"• Загружать видео с YouTube, Instagram и других платформ\n"
|
||||
"• Загружать файлы по прямым ссылкам\n"
|
||||
"• Отправлять файлы вам в Telegram\n\n"
|
||||
"**Как использовать:**\n"
|
||||
"Просто отправьте мне ссылку на видео или файл, и я загружу его для вас!\n\n"
|
||||
"Используйте кнопки ниже для управления:"
|
||||
)
|
||||
|
||||
keyboard = await get_start_keyboard(message.from_user.id)
|
||||
|
||||
await message.reply(
|
||||
welcome_text,
|
||||
reply_markup=keyboard
|
||||
)
|
||||
|
||||
|
||||
async def help_command(client: Client, message: Message):
|
||||
"""Handle /help command"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(message.from_user.id):
|
||||
await message.reply("❌ У вас нет доступа к этому боту")
|
||||
return
|
||||
|
||||
help_text = (
|
||||
"👋 **Привет! Рад помочь!**\n\n"
|
||||
|
||||
"🎯 **Как начать работу:**\n"
|
||||
"Это очень просто! Просто отправьте мне ссылку на видео или файл, и я сразу начну загрузку.\n\n"
|
||||
|
||||
"📥 **Что я умею загружать:**\n"
|
||||
"• 🎬 Видео с YouTube, Instagram, TikTok и других платформ\n"
|
||||
"• 📁 Файлы по прямым ссылкам\n"
|
||||
"• 🎵 Аудио и музыку\n"
|
||||
"• 📸 Изображения и фото\n\n"
|
||||
|
||||
"⌨️ **Основные команды:**\n"
|
||||
"• `/start` - Открыть главное меню с кнопками\n"
|
||||
"• `/help` - Показать эту справку\n"
|
||||
"• `/status` - Посмотреть статус ваших загрузок\n"
|
||||
"• `/cancel <task_id>` - Отменить задачу\n\n"
|
||||
|
||||
"💡 **Совет:** Используйте кнопки в главном меню для быстрого доступа к функциям!"
|
||||
)
|
||||
|
||||
# Add information for administrators
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
if await is_admin(message.from_user.id):
|
||||
help_text += (
|
||||
"\n\n"
|
||||
"👑 **Команды для администраторов:**\n"
|
||||
"• `/adduser <user_id или @username>` - Добавить нового пользователя\n"
|
||||
"• `/removeuser <user_id или @username>` - Удалить пользователя\n"
|
||||
"• `/blockuser <user_id или @username>` - Заблокировать пользователя\n"
|
||||
"• `/listusers` - Посмотреть список всех пользователей\n\n"
|
||||
"💼 **Управление администраторами:**\n"
|
||||
"• `/addadmin <user_id или @username>` - Назначить администратора\n"
|
||||
"• `/removeadmin <user_id или @username>` - Снять права администратора\n"
|
||||
"• `/listadmins` - Список всех администраторов"
|
||||
)
|
||||
await message.reply(help_text)
|
||||
|
||||
|
||||
async def status_command(client: Client, message: Message):
|
||||
"""Handle /status command"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
from bot.modules.task_scheduler.monitor import get_user_tasks_status
|
||||
from bot.modules.task_scheduler.queue import TaskStatus
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(message.from_user.id):
|
||||
await message.reply("❌ У вас нет доступа к этому боту")
|
||||
return
|
||||
|
||||
user_id = message.from_user.id
|
||||
tasks = await get_user_tasks_status(user_id)
|
||||
|
||||
if not tasks:
|
||||
await message.reply("📊 У вас нет активных задач")
|
||||
return
|
||||
|
||||
# Filter only active tasks (pending, processing)
|
||||
active_tasks = [
|
||||
t for t in tasks
|
||||
if t.get('status') in ['pending', 'processing']
|
||||
]
|
||||
|
||||
if not active_tasks:
|
||||
await message.reply("📊 У вас нет активных задач")
|
||||
return
|
||||
|
||||
status_text = "📊 **Ваши активные задачи:**\n\n"
|
||||
|
||||
for task in active_tasks[:10]: # Show maximum 10 tasks
|
||||
task_id = task.get('id')
|
||||
status = task.get('status', 'unknown')
|
||||
progress = task.get('progress', 0)
|
||||
url = task.get('url', 'N/A')
|
||||
|
||||
status_emoji = {
|
||||
'pending': '⏳',
|
||||
'processing': '🔄',
|
||||
'completed': '✅',
|
||||
'failed': '❌',
|
||||
'cancelled': '🚫'
|
||||
}.get(status, '❓')
|
||||
|
||||
status_text += (
|
||||
f"{status_emoji} **Задача #{task_id}**\n"
|
||||
f"🔗 {url[:50]}...\n"
|
||||
f"📊 Прогресс: {progress}%\n"
|
||||
f"📝 Статус: {status}\n\n"
|
||||
)
|
||||
|
||||
if len(active_tasks) > 10:
|
||||
status_text += f"... и еще {len(active_tasks) - 10} задач\n\n"
|
||||
|
||||
status_text += "💡 Используйте `/cancel <task_id>` для отмены задачи"
|
||||
|
||||
await message.reply(status_text)
|
||||
|
||||
|
||||
async def cancel_command(client: Client, message: Message):
|
||||
"""Handle /cancel command"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
from bot.modules.task_scheduler.monitor import cancel_user_task
|
||||
from bot.modules.task_scheduler.queue import task_queue
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(message.from_user.id):
|
||||
await message.reply("❌ У вас нет доступа к этому боту")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /cancel <task_id>\n\nИспользуйте /status чтобы увидеть ID ваших задач")
|
||||
return
|
||||
|
||||
try:
|
||||
task_id = int(message.command[1])
|
||||
except ValueError:
|
||||
await message.reply("❌ Неверный формат task_id. Используйте число.")
|
||||
return
|
||||
|
||||
user_id = message.from_user.id
|
||||
|
||||
# Cancel task
|
||||
try:
|
||||
success, message_text = await cancel_user_task(user_id, task_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cancel_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка при отмене задачи: {str(e)}")
|
||||
|
||||
|
||||
# User management commands (admin only)
|
||||
async def adduser_command(client: Client, message: Message):
|
||||
"""Add user"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from bot.utils.helpers import resolve_user_identifier
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /adduser <user_id или @username>")
|
||||
return
|
||||
|
||||
identifier = message.command[1]
|
||||
|
||||
# Resolve identifier (user_id or username)
|
||||
user_id, error_message = await resolve_user_identifier(identifier)
|
||||
if not user_id:
|
||||
await message.reply(f"❌ {error_message}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, message_text = await add_user(user_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in adduser_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка: {str(e)}")
|
||||
|
||||
|
||||
async def removeuser_command(client: Client, message: Message):
|
||||
"""Remove user"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from bot.utils.helpers import resolve_user_identifier
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /removeuser <user_id или @username>")
|
||||
return
|
||||
|
||||
identifier = message.command[1]
|
||||
|
||||
# Resolve identifier (user_id or username)
|
||||
user_id, error_message = await resolve_user_identifier(identifier)
|
||||
if not user_id:
|
||||
await message.reply(f"❌ {error_message}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, message_text = await remove_user(user_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in removeuser_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка: {str(e)}")
|
||||
|
||||
|
||||
async def blockuser_command(client: Client, message: Message):
|
||||
"""Block user"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from bot.utils.helpers import resolve_user_identifier
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /blockuser <user_id или @username>")
|
||||
return
|
||||
|
||||
identifier = message.command[1]
|
||||
|
||||
# Resolve identifier (user_id or username)
|
||||
user_id, error_message = await resolve_user_identifier(identifier)
|
||||
if not user_id:
|
||||
await message.reply(f"❌ {error_message}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, message_text = await block_user(user_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in blockuser_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка: {str(e)}")
|
||||
|
||||
|
||||
async def unblockuser_command(client: Client, message: Message):
|
||||
"""Unblock user"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from bot.utils.helpers import resolve_user_identifier
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /unblockuser <user_id или @username>")
|
||||
return
|
||||
|
||||
identifier = message.command[1]
|
||||
|
||||
# Resolve identifier (user_id or username)
|
||||
user_id, error_message = await resolve_user_identifier(identifier)
|
||||
if not user_id:
|
||||
await message.reply(f"❌ {error_message}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, message_text = await unblock_user(user_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unblockuser_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка: {str(e)}")
|
||||
|
||||
|
||||
async def listusers_command(client: Client, message: Message):
|
||||
"""List users"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_async_session_local
|
||||
from sqlalchemy import select, func, desc
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
# Get total count
|
||||
count_result = await session.execute(select(func.count(User.user_id)))
|
||||
total_count = count_result.scalar() or 0
|
||||
|
||||
if total_count == 0:
|
||||
await message.reply("📋 Пользователей в базе данных нет")
|
||||
return
|
||||
|
||||
# Get users (limit to 50 for message length)
|
||||
query = select(User).order_by(desc(User.created_at)).limit(50)
|
||||
result = await session.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
|
||||
# Format message
|
||||
text = f"📋 **Список пользователей** (всего: {total_count})\n\n"
|
||||
|
||||
for i, user in enumerate(users, 1):
|
||||
username = f"@{user.username}" if user.username else "-"
|
||||
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or "-"
|
||||
admin_badge = "👑" if user.is_admin else ""
|
||||
blocked_badge = "🚫" if user.is_blocked else ""
|
||||
|
||||
text += (
|
||||
f"{i}. {admin_badge} {blocked_badge} **ID:** `{user.user_id}`\n"
|
||||
f" 👤 {username} ({name})\n"
|
||||
f" 📅 Создан: {user.created_at.strftime('%Y-%m-%d %H:%M') if user.created_at else 'N/A'}\n\n"
|
||||
)
|
||||
|
||||
if total_count > 50:
|
||||
text += f"\n... и еще {total_count - 50} пользователей (показаны первые 50)"
|
||||
|
||||
# Split message if too long (Telegram limit is 4096 characters)
|
||||
if len(text) > 4000:
|
||||
# Send first part
|
||||
await message.reply(text[:4000])
|
||||
# Send remaining users count
|
||||
await message.reply(f"... и еще {total_count - 50} пользователей")
|
||||
else:
|
||||
await message.reply(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {e}", exc_info=True)
|
||||
await message.reply("❌ Ошибка при получении списка пользователей")
|
||||
|
||||
|
||||
async def addadmin_command(client: Client, message: Message):
|
||||
"""Assign administrator"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from bot.utils.helpers import resolve_user_identifier
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /addadmin <user_id или @username>")
|
||||
return
|
||||
|
||||
identifier = message.command[1]
|
||||
requester_id = message.from_user.id
|
||||
|
||||
# Resolve identifier (user_id or username)
|
||||
user_id, error_message = await resolve_user_identifier(identifier)
|
||||
if not user_id:
|
||||
await message.reply(f"❌ {error_message}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, message_text = await add_admin(user_id, requester_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in addadmin_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка: {str(e)}")
|
||||
|
||||
|
||||
async def removeadmin_command(client: Client, message: Message):
|
||||
"""Remove administrator privileges"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from bot.utils.helpers import resolve_user_identifier
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
if not message.command or len(message.command) < 2:
|
||||
await message.reply("❌ Использование: /removeadmin <user_id или @username>")
|
||||
return
|
||||
|
||||
identifier = message.command[1]
|
||||
requester_id = message.from_user.id
|
||||
|
||||
# Resolve identifier (user_id or username)
|
||||
user_id, error_message = await resolve_user_identifier(identifier)
|
||||
if not user_id:
|
||||
await message.reply(f"❌ {error_message}")
|
||||
return
|
||||
|
||||
try:
|
||||
success, message_text = await remove_admin(user_id, requester_id)
|
||||
if success:
|
||||
await message.reply(f"✅ {message_text}")
|
||||
else:
|
||||
await message.reply(f"❌ {message_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in removeadmin_command: {e}", exc_info=True)
|
||||
await message.reply(f"❌ Произошла ошибка: {str(e)}")
|
||||
|
||||
|
||||
async def listadmins_command(client: Client, message: Message):
|
||||
"""List administrators"""
|
||||
from bot.modules.access_control.auth import is_admin
|
||||
from shared.database.models import User
|
||||
from shared.database.session import get_async_session_local
|
||||
from sqlalchemy import select, func, desc
|
||||
|
||||
# Check access permissions
|
||||
if not await is_admin(message.from_user.id):
|
||||
await message.reply("❌ Эта команда доступна только администраторам")
|
||||
return
|
||||
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
# Get administrators
|
||||
query = select(User).where(User.is_admin == True).order_by(desc(User.created_at))
|
||||
result = await session.execute(query)
|
||||
admins = list(result.scalars().all())
|
||||
|
||||
if not admins:
|
||||
await message.reply("👑 Администраторов в базе данных нет")
|
||||
return
|
||||
|
||||
# Format message
|
||||
text = f"👑 **Список администраторов** (всего: {len(admins)})\n\n"
|
||||
|
||||
for i, admin in enumerate(admins, 1):
|
||||
username = f"@{admin.username}" if admin.username else "-"
|
||||
name = f"{admin.first_name or ''} {admin.last_name or ''}".strip() or "-"
|
||||
blocked_badge = "🚫" if admin.is_blocked else ""
|
||||
|
||||
text += (
|
||||
f"{i}. {blocked_badge} **ID:** `{admin.user_id}`\n"
|
||||
f" 👤 {username} ({name})\n"
|
||||
f" 📅 Создан: {admin.created_at.strftime('%Y-%m-%d %H:%M') if admin.created_at else 'N/A'}\n\n"
|
||||
)
|
||||
|
||||
await message.reply(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing administrators: {e}", exc_info=True)
|
||||
await message.reply("❌ Ошибка при получении списка администраторов")
|
||||
|
||||
|
||||
async def login_command(client: Client, message: Message):
|
||||
"""Handle /login command to get OTP code"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
from bot.modules.database.session import AsyncSessionLocal
|
||||
from web.utils.otp import create_otp_code
|
||||
from shared.config import settings
|
||||
|
||||
user_id = message.from_user.id
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(user_id):
|
||||
await message.reply("❌ У вас нет доступа к этому боту")
|
||||
return
|
||||
|
||||
try:
|
||||
# Create OTP code
|
||||
async with AsyncSessionLocal() as session:
|
||||
code = await create_otp_code(user_id, session)
|
||||
|
||||
if code:
|
||||
# Form URL for web interface
|
||||
if settings.WEB_HOST == "0.0.0.0":
|
||||
login_url = f"localhost:{settings.WEB_PORT}"
|
||||
else:
|
||||
login_url = f"{settings.WEB_HOST}:{settings.WEB_PORT}"
|
||||
|
||||
await message.reply(
|
||||
f"🔐 **Ваш код для входа в веб-интерфейс:**\n\n"
|
||||
f"**`{code}`**\n\n"
|
||||
f"⏰ Код действителен 10 минут\n\n"
|
||||
f"🌐 Перейдите на http://{login_url}/admin/login и введите этот код\n\n"
|
||||
f"💡 Или используйте ваш User ID: `{user_id}`"
|
||||
)
|
||||
else:
|
||||
await message.reply("❌ Не удалось создать код. Попробуйте позже.")
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка при создании OTP кода: {e}")
|
||||
await message.reply("❌ Произошла ошибка при создании кода. Попробуйте позже.")
|
||||
|
||||
|
||||
async def url_handler(client: Client, message: Message):
|
||||
"""Handle URL messages"""
|
||||
from bot.modules.access_control.auth import is_authorized
|
||||
from bot.modules.task_scheduler.queue import task_queue, Task, TaskStatus
|
||||
from bot.modules.task_scheduler.executor import task_executor, set_app_client
|
||||
import time
|
||||
|
||||
# Check authorization
|
||||
if not await is_authorized(message.from_user.id):
|
||||
await message.reply("❌ У вас нет доступа к этому боту")
|
||||
return
|
||||
|
||||
url = message.text.strip()
|
||||
user_id = message.from_user.id
|
||||
|
||||
# URL validation
|
||||
from bot.utils.helpers import is_valid_url
|
||||
if not is_valid_url(url):
|
||||
await message.reply(
|
||||
"❌ Некорректный или небезопасный URL.\n\n"
|
||||
"Пожалуйста, отправьте валидную ссылку (http:// или https://)"
|
||||
)
|
||||
return
|
||||
|
||||
# Check concurrent tasks count
|
||||
from bot.config import settings
|
||||
active_tasks_count = await task_queue.get_user_active_tasks_count(user_id)
|
||||
if active_tasks_count >= settings.MAX_CONCURRENT_TASKS:
|
||||
await message.reply(
|
||||
f"❌ Превышен лимит одновременных задач ({settings.MAX_CONCURRENT_TASKS}).\n"
|
||||
f"⏳ Дождитесь завершения текущих задач или отмените их через /cancel"
|
||||
)
|
||||
return
|
||||
|
||||
# Set client for task executor
|
||||
set_app_client(client)
|
||||
|
||||
# Generate unique task_id using UUID
|
||||
from bot.utils.helpers import generate_unique_task_id
|
||||
task_id = generate_unique_task_id()
|
||||
|
||||
# Check that such ID doesn't exist yet (in case of collision, though probability is extremely low)
|
||||
existing_task = await task_queue.get_task_by_id(task_id)
|
||||
max_retries = 10
|
||||
retries = 0
|
||||
while existing_task and retries < max_retries:
|
||||
task_id = generate_unique_task_id()
|
||||
existing_task = await task_queue.get_task_by_id(task_id)
|
||||
retries += 1
|
||||
|
||||
if existing_task:
|
||||
# If after 10 attempts still collision (extremely unlikely), log error
|
||||
logger.error(f"Failed to generate unique task_id after {max_retries} attempts")
|
||||
await message.reply("❌ Ошибка при создании задачи. Попробуйте позже.")
|
||||
return
|
||||
|
||||
# Duplicate URL check will be performed atomically in task_queue.add_task()
|
||||
url_normalized = url.strip()
|
||||
|
||||
task = Task(
|
||||
id=task_id,
|
||||
user_id=user_id,
|
||||
task_type="download",
|
||||
url=url_normalized,
|
||||
status=TaskStatus.PENDING
|
||||
)
|
||||
|
||||
# Save task to database BEFORE adding to queue (race condition fix)
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from shared.database.user_helpers import ensure_user_exists
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
async with get_async_session_local()() as session:
|
||||
# Ensure user exists before creating task
|
||||
await ensure_user_exists(user_id, session)
|
||||
|
||||
db_task = DBTask(
|
||||
id=task_id,
|
||||
user_id=user_id,
|
||||
task_type=task.task_type,
|
||||
status=task.status.value,
|
||||
url=task.url,
|
||||
progress=0,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
session.add(db_task)
|
||||
await session.commit()
|
||||
logger.info(f"Task {task_id} saved to database from bot")
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError saving task {task_id} to database (possibly duplicate ID): {e}", exc_info=True)
|
||||
# Generate new task_id and retry
|
||||
from bot.utils.helpers import generate_unique_task_id
|
||||
task_id = generate_unique_task_id()
|
||||
task.id = task_id
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
from shared.database.user_helpers import ensure_user_exists
|
||||
# Ensure user exists before creating task
|
||||
await ensure_user_exists(user_id, session)
|
||||
|
||||
db_task = DBTask(
|
||||
id=task_id,
|
||||
user_id=user_id,
|
||||
task_type=task.task_type,
|
||||
status=task.status.value,
|
||||
url=task.url,
|
||||
progress=0,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
session.add(db_task)
|
||||
await session.commit()
|
||||
logger.info(f"Task {task_id} saved to database from bot with new ID")
|
||||
except Exception as e2:
|
||||
logger.error(f"Error saving task {task_id} to database again: {e2}", exc_info=True)
|
||||
await message.reply("❌ Ошибка при создании задачи. Попробуйте позже.")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving task {task_id} to database: {e}", exc_info=True)
|
||||
await message.reply("❌ Ошибка при создании задачи. Попробуйте позже.")
|
||||
return
|
||||
|
||||
# Add to queue (with duplicate URL check) AFTER saving to database
|
||||
success = await task_queue.add_task(task, check_duplicate_url=True)
|
||||
if not success:
|
||||
# If failed to add to queue, remove from database
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
await session.delete(db_task)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting task {task_id} from database after failed queue addition: {e}")
|
||||
await message.reply(
|
||||
f"⚠️ Задача с этим URL уже обрабатывается.\n"
|
||||
f"Дождитесь завершения или отмените предыдущую задачу через /cancel"
|
||||
)
|
||||
return
|
||||
|
||||
# Start executor if not already started
|
||||
if not task_executor._running:
|
||||
await task_executor.start()
|
||||
|
||||
await message.reply(
|
||||
f"✅ Ссылка получена!\n\n"
|
||||
f"🔗 {url}\n\n"
|
||||
f"📥 Загрузка добавлена в очередь. Я начну загрузку в ближайшее время.\n"
|
||||
f"⏳ Вы получите уведомление о статусе загрузки."
|
||||
)
|
||||
|
||||
|
||||
def register_commands(app: Client):
|
||||
"""Register all commands"""
|
||||
# Base commands (for all users)
|
||||
app.add_handler(MessageHandler(start_command, filters=command("start")))
|
||||
app.add_handler(MessageHandler(help_command, filters=command("help")))
|
||||
app.add_handler(MessageHandler(status_command, filters=command("status")))
|
||||
app.add_handler(MessageHandler(cancel_command, filters=command("cancel")))
|
||||
app.add_handler(MessageHandler(login_command, filters=command("login")))
|
||||
|
||||
# User management commands (admin only)
|
||||
app.add_handler(MessageHandler(adduser_command, filters=command("adduser")))
|
||||
app.add_handler(MessageHandler(removeuser_command, filters=command("removeuser")))
|
||||
app.add_handler(MessageHandler(blockuser_command, filters=command("blockuser")))
|
||||
app.add_handler(MessageHandler(unblockuser_command, filters=command("unblockuser")))
|
||||
app.add_handler(MessageHandler(listusers_command, filters=command("listusers")))
|
||||
|
||||
# Administrator management commands (admin only)
|
||||
app.add_handler(MessageHandler(addadmin_command, filters=command("addadmin")))
|
||||
app.add_handler(MessageHandler(removeadmin_command, filters=command("removeadmin")))
|
||||
app.add_handler(MessageHandler(listadmins_command, filters=command("listadmins")))
|
||||
|
||||
# URL message handling
|
||||
app.add_handler(MessageHandler(url_handler, filters=is_url_message))
|
||||
|
||||
logger.info("Commands registered")
|
||||
|
||||
37
bot/modules/message_handler/filters.py
Normal file
37
bot/modules/message_handler/filters.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Message filters
|
||||
"""
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message
|
||||
from pyrogram.filters import Filter
|
||||
from bot.utils.helpers import is_valid_url
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class URLFilter(Filter):
|
||||
"""Filter for URL messages"""
|
||||
|
||||
async def __call__(self, client: Client, message: Message) -> bool:
|
||||
if not message.text:
|
||||
return False
|
||||
|
||||
text = message.text.strip()
|
||||
return is_valid_url(text)
|
||||
|
||||
|
||||
# Filter instance
|
||||
is_url_message = URLFilter()
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
"""Check if URL is YouTube"""
|
||||
youtube_domains = ['youtube.com', 'youtu.be', 'm.youtube.com']
|
||||
return any(domain in url.lower() for domain in youtube_domains)
|
||||
|
||||
|
||||
def is_instagram_url(url: str) -> bool:
|
||||
"""Check if URL is Instagram"""
|
||||
return 'instagram.com' in url.lower()
|
||||
|
||||
4
bot/modules/task_scheduler/__init__.py
Normal file
4
bot/modules/task_scheduler/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Task scheduler module
|
||||
"""
|
||||
|
||||
694
bot/modules/task_scheduler/executor.py
Normal file
694
bot/modules/task_scheduler/executor.py
Normal file
@@ -0,0 +1,694 @@
|
||||
"""
|
||||
Task executor
|
||||
"""
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from bot.modules.task_scheduler.queue import task_queue, Task, TaskStatus
|
||||
from bot.modules.media_loader.ytdlp import download_media
|
||||
from bot.modules.media_loader.sender import send_file_to_user, delete_file
|
||||
from bot.modules.message_handler.filters import is_youtube_url, is_instagram_url
|
||||
from pyrogram import Client
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global client for sending messages
|
||||
# Protected by threading.Lock for thread-safe access from different threads
|
||||
_app_client: Optional[Client] = None
|
||||
_app_client_lock = threading.Lock()
|
||||
|
||||
# Dictionary to store message_id for tasks to update messages
|
||||
# Format: {task_id: message_id}
|
||||
# Use size limit to prevent memory leaks
|
||||
_task_messages: dict[int, int] = {}
|
||||
_task_messages_lock = threading.Lock()
|
||||
_MAX_TASK_MESSAGES = 10000 # Maximum number of records
|
||||
|
||||
|
||||
def set_app_client(client: Client) -> None:
|
||||
"""
|
||||
Set client for sending messages (thread-safe)
|
||||
|
||||
Args:
|
||||
client: Pyrogram client for sending messages
|
||||
"""
|
||||
global _app_client
|
||||
with _app_client_lock:
|
||||
_app_client = client
|
||||
|
||||
|
||||
def get_app_client() -> Optional[Client]:
|
||||
"""Get client for sending messages (thread-safe)"""
|
||||
global _app_client
|
||||
with _app_client_lock:
|
||||
return _app_client
|
||||
|
||||
|
||||
def set_task_message(task_id: int, message_id: int) -> None:
|
||||
"""
|
||||
Save message_id for task (thread-safe)
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
message_id: Telegram message ID
|
||||
"""
|
||||
global _task_messages
|
||||
with _task_messages_lock:
|
||||
# Clear old records if limit reached
|
||||
if len(_task_messages) >= _MAX_TASK_MESSAGES:
|
||||
# Remove 10% of oldest records (FIFO)
|
||||
keys_to_remove = list(_task_messages.keys())[:_MAX_TASK_MESSAGES // 10]
|
||||
for key in keys_to_remove:
|
||||
_task_messages.pop(key, None)
|
||||
logger.debug(f"Cleared {len(keys_to_remove)} old records from _task_messages")
|
||||
_task_messages[task_id] = message_id
|
||||
|
||||
|
||||
def get_task_message(task_id: int) -> Optional[int]:
|
||||
"""Get message_id for task (thread-safe)"""
|
||||
global _task_messages
|
||||
with _task_messages_lock:
|
||||
return _task_messages.get(task_id)
|
||||
|
||||
|
||||
def clear_task_message(task_id: int) -> None:
|
||||
"""
|
||||
Remove message_id for task (thread-safe)
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
"""
|
||||
global _task_messages
|
||||
with _task_messages_lock:
|
||||
_task_messages.pop(task_id, None)
|
||||
|
||||
|
||||
async def cleanup_completed_task_messages():
|
||||
"""
|
||||
Periodic cleanup of message_id for completed tasks
|
||||
Runs in background every 30 minutes
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(30 * 60) # 30 minutes
|
||||
from bot.modules.task_scheduler.queue import task_queue, TaskStatus
|
||||
|
||||
# Get all completed tasks
|
||||
all_tasks = await task_queue.get_all_tasks()
|
||||
completed_task_ids = [
|
||||
task.id for task in all_tasks
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
|
||||
]
|
||||
|
||||
# Remove message_id for completed tasks
|
||||
with _task_messages_lock:
|
||||
removed_count = 0
|
||||
for task_id in completed_task_ids:
|
||||
if task_id in _task_messages:
|
||||
del _task_messages[task_id]
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"Cleared {removed_count} message_id for completed tasks")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Message ID cleanup task stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up message_id: {e}", exc_info=True)
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""Task executor"""
|
||||
|
||||
def __init__(self):
|
||||
self._running = False
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._running_lock = asyncio.Lock() # Protection for _running flag
|
||||
|
||||
async def start(self, num_workers: int = 2):
|
||||
"""
|
||||
Start task executor
|
||||
|
||||
Args:
|
||||
num_workers: Number of workers (default 2 for parallel processing)
|
||||
"""
|
||||
async with self._running_lock:
|
||||
if self._running:
|
||||
logger.warning("Task executor already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info(f"Starting task executor with {num_workers} workers")
|
||||
|
||||
# Create workers (each works independently)
|
||||
for i in range(num_workers):
|
||||
worker = asyncio.create_task(self._worker(f"worker-{i+1}"))
|
||||
self._workers.append(worker)
|
||||
# Small delay between worker starts for even load distribution
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Start background task for cleaning message_id for completed tasks
|
||||
cleanup_task = asyncio.create_task(cleanup_completed_task_messages())
|
||||
self._workers.append(cleanup_task)
|
||||
logger.info("Started background task for cleaning message_id for completed tasks")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop task executor"""
|
||||
async with self._running_lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("Stopping task executor...")
|
||||
|
||||
# Cancel all active tasks
|
||||
# Get all tasks and cancel active ones
|
||||
all_tasks = await task_queue.get_all_tasks()
|
||||
for task in all_tasks:
|
||||
if task.status == TaskStatus.PROCESSING:
|
||||
logger.info(f"Cancelling active task {task.id} on shutdown")
|
||||
await task_queue.update_task_status(task.id, TaskStatus.CANCELLED, error="Bot shutdown")
|
||||
cancel_event = task_queue.get_cancel_event(task.id)
|
||||
cancel_event.set()
|
||||
|
||||
# Wait for all workers to complete
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("Task executor stopped")
|
||||
|
||||
async def _worker(self, name: str):
|
||||
"""Worker for processing tasks (runs in parallel with other workers)"""
|
||||
logger.info(f"Worker {name} started")
|
||||
|
||||
while True:
|
||||
# Check state with lock protection
|
||||
async with self._running_lock:
|
||||
if not self._running:
|
||||
break
|
||||
try:
|
||||
# Get task from queue (non-blocking)
|
||||
task = await task_queue.get_task()
|
||||
|
||||
if not task:
|
||||
# No tasks, small delay
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
|
||||
# Check for cancellation before starting processing
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
logger.info(f"Task {task.id} was cancelled, skipping")
|
||||
continue
|
||||
|
||||
# Update status
|
||||
await task_queue.update_task_status(task.id, TaskStatus.PROCESSING)
|
||||
|
||||
logger.info(f"Worker {name} processing task {task.id}")
|
||||
|
||||
# Execute task (doesn't block other workers and message processing)
|
||||
# Each task executes independently
|
||||
await self._execute_task(task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Worker {name} stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker {name}: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"Worker {name} finished")
|
||||
|
||||
async def _execute_task(self, task: Task):
|
||||
"""
|
||||
Execute task
|
||||
|
||||
Args:
|
||||
task: Task to execute
|
||||
"""
|
||||
try:
|
||||
if task.task_type == "download" and task.url:
|
||||
# Determine download type
|
||||
if is_youtube_url(task.url) or is_instagram_url(task.url) or any(
|
||||
domain in task.url.lower()
|
||||
for domain in ['youtube.com', 'youtu.be', 'instagram.com', 'tiktok.com', 'twitter.com', 'x.com']
|
||||
):
|
||||
# Download via yt-dlp
|
||||
await self._download_with_ytdlp(task)
|
||||
else:
|
||||
# Direct download (to be implemented later)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Direct download not supported yet"
|
||||
)
|
||||
else:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Unknown task type or missing URL"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task {task.id}: {e}", exc_info=True)
|
||||
|
||||
# Form user-friendly error message
|
||||
error_message = str(e)
|
||||
if "login required" in error_message.lower() or "cookies" in error_message.lower():
|
||||
error_message = (
|
||||
"❌ Authentication required to download this content.\n\n"
|
||||
"💡 Solution: configure cookies in bot configuration.\n"
|
||||
"See instructions in README.md"
|
||||
)
|
||||
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=error_message
|
||||
)
|
||||
|
||||
# Send message to user
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Download error:\n{error_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {e}", exc_info=True)
|
||||
|
||||
async def _download_with_ytdlp(self, task: Task):
|
||||
"""Download via yt-dlp"""
|
||||
# Get cancellation event for this task
|
||||
cancel_event = task_queue.get_cancel_event(task.id)
|
||||
|
||||
try:
|
||||
# Check for cancellation
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
logger.info(f"Task {task.id} cancelled, stopping download")
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
return
|
||||
|
||||
# Get media information to check limits
|
||||
from bot.modules.media_loader.ytdlp import get_media_info
|
||||
from shared.config import settings
|
||||
|
||||
media_info = await get_media_info(task.url, cookies_file=settings.COOKIES_FILE)
|
||||
if media_info:
|
||||
# Check duration
|
||||
max_duration = settings.max_duration_minutes_int
|
||||
if max_duration and media_info.get('duration'):
|
||||
duration_minutes = media_info['duration'] / 60
|
||||
if duration_minutes > max_duration:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"Maximum duration exceeded ({max_duration} min)"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ File too long ({duration_minutes:.1f} min). "
|
||||
f"Maximum: {max_duration} min."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send duration exceeded message to user {task.user_id}: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
# Function for updating progress with cancellation check
|
||||
# This function is called from another thread (yt-dlp), so we use run_coroutine_threadsafe
|
||||
async def update_progress(percent: int):
|
||||
# Check cancellation when updating progress
|
||||
if cancel_event.is_set():
|
||||
raise asyncio.CancelledError("Task cancelled by user")
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
raise asyncio.CancelledError("Task cancelled by user")
|
||||
await task_queue.update_task_status(task.id, TaskStatus.PROCESSING, progress=percent)
|
||||
|
||||
# Update progress message
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
from pyrogram.errors import MessageNotModified
|
||||
status_text = (
|
||||
f"📥 **Downloading file**\n\n"
|
||||
f"🔗 {task.url[:50]}...\n"
|
||||
f"📊 Progress: **{percent}%**\n"
|
||||
f"⏳ Please wait..."
|
||||
)
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text=status_text
|
||||
)
|
||||
except MessageNotModified:
|
||||
pass # Ignore if text didn't change
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to update message: {e}")
|
||||
|
||||
# Save reference to event loop for use in progress hook
|
||||
progress_loop = asyncio.get_event_loop()
|
||||
|
||||
# Send one message about download start (will be updated)
|
||||
app_client = get_app_client()
|
||||
status_message = None
|
||||
if app_client:
|
||||
try:
|
||||
status_text = (
|
||||
f"📥 **Downloading file**\n\n"
|
||||
f"🔗 {task.url[:50]}...\n"
|
||||
f"📊 Progress: **0%**\n"
|
||||
f"⏳ Please wait..."
|
||||
)
|
||||
status_message = await app_client.send_message(
|
||||
task.user_id,
|
||||
status_text
|
||||
)
|
||||
# Save message_id for updates
|
||||
set_task_message(task.id, status_message.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send notification: {e}")
|
||||
|
||||
# Download media
|
||||
try:
|
||||
logger.info(f"Starting download for task {task.id}, URL: {task.url}")
|
||||
result = await download_media(
|
||||
url=task.url,
|
||||
output_dir="downloads",
|
||||
quality="best",
|
||||
progress_callback=update_progress,
|
||||
cookies_file=settings.COOKIES_FILE,
|
||||
cancel_event=cancel_event,
|
||||
task_id=task.id
|
||||
)
|
||||
|
||||
logger.info(f"Download completed for task {task.id}. Result: {result is not None}, file_path: {result.get('file_path') if result else None}")
|
||||
|
||||
# Save file path to database
|
||||
if result and result.get('file_path'):
|
||||
await task_queue.update_task_file_path(task.id, result['file_path'])
|
||||
logger.info(f"File path saved to DB: {result['file_path']}")
|
||||
|
||||
# Check file size after download
|
||||
max_file_size = settings.max_file_size_bytes
|
||||
if result and max_file_size:
|
||||
file_size = result.get('size', 0)
|
||||
if file_size > max_file_size:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"Maximum file size exceeded ({max_file_size / (1024*1024):.1f} MB)"
|
||||
)
|
||||
# Delete file
|
||||
if result.get('file_path'):
|
||||
await delete_file(result['file_path'])
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ File too large ({file_size / (1024*1024):.1f} MB). "
|
||||
f"Maximum: {max_file_size / (1024*1024):.1f} MB."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send file size exceeded message to user {task.user_id}: {e}", exc_info=True)
|
||||
clear_task_message(task.id)
|
||||
return
|
||||
except (asyncio.CancelledError, KeyboardInterrupt) as e:
|
||||
logger.info(f"Task {task.id} cancelled during download: {e}")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.CANCELLED,
|
||||
error="Cancelled by user"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
# Update cancellation message
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text=f"🚫 **Task cancelled**\n\nTask #{task.id} was cancelled."
|
||||
)
|
||||
except Exception as edit_error:
|
||||
# If update failed, send new message
|
||||
logger.debug(f"Failed to update cancellation message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"🚫 Task #{task.id} cancelled"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send cancellation message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"🚫 Task #{task.id} cancelled"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send cancellation message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending cancellation notification for task {task.id}: {e}", exc_info=True)
|
||||
clear_task_message(task.id)
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
return
|
||||
|
||||
# Check for cancellation after download
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
logger.info(f"Task {task.id} cancelled after download")
|
||||
# Delete downloaded file if exists
|
||||
if result and result.get('file_path'):
|
||||
await delete_file(result['file_path'])
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
return
|
||||
|
||||
if not result:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Failed to download file"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text="❌ **Download error**\n\nFailed to download file. Check the link and try again."
|
||||
)
|
||||
except Exception as edit_error:
|
||||
logger.debug(f"Failed to update error message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error downloading file. Check the link and try again."
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error downloading file. Check the link and try again."
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download error notification for task {task.id}: {e}", exc_info=True)
|
||||
clear_task_message(task.id)
|
||||
return
|
||||
|
||||
# Send file to user
|
||||
await task_queue.update_task_status(task.id, TaskStatus.PROCESSING, progress=90)
|
||||
|
||||
# Check that file exists before sending
|
||||
file_path_obj = Path(result['file_path'])
|
||||
if not file_path_obj.exists():
|
||||
logger.error(f"File doesn't exist before sending: {result['file_path']}")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"File not found: {result['file_path']}"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error: file not found after download"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message: {e}")
|
||||
clear_task_message(task.id)
|
||||
return
|
||||
|
||||
logger.info(f"Sending file to user {task.user_id}: {result['file_path']}")
|
||||
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
# Form caption
|
||||
caption = f"📥 **{result.get('title', 'File')}**"
|
||||
if result.get('duration'):
|
||||
from bot.utils.helpers import format_duration
|
||||
caption += f"\n⏱ Duration: {format_duration(result['duration'])}"
|
||||
|
||||
# Send file
|
||||
logger.info(f"Calling send_file_to_user for file: {result['file_path']}")
|
||||
success = await send_file_to_user(
|
||||
client=app_client,
|
||||
chat_id=task.user_id,
|
||||
file_path=result['file_path'],
|
||||
caption=caption,
|
||||
thumbnail=result.get('thumbnail')
|
||||
)
|
||||
|
||||
logger.info(f"File sending result: success={success}")
|
||||
|
||||
if success:
|
||||
# Delete file after successful sending
|
||||
await delete_file(result['file_path'])
|
||||
await task_queue.update_task_status(task.id, TaskStatus.COMPLETED, progress=100)
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
|
||||
# Delete download message (file already sent)
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.delete_messages(
|
||||
chat_id=task.user_id,
|
||||
message_ids=message_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to delete download message for task {task.id}: {e}")
|
||||
clear_task_message(task.id)
|
||||
else:
|
||||
error_msg = "Failed to send file"
|
||||
logger.error(f"File sending error for task {task.id}: {error_msg}")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=error_msg
|
||||
)
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error sending file. File downloaded but failed to send."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send file sending error message: {e}")
|
||||
except Exception as send_error:
|
||||
error_msg = f"Error sending file: {str(send_error)}"
|
||||
logger.error(f"Exception sending file for task {task.id}: {send_error}", exc_info=True)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=error_msg
|
||||
)
|
||||
try:
|
||||
if app_client:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error sending file: {str(send_error)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message: {e}")
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text="❌ **Send error**\n\nFailed to send file. Try again later."
|
||||
)
|
||||
except Exception as edit_error:
|
||||
logger.debug(f"Failed to update send error message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
"❌ Error sending file. Try again later."
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
"❌ Error sending file. Try again later."
|
||||
)
|
||||
clear_task_message(task.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending file: {e}", exc_info=True)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"Send error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.warning("Client not set, file not sent")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Client not available"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading via yt-dlp: {e}", exc_info=True)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
clear_task_message(task.id)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text=f"❌ **Error**\n\nAn error occurred: {str(e)}"
|
||||
)
|
||||
except Exception as edit_error:
|
||||
logger.debug(f"Failed to update error message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ An error occurred: {str(e)}"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ An error occurred: {str(e)}"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
except Exception as notify_error:
|
||||
logger.error(f"Error sending error notification for task {task.id}: {notify_error}", exc_info=True)
|
||||
|
||||
|
||||
# Global task executor
|
||||
task_executor = TaskExecutor()
|
||||
|
||||
84
bot/modules/task_scheduler/monitor.py
Normal file
84
bot/modules/task_scheduler/monitor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Task monitoring
|
||||
"""
|
||||
from typing import Tuple
|
||||
from bot.modules.task_scheduler.queue import task_queue, TaskStatus
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_task_status(task_id: int) -> dict:
|
||||
"""
|
||||
Get task status
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Dictionary with task status
|
||||
"""
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
return {"error": "Task not found"}
|
||||
|
||||
return {
|
||||
"id": task.id,
|
||||
"user_id": task.user_id,
|
||||
"task_type": task.task_type,
|
||||
"status": task.status.value,
|
||||
"progress": task.progress,
|
||||
"error_message": task.error_message,
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None
|
||||
}
|
||||
|
||||
|
||||
async def get_user_tasks_status(user_id: int) -> list[dict]:
|
||||
"""
|
||||
Get status of all user tasks
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of user tasks
|
||||
"""
|
||||
tasks = await task_queue.get_user_tasks(user_id)
|
||||
return [await get_task_status(task.id) for task in tasks]
|
||||
|
||||
|
||||
async def cancel_user_task(user_id: int, task_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Cancel user task
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
return (False, f"Задача #{task_id} не найдена")
|
||||
|
||||
if task.user_id != user_id:
|
||||
return (False, "Вы можете отменять только свои задачи")
|
||||
|
||||
if task.status == TaskStatus.COMPLETED:
|
||||
return (False, f"Задача #{task_id} уже завершена")
|
||||
|
||||
if task.status == TaskStatus.CANCELLED:
|
||||
return (False, f"Задача #{task_id} уже отменена")
|
||||
|
||||
if task.status == TaskStatus.FAILED:
|
||||
return (False, f"Задача #{task_id} уже завершилась с ошибкой")
|
||||
|
||||
success = await task_queue.cancel_task(task_id)
|
||||
if success:
|
||||
return (True, f"Задача #{task_id} успешно отменена")
|
||||
else:
|
||||
return (False, f"Не удалось отменить задачу #{task_id}. Возможно, она уже завершается.")
|
||||
|
||||
297
bot/modules/task_scheduler/queue.py
Normal file
297
bot/modules/task_scheduler/queue.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Task queue
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task statuses"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task for execution"""
|
||||
id: int
|
||||
user_id: int
|
||||
task_type: str
|
||||
url: Optional[str] = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
progress: int = 0
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""Task queue"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: Optional[asyncio.Queue] = None
|
||||
self._tasks: dict[int, Task] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
# Dictionary to store cancellation events for each task
|
||||
# Protected by threading.Lock, as access may be from different threads
|
||||
self._cancel_events: dict[int, threading.Event] = {}
|
||||
self._cancel_events_lock = threading.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize queue in current event loop"""
|
||||
if not self._initialized:
|
||||
self._queue = asyncio.Queue()
|
||||
self._initialized = True
|
||||
logger.info("Task queue initialized")
|
||||
|
||||
async def add_task(self, task: Task, check_duplicate_url: bool = True) -> bool:
|
||||
"""
|
||||
Add task to queue
|
||||
|
||||
Args:
|
||||
task: Task to add
|
||||
check_duplicate_url: Check for duplicate URLs for this user
|
||||
|
||||
Returns:
|
||||
True if successful, False if duplicate
|
||||
"""
|
||||
if not self._initialized or not self._queue:
|
||||
await self.initialize()
|
||||
|
||||
async with self._lock:
|
||||
# Check for duplicate task_id
|
||||
if task.id in self._tasks:
|
||||
logger.warning(f"Task with ID {task.id} already exists, skipping duplicate")
|
||||
return False
|
||||
|
||||
# Atomic check for duplicate URLs for this user
|
||||
if check_duplicate_url and task.url:
|
||||
url_normalized = task.url.strip()
|
||||
for existing_task in self._tasks.values():
|
||||
if (existing_task.user_id == task.user_id and
|
||||
existing_task.url and existing_task.url.strip() == url_normalized and
|
||||
existing_task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING] and
|
||||
existing_task.id != task.id):
|
||||
logger.warning(f"Task {task.id} with URL {url_normalized} already processing for user {task.user_id} (task {existing_task.id})")
|
||||
return False # Block duplicate
|
||||
|
||||
self._tasks[task.id] = task
|
||||
await self._queue.put(task)
|
||||
logger.info(f"Task {task.id} added to queue")
|
||||
return True
|
||||
|
||||
async def get_task(self) -> Optional[Task]:
|
||||
"""
|
||||
Get task from queue
|
||||
|
||||
Returns:
|
||||
Task or None
|
||||
"""
|
||||
if not self._initialized or not self._queue:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
# Use timeout to avoid blocking indefinitely
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
return task
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is normal, just no tasks
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task: {e}", exc_info=True)
|
||||
await asyncio.sleep(1) # Small delay before retry
|
||||
return None
|
||||
|
||||
async def update_task_status(self, task_id: int, status: TaskStatus, progress: int = None, error: str = None):
|
||||
"""
|
||||
Update task status
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
status: New status
|
||||
progress: Progress (0-100)
|
||||
error: Error message
|
||||
"""
|
||||
async with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if error:
|
||||
task.error_message = error
|
||||
logger.info(f"Task {task_id} updated: {status.value}")
|
||||
|
||||
# Sync with database
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from datetime import datetime
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
db_task.status = status.value
|
||||
if progress is not None:
|
||||
db_task.progress = progress
|
||||
if error:
|
||||
# Limit error_message length (maximum 1000 characters)
|
||||
db_task.error_message = (error[:1000] if error and len(error) > 1000 else error)
|
||||
if status == TaskStatus.COMPLETED:
|
||||
db_task.completed_at = datetime.utcnow()
|
||||
db_task.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.debug(f"Task {task_id} updated in DB: {status.value}")
|
||||
else:
|
||||
# Task not found in DB - create it for synchronization
|
||||
logger.warning(f"Task {task_id} not found in DB, creating record for synchronization")
|
||||
from bot.modules.task_scheduler.queue import task_queue
|
||||
from shared.database.user_helpers import ensure_user_exists
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
if task:
|
||||
# Ensure user exists before creating task
|
||||
await ensure_user_exists(task.user_id, session)
|
||||
|
||||
db_task = DBTask(
|
||||
id=task_id,
|
||||
user_id=task.user_id,
|
||||
task_type=task.task_type,
|
||||
status=status.value,
|
||||
url=task.url,
|
||||
progress=progress if progress is not None else 0,
|
||||
error_message=(error[:1000] if error and len(error) > 1000 else error),
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
session.add(db_task)
|
||||
await session.commit()
|
||||
logger.info(f"Task {task_id} created in DB for synchronization")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update task {task_id} in DB: {e}", exc_info=True)
|
||||
# Try to rollback if session is still open
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
await session.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def update_task_file_path(self, task_id: int, file_path: str):
|
||||
"""
|
||||
Update task file path
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
file_path: File path
|
||||
"""
|
||||
# Sync with database
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from datetime import datetime
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
db_task.file_path = file_path
|
||||
db_task.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.debug(f"Task {task_id} file path updated in DB: {file_path}")
|
||||
else:
|
||||
logger.warning(f"Task {task_id} not found in DB for file_path update")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update task {task_id} file path in DB: {e}", exc_info=True)
|
||||
# Try to rollback if session is still open
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
await session.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def get_task_by_id(self, task_id: int) -> Optional[Task]:
|
||||
"""Get task by ID"""
|
||||
async with self._lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
"""
|
||||
Cancel task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
async with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
# Can only cancel pending or processing tasks
|
||||
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.error_message = "Cancelled by user"
|
||||
# Set cancellation event to interrupt download (thread-safe)
|
||||
with self._cancel_events_lock:
|
||||
if task_id in self._cancel_events:
|
||||
self._cancel_events[task_id].set()
|
||||
logger.info(f"Task {task_id} cancelled")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cancel_event(self, task_id: int) -> threading.Event:
|
||||
"""
|
||||
Get cancellation event for task (created if doesn't exist)
|
||||
Thread-safe method for use from different threads.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
threading.Event for cancellation check
|
||||
"""
|
||||
with self._cancel_events_lock:
|
||||
if task_id not in self._cancel_events:
|
||||
self._cancel_events[task_id] = threading.Event()
|
||||
return self._cancel_events[task_id]
|
||||
|
||||
def clear_cancel_event(self, task_id: int):
|
||||
"""
|
||||
Clear cancellation event after task completion
|
||||
Thread-safe method for use from different threads.
|
||||
"""
|
||||
with self._cancel_events_lock:
|
||||
if task_id in self._cancel_events:
|
||||
del self._cancel_events[task_id]
|
||||
|
||||
async def get_user_tasks(self, user_id: int) -> list[Task]:
|
||||
"""Get user tasks"""
|
||||
async with self._lock:
|
||||
return [task for task in self._tasks.values() if task.user_id == user_id]
|
||||
|
||||
async def get_user_active_tasks_count(self, user_id: int) -> int:
|
||||
"""Get count of active user tasks"""
|
||||
async with self._lock:
|
||||
active_statuses = [TaskStatus.PENDING, TaskStatus.PROCESSING]
|
||||
return sum(1 for task in self._tasks.values()
|
||||
if task.user_id == user_id and task.status in active_statuses)
|
||||
|
||||
async def get_all_tasks(self) -> list[Task]:
|
||||
"""Get all tasks"""
|
||||
async with self._lock:
|
||||
return list(self._tasks.values())
|
||||
|
||||
|
||||
# Global task queue
|
||||
task_queue = TaskQueue()
|
||||
|
||||
Reference in New Issue
Block a user