Add source
This commit is contained in:
297
bot/modules/task_scheduler/queue.py
Normal file
297
bot/modules/task_scheduler/queue.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Task queue
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task statuses"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task for execution"""
|
||||
id: int
|
||||
user_id: int
|
||||
task_type: str
|
||||
url: Optional[str] = None
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
progress: int = 0
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""Task queue"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: Optional[asyncio.Queue] = None
|
||||
self._tasks: dict[int, Task] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
# Dictionary to store cancellation events for each task
|
||||
# Protected by threading.Lock, as access may be from different threads
|
||||
self._cancel_events: dict[int, threading.Event] = {}
|
||||
self._cancel_events_lock = threading.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize queue in current event loop"""
|
||||
if not self._initialized:
|
||||
self._queue = asyncio.Queue()
|
||||
self._initialized = True
|
||||
logger.info("Task queue initialized")
|
||||
|
||||
async def add_task(self, task: Task, check_duplicate_url: bool = True) -> bool:
|
||||
"""
|
||||
Add task to queue
|
||||
|
||||
Args:
|
||||
task: Task to add
|
||||
check_duplicate_url: Check for duplicate URLs for this user
|
||||
|
||||
Returns:
|
||||
True if successful, False if duplicate
|
||||
"""
|
||||
if not self._initialized or not self._queue:
|
||||
await self.initialize()
|
||||
|
||||
async with self._lock:
|
||||
# Check for duplicate task_id
|
||||
if task.id in self._tasks:
|
||||
logger.warning(f"Task with ID {task.id} already exists, skipping duplicate")
|
||||
return False
|
||||
|
||||
# Atomic check for duplicate URLs for this user
|
||||
if check_duplicate_url and task.url:
|
||||
url_normalized = task.url.strip()
|
||||
for existing_task in self._tasks.values():
|
||||
if (existing_task.user_id == task.user_id and
|
||||
existing_task.url and existing_task.url.strip() == url_normalized and
|
||||
existing_task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING] and
|
||||
existing_task.id != task.id):
|
||||
logger.warning(f"Task {task.id} with URL {url_normalized} already processing for user {task.user_id} (task {existing_task.id})")
|
||||
return False # Block duplicate
|
||||
|
||||
self._tasks[task.id] = task
|
||||
await self._queue.put(task)
|
||||
logger.info(f"Task {task.id} added to queue")
|
||||
return True
|
||||
|
||||
async def get_task(self) -> Optional[Task]:
|
||||
"""
|
||||
Get task from queue
|
||||
|
||||
Returns:
|
||||
Task or None
|
||||
"""
|
||||
if not self._initialized or not self._queue:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
# Use timeout to avoid blocking indefinitely
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
return task
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is normal, just no tasks
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task: {e}", exc_info=True)
|
||||
await asyncio.sleep(1) # Small delay before retry
|
||||
return None
|
||||
|
||||
async def update_task_status(self, task_id: int, status: TaskStatus, progress: int = None, error: str = None):
|
||||
"""
|
||||
Update task status
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
status: New status
|
||||
progress: Progress (0-100)
|
||||
error: Error message
|
||||
"""
|
||||
async with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if error:
|
||||
task.error_message = error
|
||||
logger.info(f"Task {task_id} updated: {status.value}")
|
||||
|
||||
# Sync with database
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from datetime import datetime
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
db_task.status = status.value
|
||||
if progress is not None:
|
||||
db_task.progress = progress
|
||||
if error:
|
||||
# Limit error_message length (maximum 1000 characters)
|
||||
db_task.error_message = (error[:1000] if error and len(error) > 1000 else error)
|
||||
if status == TaskStatus.COMPLETED:
|
||||
db_task.completed_at = datetime.utcnow()
|
||||
db_task.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.debug(f"Task {task_id} updated in DB: {status.value}")
|
||||
else:
|
||||
# Task not found in DB - create it for synchronization
|
||||
logger.warning(f"Task {task_id} not found in DB, creating record for synchronization")
|
||||
from bot.modules.task_scheduler.queue import task_queue
|
||||
from shared.database.user_helpers import ensure_user_exists
|
||||
task = await task_queue.get_task_by_id(task_id)
|
||||
if task:
|
||||
# Ensure user exists before creating task
|
||||
await ensure_user_exists(task.user_id, session)
|
||||
|
||||
db_task = DBTask(
|
||||
id=task_id,
|
||||
user_id=task.user_id,
|
||||
task_type=task.task_type,
|
||||
status=status.value,
|
||||
url=task.url,
|
||||
progress=progress if progress is not None else 0,
|
||||
error_message=(error[:1000] if error and len(error) > 1000 else error),
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
session.add(db_task)
|
||||
await session.commit()
|
||||
logger.info(f"Task {task_id} created in DB for synchronization")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update task {task_id} in DB: {e}", exc_info=True)
|
||||
# Try to rollback if session is still open
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
await session.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def update_task_file_path(self, task_id: int, file_path: str):
|
||||
"""
|
||||
Update task file path
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
file_path: File path
|
||||
"""
|
||||
# Sync with database
|
||||
try:
|
||||
from shared.database.models import Task as DBTask
|
||||
from shared.database.session import get_async_session_local
|
||||
from datetime import datetime
|
||||
async with get_async_session_local()() as session:
|
||||
db_task = await session.get(DBTask, task_id)
|
||||
if db_task:
|
||||
db_task.file_path = file_path
|
||||
db_task.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
logger.debug(f"Task {task_id} file path updated in DB: {file_path}")
|
||||
else:
|
||||
logger.warning(f"Task {task_id} not found in DB for file_path update")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update task {task_id} file path in DB: {e}", exc_info=True)
|
||||
# Try to rollback if session is still open
|
||||
try:
|
||||
async with get_async_session_local()() as session:
|
||||
await session.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def get_task_by_id(self, task_id: int) -> Optional[Task]:
|
||||
"""Get task by ID"""
|
||||
async with self._lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
"""
|
||||
Cancel task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
async with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
# Can only cancel pending or processing tasks
|
||||
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.error_message = "Cancelled by user"
|
||||
# Set cancellation event to interrupt download (thread-safe)
|
||||
with self._cancel_events_lock:
|
||||
if task_id in self._cancel_events:
|
||||
self._cancel_events[task_id].set()
|
||||
logger.info(f"Task {task_id} cancelled")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cancel_event(self, task_id: int) -> threading.Event:
|
||||
"""
|
||||
Get cancellation event for task (created if doesn't exist)
|
||||
Thread-safe method for use from different threads.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
threading.Event for cancellation check
|
||||
"""
|
||||
with self._cancel_events_lock:
|
||||
if task_id not in self._cancel_events:
|
||||
self._cancel_events[task_id] = threading.Event()
|
||||
return self._cancel_events[task_id]
|
||||
|
||||
def clear_cancel_event(self, task_id: int):
|
||||
"""
|
||||
Clear cancellation event after task completion
|
||||
Thread-safe method for use from different threads.
|
||||
"""
|
||||
with self._cancel_events_lock:
|
||||
if task_id in self._cancel_events:
|
||||
del self._cancel_events[task_id]
|
||||
|
||||
async def get_user_tasks(self, user_id: int) -> list[Task]:
|
||||
"""Get user tasks"""
|
||||
async with self._lock:
|
||||
return [task for task in self._tasks.values() if task.user_id == user_id]
|
||||
|
||||
async def get_user_active_tasks_count(self, user_id: int) -> int:
|
||||
"""Get count of active user tasks"""
|
||||
async with self._lock:
|
||||
active_statuses = [TaskStatus.PENDING, TaskStatus.PROCESSING]
|
||||
return sum(1 for task in self._tasks.values()
|
||||
if task.user_id == user_id and task.status in active_statuses)
|
||||
|
||||
async def get_all_tasks(self) -> list[Task]:
|
||||
"""Get all tasks"""
|
||||
async with self._lock:
|
||||
return list(self._tasks.values())
|
||||
|
||||
|
||||
# Global task queue
|
||||
task_queue = TaskQueue()
|
||||
|
||||
Reference in New Issue
Block a user