Files
2025-12-04 00:12:56 +03:00

298 lines
11 KiB
Python

"""
Task queue
"""
import asyncio
from typing import Optional
from enum import Enum
from dataclasses import dataclass
from datetime import datetime
import threading
import logging
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""Task statuses"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class Task:
"""Task for execution"""
id: int
user_id: int
task_type: str
url: Optional[str] = None
status: TaskStatus = TaskStatus.PENDING
progress: int = 0
error_message: Optional[str] = None
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
class TaskQueue:
"""Task queue"""
def __init__(self):
self._queue: Optional[asyncio.Queue] = None
self._tasks: dict[int, Task] = {}
self._lock = asyncio.Lock()
self._initialized = False
# Dictionary to store cancellation events for each task
# Protected by threading.Lock, as access may be from different threads
self._cancel_events: dict[int, threading.Event] = {}
self._cancel_events_lock = threading.Lock()
async def initialize(self):
"""Initialize queue in current event loop"""
if not self._initialized:
self._queue = asyncio.Queue()
self._initialized = True
logger.info("Task queue initialized")
async def add_task(self, task: Task, check_duplicate_url: bool = True) -> bool:
"""
Add task to queue
Args:
task: Task to add
check_duplicate_url: Check for duplicate URLs for this user
Returns:
True if successful, False if duplicate
"""
if not self._initialized or not self._queue:
await self.initialize()
async with self._lock:
# Check for duplicate task_id
if task.id in self._tasks:
logger.warning(f"Task with ID {task.id} already exists, skipping duplicate")
return False
# Atomic check for duplicate URLs for this user
if check_duplicate_url and task.url:
url_normalized = task.url.strip()
for existing_task in self._tasks.values():
if (existing_task.user_id == task.user_id and
existing_task.url and existing_task.url.strip() == url_normalized and
existing_task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING] and
existing_task.id != task.id):
logger.warning(f"Task {task.id} with URL {url_normalized} already processing for user {task.user_id} (task {existing_task.id})")
return False # Block duplicate
self._tasks[task.id] = task
await self._queue.put(task)
logger.info(f"Task {task.id} added to queue")
return True
async def get_task(self) -> Optional[Task]:
"""
Get task from queue
Returns:
Task or None
"""
if not self._initialized or not self._queue:
await self.initialize()
try:
# Use timeout to avoid blocking indefinitely
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
return task
except asyncio.TimeoutError:
# Timeout is normal, just no tasks
return None
except Exception as e:
logger.error(f"Error getting task: {e}", exc_info=True)
await asyncio.sleep(1) # Small delay before retry
return None
async def update_task_status(self, task_id: int, status: TaskStatus, progress: int = None, error: str = None):
"""
Update task status
Args:
task_id: Task ID
status: New status
progress: Progress (0-100)
error: Error message
"""
async with self._lock:
if task_id in self._tasks:
task = self._tasks[task_id]
task.status = status
if progress is not None:
task.progress = progress
if error:
task.error_message = error
logger.info(f"Task {task_id} updated: {status.value}")
# Sync with database
try:
from shared.database.models import Task as DBTask
from shared.database.session import get_async_session_local
from datetime import datetime
async with get_async_session_local()() as session:
db_task = await session.get(DBTask, task_id)
if db_task:
db_task.status = status.value
if progress is not None:
db_task.progress = progress
if error:
# Limit error_message length (maximum 1000 characters)
db_task.error_message = (error[:1000] if error and len(error) > 1000 else error)
if status == TaskStatus.COMPLETED:
db_task.completed_at = datetime.utcnow()
db_task.updated_at = datetime.utcnow()
await session.commit()
logger.debug(f"Task {task_id} updated in DB: {status.value}")
else:
# Task not found in DB - create it for synchronization
logger.warning(f"Task {task_id} not found in DB, creating record for synchronization")
from bot.modules.task_scheduler.queue import task_queue
from shared.database.user_helpers import ensure_user_exists
task = await task_queue.get_task_by_id(task_id)
if task:
# Ensure user exists before creating task
await ensure_user_exists(task.user_id, session)
db_task = DBTask(
id=task_id,
user_id=task.user_id,
task_type=task.task_type,
status=status.value,
url=task.url,
progress=progress if progress is not None else 0,
error_message=(error[:1000] if error and len(error) > 1000 else error),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
session.add(db_task)
await session.commit()
logger.info(f"Task {task_id} created in DB for synchronization")
except Exception as e:
logger.warning(f"Failed to update task {task_id} in DB: {e}", exc_info=True)
# Try to rollback if session is still open
try:
async with get_async_session_local()() as session:
await session.rollback()
except:
pass
async def update_task_file_path(self, task_id: int, file_path: str):
"""
Update task file path
Args:
task_id: Task ID
file_path: File path
"""
# Sync with database
try:
from shared.database.models import Task as DBTask
from shared.database.session import get_async_session_local
from datetime import datetime
async with get_async_session_local()() as session:
db_task = await session.get(DBTask, task_id)
if db_task:
db_task.file_path = file_path
db_task.updated_at = datetime.utcnow()
await session.commit()
logger.debug(f"Task {task_id} file path updated in DB: {file_path}")
else:
logger.warning(f"Task {task_id} not found in DB for file_path update")
except Exception as e:
logger.warning(f"Failed to update task {task_id} file path in DB: {e}", exc_info=True)
# Try to rollback if session is still open
try:
async with get_async_session_local()() as session:
await session.rollback()
except:
pass
async def get_task_by_id(self, task_id: int) -> Optional[Task]:
"""Get task by ID"""
async with self._lock:
return self._tasks.get(task_id)
async def cancel_task(self, task_id: int) -> bool:
"""
Cancel task
Args:
task_id: Task ID
Returns:
True if successful
"""
async with self._lock:
if task_id in self._tasks:
task = self._tasks[task_id]
# Can only cancel pending or processing tasks
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
task.status = TaskStatus.CANCELLED
task.error_message = "Cancelled by user"
# Set cancellation event to interrupt download (thread-safe)
with self._cancel_events_lock:
if task_id in self._cancel_events:
self._cancel_events[task_id].set()
logger.info(f"Task {task_id} cancelled")
return True
return False
def get_cancel_event(self, task_id: int) -> threading.Event:
"""
Get cancellation event for task (created if doesn't exist)
Thread-safe method for use from different threads.
Args:
task_id: Task ID
Returns:
threading.Event for cancellation check
"""
with self._cancel_events_lock:
if task_id not in self._cancel_events:
self._cancel_events[task_id] = threading.Event()
return self._cancel_events[task_id]
def clear_cancel_event(self, task_id: int):
"""
Clear cancellation event after task completion
Thread-safe method for use from different threads.
"""
with self._cancel_events_lock:
if task_id in self._cancel_events:
del self._cancel_events[task_id]
async def get_user_tasks(self, user_id: int) -> list[Task]:
"""Get user tasks"""
async with self._lock:
return [task for task in self._tasks.values() if task.user_id == user_id]
async def get_user_active_tasks_count(self, user_id: int) -> int:
"""Get count of active user tasks"""
async with self._lock:
active_statuses = [TaskStatus.PENDING, TaskStatus.PROCESSING]
return sum(1 for task in self._tasks.values()
if task.user_id == user_id and task.status in active_statuses)
async def get_all_tasks(self) -> list[Task]:
"""Get all tasks"""
async with self._lock:
return list(self._tasks.values())
# Global task queue
task_queue = TaskQueue()