Add source
This commit is contained in:
4
bot/modules/task_scheduler/__init__.py
Normal file
4
bot/modules/task_scheduler/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Task scheduler module
|
||||
"""
|
||||
|
||||
694
bot/modules/task_scheduler/executor.py
Normal file
694
bot/modules/task_scheduler/executor.py
Normal file
@@ -0,0 +1,694 @@
|
||||
"""
|
||||
Task executor
|
||||
"""
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from bot.modules.task_scheduler.queue import task_queue, Task, TaskStatus
|
||||
from bot.modules.media_loader.ytdlp import download_media
|
||||
from bot.modules.media_loader.sender import send_file_to_user, delete_file
|
||||
from bot.modules.message_handler.filters import is_youtube_url, is_instagram_url
|
||||
from pyrogram import Client
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global client for sending messages
|
||||
# Protected by threading.Lock for thread-safe access from different threads
|
||||
_app_client: Optional[Client] = None
|
||||
_app_client_lock = threading.Lock()
|
||||
|
||||
# Dictionary to store message_id for tasks to update messages
|
||||
# Format: {task_id: message_id}
|
||||
# Use size limit to prevent memory leaks
|
||||
_task_messages: dict[int, int] = {}
|
||||
_task_messages_lock = threading.Lock()
|
||||
_MAX_TASK_MESSAGES = 10000 # Maximum number of records
|
||||
|
||||
|
||||
def set_app_client(client: Client) -> None:
|
||||
"""
|
||||
Set client for sending messages (thread-safe)
|
||||
|
||||
Args:
|
||||
client: Pyrogram client for sending messages
|
||||
"""
|
||||
global _app_client
|
||||
with _app_client_lock:
|
||||
_app_client = client
|
||||
|
||||
|
||||
def get_app_client() -> Optional[Client]:
|
||||
"""Get client for sending messages (thread-safe)"""
|
||||
global _app_client
|
||||
with _app_client_lock:
|
||||
return _app_client
|
||||
|
||||
|
||||
def set_task_message(task_id: int, message_id: int) -> None:
|
||||
"""
|
||||
Save message_id for task (thread-safe)
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
message_id: Telegram message ID
|
||||
"""
|
||||
global _task_messages
|
||||
with _task_messages_lock:
|
||||
# Clear old records if limit reached
|
||||
if len(_task_messages) >= _MAX_TASK_MESSAGES:
|
||||
# Remove 10% of oldest records (FIFO)
|
||||
keys_to_remove = list(_task_messages.keys())[:_MAX_TASK_MESSAGES // 10]
|
||||
for key in keys_to_remove:
|
||||
_task_messages.pop(key, None)
|
||||
logger.debug(f"Cleared {len(keys_to_remove)} old records from _task_messages")
|
||||
_task_messages[task_id] = message_id
|
||||
|
||||
|
||||
def get_task_message(task_id: int) -> Optional[int]:
|
||||
"""Get message_id for task (thread-safe)"""
|
||||
global _task_messages
|
||||
with _task_messages_lock:
|
||||
return _task_messages.get(task_id)
|
||||
|
||||
|
||||
def clear_task_message(task_id: int) -> None:
|
||||
"""
|
||||
Remove message_id for task (thread-safe)
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
"""
|
||||
global _task_messages
|
||||
with _task_messages_lock:
|
||||
_task_messages.pop(task_id, None)
|
||||
|
||||
|
||||
async def cleanup_completed_task_messages():
|
||||
"""
|
||||
Periodic cleanup of message_id for completed tasks
|
||||
Runs in background every 30 minutes
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(30 * 60) # 30 minutes
|
||||
from bot.modules.task_scheduler.queue import task_queue, TaskStatus
|
||||
|
||||
# Get all completed tasks
|
||||
all_tasks = await task_queue.get_all_tasks()
|
||||
completed_task_ids = [
|
||||
task.id for task in all_tasks
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]
|
||||
]
|
||||
|
||||
# Remove message_id for completed tasks
|
||||
with _task_messages_lock:
|
||||
removed_count = 0
|
||||
for task_id in completed_task_ids:
|
||||
if task_id in _task_messages:
|
||||
del _task_messages[task_id]
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"Cleared {removed_count} message_id for completed tasks")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Message ID cleanup task stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up message_id: {e}", exc_info=True)
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""Task executor"""
|
||||
|
||||
def __init__(self):
|
||||
self._running = False
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._running_lock = asyncio.Lock() # Protection for _running flag
|
||||
|
||||
async def start(self, num_workers: int = 2):
|
||||
"""
|
||||
Start task executor
|
||||
|
||||
Args:
|
||||
num_workers: Number of workers (default 2 for parallel processing)
|
||||
"""
|
||||
async with self._running_lock:
|
||||
if self._running:
|
||||
logger.warning("Task executor already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info(f"Starting task executor with {num_workers} workers")
|
||||
|
||||
# Create workers (each works independently)
|
||||
for i in range(num_workers):
|
||||
worker = asyncio.create_task(self._worker(f"worker-{i+1}"))
|
||||
self._workers.append(worker)
|
||||
# Small delay between worker starts for even load distribution
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Start background task for cleaning message_id for completed tasks
|
||||
cleanup_task = asyncio.create_task(cleanup_completed_task_messages())
|
||||
self._workers.append(cleanup_task)
|
||||
logger.info("Started background task for cleaning message_id for completed tasks")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop task executor"""
|
||||
async with self._running_lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("Stopping task executor...")
|
||||
|
||||
# Cancel all active tasks
|
||||
# Get all tasks and cancel active ones
|
||||
all_tasks = await task_queue.get_all_tasks()
|
||||
for task in all_tasks:
|
||||
if task.status == TaskStatus.PROCESSING:
|
||||
logger.info(f"Cancelling active task {task.id} on shutdown")
|
||||
await task_queue.update_task_status(task.id, TaskStatus.CANCELLED, error="Bot shutdown")
|
||||
cancel_event = task_queue.get_cancel_event(task.id)
|
||||
cancel_event.set()
|
||||
|
||||
# Wait for all workers to complete
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("Task executor stopped")
|
||||
|
||||
async def _worker(self, name: str):
|
||||
"""Worker for processing tasks (runs in parallel with other workers)"""
|
||||
logger.info(f"Worker {name} started")
|
||||
|
||||
while True:
|
||||
# Check state with lock protection
|
||||
async with self._running_lock:
|
||||
if not self._running:
|
||||
break
|
||||
try:
|
||||
# Get task from queue (non-blocking)
|
||||
task = await task_queue.get_task()
|
||||
|
||||
if not task:
|
||||
# No tasks, small delay
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
|
||||
# Check for cancellation before starting processing
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
logger.info(f"Task {task.id} was cancelled, skipping")
|
||||
continue
|
||||
|
||||
# Update status
|
||||
await task_queue.update_task_status(task.id, TaskStatus.PROCESSING)
|
||||
|
||||
logger.info(f"Worker {name} processing task {task.id}")
|
||||
|
||||
# Execute task (doesn't block other workers and message processing)
|
||||
# Each task executes independently
|
||||
await self._execute_task(task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Worker {name} stopped")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker {name}: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"Worker {name} finished")
|
||||
|
||||
async def _execute_task(self, task: Task):
|
||||
"""
|
||||
Execute task
|
||||
|
||||
Args:
|
||||
task: Task to execute
|
||||
"""
|
||||
try:
|
||||
if task.task_type == "download" and task.url:
|
||||
# Determine download type
|
||||
if is_youtube_url(task.url) or is_instagram_url(task.url) or any(
|
||||
domain in task.url.lower()
|
||||
for domain in ['youtube.com', 'youtu.be', 'instagram.com', 'tiktok.com', 'twitter.com', 'x.com']
|
||||
):
|
||||
# Download via yt-dlp
|
||||
await self._download_with_ytdlp(task)
|
||||
else:
|
||||
# Direct download (to be implemented later)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Direct download not supported yet"
|
||||
)
|
||||
else:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Unknown task type or missing URL"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task {task.id}: {e}", exc_info=True)
|
||||
|
||||
# Form user-friendly error message
|
||||
error_message = str(e)
|
||||
if "login required" in error_message.lower() or "cookies" in error_message.lower():
|
||||
error_message = (
|
||||
"❌ Authentication required to download this content.\n\n"
|
||||
"💡 Solution: configure cookies in bot configuration.\n"
|
||||
"See instructions in README.md"
|
||||
)
|
||||
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=error_message
|
||||
)
|
||||
|
||||
# Send message to user
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Download error:\n{error_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {e}", exc_info=True)
|
||||
|
||||
async def _download_with_ytdlp(self, task: Task):
|
||||
"""Download via yt-dlp"""
|
||||
# Get cancellation event for this task
|
||||
cancel_event = task_queue.get_cancel_event(task.id)
|
||||
|
||||
try:
|
||||
# Check for cancellation
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
logger.info(f"Task {task.id} cancelled, stopping download")
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
return
|
||||
|
||||
# Get media information to check limits
|
||||
from bot.modules.media_loader.ytdlp import get_media_info
|
||||
from shared.config import settings
|
||||
|
||||
media_info = await get_media_info(task.url, cookies_file=settings.COOKIES_FILE)
|
||||
if media_info:
|
||||
# Check duration
|
||||
max_duration = settings.max_duration_minutes_int
|
||||
if max_duration and media_info.get('duration'):
|
||||
duration_minutes = media_info['duration'] / 60
|
||||
if duration_minutes > max_duration:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"Maximum duration exceeded ({max_duration} min)"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ File too long ({duration_minutes:.1f} min). "
|
||||
f"Maximum: {max_duration} min."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send duration exceeded message to user {task.user_id}: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
# Function for updating progress with cancellation check
|
||||
# This function is called from another thread (yt-dlp), so we use run_coroutine_threadsafe
|
||||
async def update_progress(percent: int):
|
||||
# Check cancellation when updating progress
|
||||
if cancel_event.is_set():
|
||||
raise asyncio.CancelledError("Task cancelled by user")
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
raise asyncio.CancelledError("Task cancelled by user")
|
||||
await task_queue.update_task_status(task.id, TaskStatus.PROCESSING, progress=percent)
|
||||
|
||||
# Update progress message
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
from pyrogram.errors import MessageNotModified
|
||||
status_text = (
|
||||
f"📥 **Downloading file**\n\n"
|
||||
f"🔗 {task.url[:50]}...\n"
|
||||
f"📊 Progress: **{percent}%**\n"
|
||||
f"⏳ Please wait..."
|
||||
)
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text=status_text
|
||||
)
|
||||
except MessageNotModified:
|
||||
pass # Ignore if text didn't change
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to update message: {e}")
|
||||
|
||||
# Save reference to event loop for use in progress hook
|
||||
progress_loop = asyncio.get_event_loop()
|
||||
|
||||
# Send one message about download start (will be updated)
|
||||
app_client = get_app_client()
|
||||
status_message = None
|
||||
if app_client:
|
||||
try:
|
||||
status_text = (
|
||||
f"📥 **Downloading file**\n\n"
|
||||
f"🔗 {task.url[:50]}...\n"
|
||||
f"📊 Progress: **0%**\n"
|
||||
f"⏳ Please wait..."
|
||||
)
|
||||
status_message = await app_client.send_message(
|
||||
task.user_id,
|
||||
status_text
|
||||
)
|
||||
# Save message_id for updates
|
||||
set_task_message(task.id, status_message.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send notification: {e}")
|
||||
|
||||
# Download media
|
||||
try:
|
||||
logger.info(f"Starting download for task {task.id}, URL: {task.url}")
|
||||
result = await download_media(
|
||||
url=task.url,
|
||||
output_dir="downloads",
|
||||
quality="best",
|
||||
progress_callback=update_progress,
|
||||
cookies_file=settings.COOKIES_FILE,
|
||||
cancel_event=cancel_event,
|
||||
task_id=task.id
|
||||
)
|
||||
|
||||
logger.info(f"Download completed for task {task.id}. Result: {result is not None}, file_path: {result.get('file_path') if result else None}")
|
||||
|
||||
# Save file path to database
|
||||
if result and result.get('file_path'):
|
||||
await task_queue.update_task_file_path(task.id, result['file_path'])
|
||||
logger.info(f"File path saved to DB: {result['file_path']}")
|
||||
|
||||
# Check file size after download
|
||||
max_file_size = settings.max_file_size_bytes
|
||||
if result and max_file_size:
|
||||
file_size = result.get('size', 0)
|
||||
if file_size > max_file_size:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"Maximum file size exceeded ({max_file_size / (1024*1024):.1f} MB)"
|
||||
)
|
||||
# Delete file
|
||||
if result.get('file_path'):
|
||||
await delete_file(result['file_path'])
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ File too large ({file_size / (1024*1024):.1f} MB). "
|
||||
f"Maximum: {max_file_size / (1024*1024):.1f} MB."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send file size exceeded message to user {task.user_id}: {e}", exc_info=True)
|
||||
clear_task_message(task.id)
|
||||
return
|
||||
except (asyncio.CancelledError, KeyboardInterrupt) as e:
|
||||
logger.info(f"Task {task.id} cancelled during download: {e}")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.CANCELLED,
|
||||
error="Cancelled by user"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
# Update cancellation message
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text=f"🚫 **Task cancelled**\n\nTask #{task.id} was cancelled."
|
||||
)
|
||||
except Exception as edit_error:
|
||||
# If update failed, send new message
|
||||
logger.debug(f"Failed to update cancellation message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"🚫 Task #{task.id} cancelled"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send cancellation message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"🚫 Task #{task.id} cancelled"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send cancellation message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending cancellation notification for task {task.id}: {e}", exc_info=True)
|
||||
clear_task_message(task.id)
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
return
|
||||
|
||||
# Check for cancellation after download
|
||||
current_task = await task_queue.get_task_by_id(task.id)
|
||||
if current_task and current_task.status == TaskStatus.CANCELLED:
|
||||
logger.info(f"Task {task.id} cancelled after download")
|
||||
# Delete downloaded file if exists
|
||||
if result and result.get('file_path'):
|
||||
await delete_file(result['file_path'])
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
return
|
||||
|
||||
if not result:
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Failed to download file"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text="❌ **Download error**\n\nFailed to download file. Check the link and try again."
|
||||
)
|
||||
except Exception as edit_error:
|
||||
logger.debug(f"Failed to update error message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error downloading file. Check the link and try again."
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error downloading file. Check the link and try again."
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download error notification for task {task.id}: {e}", exc_info=True)
|
||||
clear_task_message(task.id)
|
||||
return
|
||||
|
||||
# Send file to user
|
||||
await task_queue.update_task_status(task.id, TaskStatus.PROCESSING, progress=90)
|
||||
|
||||
# Check that file exists before sending
|
||||
file_path_obj = Path(result['file_path'])
|
||||
if not file_path_obj.exists():
|
||||
logger.error(f"File doesn't exist before sending: {result['file_path']}")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"File not found: {result['file_path']}"
|
||||
)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error: file not found after download"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message: {e}")
|
||||
clear_task_message(task.id)
|
||||
return
|
||||
|
||||
logger.info(f"Sending file to user {task.user_id}: {result['file_path']}")
|
||||
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
# Form caption
|
||||
caption = f"📥 **{result.get('title', 'File')}**"
|
||||
if result.get('duration'):
|
||||
from bot.utils.helpers import format_duration
|
||||
caption += f"\n⏱ Duration: {format_duration(result['duration'])}"
|
||||
|
||||
# Send file
|
||||
logger.info(f"Calling send_file_to_user for file: {result['file_path']}")
|
||||
success = await send_file_to_user(
|
||||
client=app_client,
|
||||
chat_id=task.user_id,
|
||||
file_path=result['file_path'],
|
||||
caption=caption,
|
||||
thumbnail=result.get('thumbnail')
|
||||
)
|
||||
|
||||
logger.info(f"File sending result: success={success}")
|
||||
|
||||
if success:
|
||||
# Delete file after successful sending
|
||||
await delete_file(result['file_path'])
|
||||
await task_queue.update_task_status(task.id, TaskStatus.COMPLETED, progress=100)
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
|
||||
# Delete download message (file already sent)
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.delete_messages(
|
||||
chat_id=task.user_id,
|
||||
message_ids=message_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to delete download message for task {task.id}: {e}")
|
||||
clear_task_message(task.id)
|
||||
else:
|
||||
error_msg = "Failed to send file"
|
||||
logger.error(f"File sending error for task {task.id}: {error_msg}")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=error_msg
|
||||
)
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error sending file. File downloaded but failed to send."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send file sending error message: {e}")
|
||||
except Exception as send_error:
|
||||
error_msg = f"Error sending file: {str(send_error)}"
|
||||
logger.error(f"Exception sending file for task {task.id}: {send_error}", exc_info=True)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=error_msg
|
||||
)
|
||||
try:
|
||||
if app_client:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ Error sending file: {str(send_error)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send error message: {e}")
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text="❌ **Send error**\n\nFailed to send file. Try again later."
|
||||
)
|
||||
except Exception as edit_error:
|
||||
logger.debug(f"Failed to update send error message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
"❌ Error sending file. Try again later."
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
"❌ Error sending file. Try again later."
|
||||
)
|
||||
clear_task_message(task.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending file: {e}", exc_info=True)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=f"Send error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.warning("Client not set, file not sent")
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error="Client not available"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading via yt-dlp: {e}", exc_info=True)
|
||||
await task_queue.update_task_status(
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
task_queue.clear_cancel_event(task.id)
|
||||
clear_task_message(task.id)
|
||||
app_client = get_app_client()
|
||||
if app_client:
|
||||
try:
|
||||
message_id = get_task_message(task.id)
|
||||
if message_id:
|
||||
try:
|
||||
await app_client.edit_message_text(
|
||||
chat_id=task.user_id,
|
||||
message_id=message_id,
|
||||
text=f"❌ **Error**\n\nAn error occurred: {str(e)}"
|
||||
)
|
||||
except Exception as edit_error:
|
||||
logger.debug(f"Failed to update error message, sending new: {edit_error}")
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ An error occurred: {str(e)}"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
await app_client.send_message(
|
||||
task.user_id,
|
||||
f"❌ An error occurred: {str(e)}"
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error message to user {task.user_id}: {send_error}", exc_info=True)
|
||||
except Exception as notify_error:
|
||||
logger.error(f"Error sending error notification for task {task.id}: {notify_error}", exc_info=True)
|
||||
|
||||
|
||||
# Global task executor
|
||||
task_executor = TaskExecutor()
|
||||
|
||||
84
bot/modules/task_scheduler/monitor.py
Normal file
84
bot/modules/task_scheduler/monitor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Task monitoring
|
||||
"""
|
||||
from typing import Tuple
|
||||
from bot.modules.task_scheduler.queue import task_queue, TaskStatus
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_task_status(task_id: int) -> dict:
|
||||
"""
|
||||
Get task status
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Dictionary with task status
|
||||
"""
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
return {"error": "Task not found"}
|
||||
|
||||
return {
|
||||
"id": task.id,
|
||||
"user_id": task.user_id,
|
||||
"task_type": task.task_type,
|
||||
"status": task.status.value,
|
||||
"progress": task.progress,
|
||||
"error_message": task.error_message,
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None
|
||||
}
|
||||
|
||||
|
||||
async def get_user_tasks_status(user_id: int) -> list[dict]:
|
||||
"""
|
||||
Get status of all user tasks
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of user tasks
|
||||
"""
|
||||
tasks = await task_queue.get_user_tasks(user_id)
|
||||
return [await get_task_status(task.id) for task in tasks]
|
||||
|
||||
|
||||
async def cancel_user_task(user_id: int, task_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Cancel user task
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
return (False, f"Задача #{task_id} не найдена")
|
||||
|
||||
if task.user_id != user_id:
|
||||
return (False, "Вы можете отменять только свои задачи")
|
||||
|
||||
if task.status == TaskStatus.COMPLETED:
|
||||
return (False, f"Задача #{task_id} уже завершена")
|
||||
|
||||
if task.status == TaskStatus.CANCELLED:
|
||||
return (False, f"Задача #{task_id} уже отменена")
|
||||
|
||||
if task.status == TaskStatus.FAILED:
|
||||
return (False, f"Задача #{task_id} уже завершилась с ошибкой")
|
||||
|
||||
success = await task_queue.cancel_task(task_id)
|
||||
if success:
|
||||
return (True, f"Задача #{task_id} успешно отменена")
|
||||
else:
|
||||
return (False, f"Не удалось отменить задачу #{task_id}. Возможно, она уже завершается.")
|
||||
|
||||
297
bot/modules/task_scheduler/queue.py
Normal file
297
bot/modules/task_scheduler/queue.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Task queue
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task statuses"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task for execution"""
|
||||
id: int
|
||||
user_id: int
|
||||
task_type: str
|
||||
url: Optional[str] = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
progress: int = 0
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""Task queue"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: Optional[asyncio.Queue] = None
|
||||
self._tasks: dict[int, Task] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
# Dictionary to store cancellation events for each task
|
||||
# Protected by threading.Lock, as access may be from different threads
|
||||
self._cancel_events: dict[int, threading.Event] = {}
|
||||
self._cancel_events_lock = threading.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize queue in current event loop"""
|
||||
if not self._initialized:
|
||||
self._queue = asyncio.Queue()
|
||||
self._initialized = True
|
||||
logger.info("Task queue initialized")
|
||||
|
||||
async def add_task(self, task: Task, check_duplicate_url: bool = True) -> bool:
|
||||
"""
|
||||
Add task to queue
|
||||
|
||||
Args:
|
||||
task: Task to add
|
||||
check_duplicate_url: Check for duplicate URLs for this user
|
||||
|
||||
Returns:
|
||||
True if successful, False if duplicate
|
||||
"""
|
||||
if not self._initialized or not self._queue:
|
||||
await self.initialize()
|
||||
|
||||
async with self._lock:
|
||||
# Check for duplicate task_id
|
||||
if task.id in self._tasks:
|
||||
logger.warning(f"Task with ID {task.id} already exists, skipping duplicate")
|
||||
return False
|
||||
|
||||
# Atomic check for duplicate URLs for this user
|
||||
if check_duplicate_url and task.url:
|
||||
url_normalized = task.url.strip()
|
||||
for existing_task in self._tasks.values():
|
||||
if (existing_task.user_id == task.user_id and
|
||||
existing_task.url and existing_task.url.strip() == url_normalized and
|
||||
existing_task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING] and
|
||||
existing_task.id != task.id):
|
||||
logger.warning(f"Task {task.id} with URL {url_normalized} already processing for user {task.user_id} (task {existing_task.id})")
|
||||
return False # Block duplicate
|
||||
|
||||
self._tasks[task.id] = task
|
||||
await self._queue.put(task)
|
||||
logger.info(f"Task {task.id} added to queue")
|
||||
return True
|
||||
|
||||
async def get_task(self) -> Optional[Task]:
|
||||
"""
|
||||
Get task from queue
|
||||
|
||||
Returns:
|
||||
Task or None
|
||||
"""
|
||||
if not self._initialized or not self._queue:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
# Use timeout to avoid blocking indefinitely
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
return task
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is normal, just no tasks
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task: {e}", exc_info=True)
|
||||
await asyncio.sleep(1) # Small delay before retry
|
||||
return None
|
||||
|
||||
async def update_task_status(self, task_id: int, status: TaskStatus, progress: int = None, error: str = None):
|
||||
"""
|
||||
Update task status
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
status: New status
|
||||
progress: Progress (0-100)
|
||||
error: Error message
|
||||
"""
|
||||
async with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if error:
|
||||
task.error_message = error
|
||||
logger.info(f"Task {task_id} updated: {status.value}")
|
||||
|
||||
# Sync with database
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from datetime import datetime
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
db_task.status = status.value
|
||||
if progress is not None:
|
||||
db_task.progress = progress
|
||||
if error:
|
||||
# Limit error_message length (maximum 1000 characters)
|
||||
db_task.error_message = (error[:1000] if error and len(error) > 1000 else error)
|
||||
if status == TaskStatus.COMPLETED:
|
||||
db_task.completed_at = datetime.utcnow()
|
||||
db_task.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.debug(f"Task {task_id} updated in DB: {status.value}")
|
||||
else:
|
||||
# Task not found in DB - create it for synchronization
|
||||
logger.warning(f"Task {task_id} not found in DB, creating record for synchronization")
|
||||
from bot.modules.task_scheduler.queue import task_queue
|
||||
from shared.database.user_helpers import ensure_user_exists
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
if task:
|
||||
# Ensure user exists before creating task
|
||||
await ensure_user_exists(task.user_id, session)
|
||||
|
||||
db_task = DBTask(
|
||||
id=task_id,
|
||||
user_id=task.user_id,
|
||||
task_type=task.task_type,
|
||||
status=status.value,
|
||||
url=task.url,
|
||||
progress=progress if progress is not None else 0,
|
||||
error_message=(error[:1000] if error and len(error) > 1000 else error),
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
session.add(db_task)
|
||||
await session.commit()
|
||||
logger.info(f"Task {task_id} created in DB for synchronization")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update task {task_id} in DB: {e}", exc_info=True)
|
||||
# Try to rollback if session is still open
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
await session.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def update_task_file_path(self, task_id: int, file_path: str):
|
||||
"""
|
||||
Update task file path
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
file_path: File path
|
||||
"""
|
||||
# Sync with database
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from datetime import datetime
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
db_task.file_path = file_path
|
||||
db_task.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.debug(f"Task {task_id} file path updated in DB: {file_path}")
|
||||
else:
|
||||
logger.warning(f"Task {task_id} not found in DB for file_path update")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update task {task_id} file path in DB: {e}", exc_info=True)
|
||||
# Try to rollback if session is still open
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
await session.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def get_task_by_id(self, task_id: int) -> Optional[Task]:
|
||||
"""Get task by ID"""
|
||||
async with self._lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
"""
|
||||
Cancel task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
async with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
# Can only cancel pending or processing tasks
|
||||
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.error_message = "Cancelled by user"
|
||||
# Set cancellation event to interrupt download (thread-safe)
|
||||
with self._cancel_events_lock:
|
||||
if task_id in self._cancel_events:
|
||||
self._cancel_events[task_id].set()
|
||||
logger.info(f"Task {task_id} cancelled")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cancel_event(self, task_id: int) -> threading.Event:
|
||||
"""
|
||||
Get cancellation event for task (created if doesn't exist)
|
||||
Thread-safe method for use from different threads.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
threading.Event for cancellation check
|
||||
"""
|
||||
with self._cancel_events_lock:
|
||||
if task_id not in self._cancel_events:
|
||||
self._cancel_events[task_id] = threading.Event()
|
||||
return self._cancel_events[task_id]
|
||||
|
||||
def clear_cancel_event(self, task_id: int):
|
||||
"""
|
||||
Clear cancellation event after task completion
|
||||
Thread-safe method for use from different threads.
|
||||
"""
|
||||
with self._cancel_events_lock:
|
||||
if task_id in self._cancel_events:
|
||||
del self._cancel_events[task_id]
|
||||
|
||||
async def get_user_tasks(self, user_id: int) -> list[Task]:
|
||||
"""Get user tasks"""
|
||||
async with self._lock:
|
||||
return [task for task in self._tasks.values() if task.user_id == user_id]
|
||||
|
||||
async def get_user_active_tasks_count(self, user_id: int) -> int:
|
||||
"""Get count of active user tasks"""
|
||||
async with self._lock:
|
||||
active_statuses = [TaskStatus.PENDING, TaskStatus.PROCESSING]
|
||||
return sum(1 for task in self._tasks.values()
|
||||
if task.user_id == user_id and task.status in active_statuses)
|
||||
|
||||
async def get_all_tasks(self) -> list[Task]:
|
||||
"""Get all tasks"""
|
||||
async with self._lock:
|
||||
return list(self._tasks.values())
|
||||
|
||||
|
||||
# Global task queue
|
||||
task_queue = TaskQueue()
|
||||
|
||||
Reference in New Issue
Block a user