""" Task queue """ import asyncio from typing import Optional from enum import Enum from dataclasses import dataclass from datetime import datetime import threading import logging logger = logging.getLogger(__name__) class TaskStatus(Enum): """Task statuses""" PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" @dataclass class Task: """Task for execution""" id: int user_id: int task_type: str url: Optional[str] = None status: TaskStatus = TaskStatus.PENDING progress: int = 0 error_message: Optional[str] = None created_at: datetime = None def __post_init__(self): if self.created_at is None: self.created_at = datetime.utcnow() class TaskQueue: """Task queue""" def __init__(self): self._queue: Optional[asyncio.Queue] = None self._tasks: dict[int, Task] = {} self._lock = asyncio.Lock() self._initialized = False # Dictionary to store cancellation events for each task # Protected by threading.Lock, as access may be from different threads self._cancel_events: dict[int, threading.Event] = {} self._cancel_events_lock = threading.Lock() async def initialize(self): """Initialize queue in current event loop""" if not self._initialized: self._queue = asyncio.Queue() self._initialized = True logger.info("Task queue initialized") async def add_task(self, task: Task, check_duplicate_url: bool = True) -> bool: """ Add task to queue Args: task: Task to add check_duplicate_url: Check for duplicate URLs for this user Returns: True if successful, False if duplicate """ if not self._initialized or not self._queue: await self.initialize() async with self._lock: # Check for duplicate task_id if task.id in self._tasks: logger.warning(f"Task with ID {task.id} already exists, skipping duplicate") return False # Atomic check for duplicate URLs for this user if check_duplicate_url and task.url: url_normalized = task.url.strip() for existing_task in self._tasks.values(): if (existing_task.user_id == task.user_id and existing_task.url and existing_task.url.strip() == url_normalized and existing_task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING] and existing_task.id != task.id): logger.warning(f"Task {task.id} with URL {url_normalized} already processing for user {task.user_id} (task {existing_task.id})") return False # Block duplicate self._tasks[task.id] = task await self._queue.put(task) logger.info(f"Task {task.id} added to queue") return True async def get_task(self) -> Optional[Task]: """ Get task from queue Returns: Task or None """ if not self._initialized or not self._queue: await self.initialize() try: # Use timeout to avoid blocking indefinitely task = await asyncio.wait_for(self._queue.get(), timeout=1.0) return task except asyncio.TimeoutError: # Timeout is normal, just no tasks return None except Exception as e: logger.error(f"Error getting task: {e}", exc_info=True) await asyncio.sleep(1) # Small delay before retry return None async def update_task_status(self, task_id: int, status: TaskStatus, progress: int = None, error: str = None): """ Update task status Args: task_id: Task ID status: New status progress: Progress (0-100) error: Error message """ async with self._lock: if task_id in self._tasks: task = self._tasks[task_id] task.status = status if progress is not None: task.progress = progress if error: task.error_message = error logger.info(f"Task {task_id} updated: {status.value}") # Sync with database try: from shared.database.models import Task as DBTask from shared.database.session import get_async_session_local from datetime import datetime async with get_async_session_local()() as session: db_task = await session.get(DBTask, task_id) if db_task: db_task.status = status.value if progress is not None: db_task.progress = progress if error: # Limit error_message length (maximum 1000 characters) db_task.error_message = (error[:1000] if error and len(error) > 1000 else error) if status == TaskStatus.COMPLETED: db_task.completed_at = datetime.utcnow() db_task.updated_at = datetime.utcnow() await session.commit() logger.debug(f"Task {task_id} updated in DB: {status.value}") else: # Task not found in DB - create it for synchronization logger.warning(f"Task {task_id} not found in DB, creating record for synchronization") from bot.modules.task_scheduler.queue import task_queue from shared.database.user_helpers import ensure_user_exists task = await task_queue.get_task_by_id(task_id) if task: # Ensure user exists before creating task await ensure_user_exists(task.user_id, session) db_task = DBTask( id=task_id, user_id=task.user_id, task_type=task.task_type, status=status.value, url=task.url, progress=progress if progress is not None else 0, error_message=(error[:1000] if error and len(error) > 1000 else error), created_at=datetime.utcnow(), updated_at=datetime.utcnow() ) session.add(db_task) await session.commit() logger.info(f"Task {task_id} created in DB for synchronization") except Exception as e: logger.warning(f"Failed to update task {task_id} in DB: {e}", exc_info=True) # Try to rollback if session is still open try: async with get_async_session_local()() as session: await session.rollback() except: pass async def update_task_file_path(self, task_id: int, file_path: str): """ Update task file path Args: task_id: Task ID file_path: File path """ # Sync with database try: from shared.database.models import Task as DBTask from shared.database.session import get_async_session_local from datetime import datetime async with get_async_session_local()() as session: db_task = await session.get(DBTask, task_id) if db_task: db_task.file_path = file_path db_task.updated_at = datetime.utcnow() await session.commit() logger.debug(f"Task {task_id} file path updated in DB: {file_path}") else: logger.warning(f"Task {task_id} not found in DB for file_path update") except Exception as e: logger.warning(f"Failed to update task {task_id} file path in DB: {e}", exc_info=True) # Try to rollback if session is still open try: async with get_async_session_local()() as session: await session.rollback() except: pass async def get_task_by_id(self, task_id: int) -> Optional[Task]: """Get task by ID""" async with self._lock: return self._tasks.get(task_id) async def cancel_task(self, task_id: int) -> bool: """ Cancel task Args: task_id: Task ID Returns: True if successful """ async with self._lock: if task_id in self._tasks: task = self._tasks[task_id] # Can only cancel pending or processing tasks if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]: task.status = TaskStatus.CANCELLED task.error_message = "Cancelled by user" # Set cancellation event to interrupt download (thread-safe) with self._cancel_events_lock: if task_id in self._cancel_events: self._cancel_events[task_id].set() logger.info(f"Task {task_id} cancelled") return True return False def get_cancel_event(self, task_id: int) -> threading.Event: """ Get cancellation event for task (created if doesn't exist) Thread-safe method for use from different threads. Args: task_id: Task ID Returns: threading.Event for cancellation check """ with self._cancel_events_lock: if task_id not in self._cancel_events: self._cancel_events[task_id] = threading.Event() return self._cancel_events[task_id] def clear_cancel_event(self, task_id: int): """ Clear cancellation event after task completion Thread-safe method for use from different threads. """ with self._cancel_events_lock: if task_id in self._cancel_events: del self._cancel_events[task_id] async def get_user_tasks(self, user_id: int) -> list[Task]: """Get user tasks""" async with self._lock: return [task for task in self._tasks.values() if task.user_id == user_id] async def get_user_active_tasks_count(self, user_id: int) -> int: """Get count of active user tasks""" async with self._lock: active_statuses = [TaskStatus.PENDING, TaskStatus.PROCESSING] return sum(1 for task in self._tasks.values() if task.user_id == user_id and task.status in active_statuses) async def get_all_tasks(self) -> list[Task]: """Get all tasks""" async with self._lock: return list(self._tasks.values()) # Global task queue task_queue = TaskQueue()