Add source

This commit is contained in:
2025-12-04 00:12:56 +03:00
parent b75875df5e
commit 0cb7045e7a
75 changed files with 9055 additions and 0 deletions

View File

@@ -0,0 +1,297 @@
"""
Task queue
"""
import asyncio
from typing import Optional
from enum import Enum
from dataclasses import dataclass
from datetime import datetime
import threading
import logging
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""Task statuses"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class Task:
"""Task for execution"""
id: int
user_id: int
task_type: str
url: Optional[str] = None
status: TaskStatus = TaskStatus.PENDING
progress: int = 0
error_message: Optional[str] = None
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
class TaskQueue:
"""Task queue"""
def __init__(self):
self._queue: Optional[asyncio.Queue] = None
self._tasks: dict[int, Task] = {}
self._lock = asyncio.Lock()
self._initialized = False
# Dictionary to store cancellation events for each task
# Protected by threading.Lock, as access may be from different threads
self._cancel_events: dict[int, threading.Event] = {}
self._cancel_events_lock = threading.Lock()
async def initialize(self):
"""Initialize queue in current event loop"""
if not self._initialized:
self._queue = asyncio.Queue()
self._initialized = True
logger.info("Task queue initialized")
async def add_task(self, task: Task, check_duplicate_url: bool = True) -> bool:
"""
Add task to queue
Args:
task: Task to add
check_duplicate_url: Check for duplicate URLs for this user
Returns:
True if successful, False if duplicate
"""
if not self._initialized or not self._queue:
await self.initialize()
async with self._lock:
# Check for duplicate task_id
if task.id in self._tasks:
logger.warning(f"Task with ID {task.id} already exists, skipping duplicate")
return False
# Atomic check for duplicate URLs for this user
if check_duplicate_url and task.url:
url_normalized = task.url.strip()
for existing_task in self._tasks.values():
if (existing_task.user_id == task.user_id and
existing_task.url and existing_task.url.strip() == url_normalized and
existing_task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING] and
existing_task.id != task.id):
logger.warning(f"Task {task.id} with URL {url_normalized} already processing for user {task.user_id} (task {existing_task.id})")
return False # Block duplicate
self._tasks[task.id] = task
await self._queue.put(task)
logger.info(f"Task {task.id} added to queue")
return True
async def get_task(self) -> Optional[Task]:
"""
Get task from queue
Returns:
Task or None
"""
if not self._initialized or not self._queue:
await self.initialize()
try:
# Use timeout to avoid blocking indefinitely
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
return task
except asyncio.TimeoutError:
# Timeout is normal, just no tasks
return None
except Exception as e:
logger.error(f"Error getting task: {e}", exc_info=True)
await asyncio.sleep(1) # Small delay before retry
return None
async def update_task_status(self, task_id: int, status: TaskStatus, progress: int = None, error: str = None):
"""
Update task status
Args:
task_id: Task ID
status: New status
progress: Progress (0-100)
error: Error message
"""
async with self._lock:
if task_id in self._tasks:
task = self._tasks[task_id]
task.status = status
if progress is not None:
task.progress = progress
if error:
task.error_message = error
logger.info(f"Task {task_id} updated: {status.value}")
# Sync with database
try:
from shared.database.models import Task as DBTask
from shared.database.session import get_async_session_local
from datetime import datetime
async with get_async_session_local()() as session:
db_task = await session.get(DBTask, task_id)
if db_task:
db_task.status = status.value
if progress is not None:
db_task.progress = progress
if error:
# Limit error_message length (maximum 1000 characters)
db_task.error_message = (error[:1000] if error and len(error) > 1000 else error)
if status == TaskStatus.COMPLETED:
db_task.completed_at = datetime.utcnow()
db_task.updated_at = datetime.utcnow()
await session.commit()
logger.debug(f"Task {task_id} updated in DB: {status.value}")
else:
# Task not found in DB - create it for synchronization
logger.warning(f"Task {task_id} not found in DB, creating record for synchronization")
from bot.modules.task_scheduler.queue import task_queue
from shared.database.user_helpers import ensure_user_exists
task = await task_queue.get_task_by_id(task_id)
if task:
# Ensure user exists before creating task
await ensure_user_exists(task.user_id, session)
db_task = DBTask(
id=task_id,
user_id=task.user_id,
task_type=task.task_type,
status=status.value,
url=task.url,
progress=progress if progress is not None else 0,
error_message=(error[:1000] if error and len(error) > 1000 else error),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
session.add(db_task)
await session.commit()
logger.info(f"Task {task_id} created in DB for synchronization")
except Exception as e:
logger.warning(f"Failed to update task {task_id} in DB: {e}", exc_info=True)
# Try to rollback if session is still open
try:
async with get_async_session_local()() as session:
await session.rollback()
except:
pass
async def update_task_file_path(self, task_id: int, file_path: str):
"""
Update task file path
Args:
task_id: Task ID
file_path: File path
"""
# Sync with database
try:
from shared.database.models import Task as DBTask
from shared.database.session import get_async_session_local
from datetime import datetime
async with get_async_session_local()() as session:
db_task = await session.get(DBTask, task_id)
if db_task:
db_task.file_path = file_path
db_task.updated_at = datetime.utcnow()
await session.commit()
logger.debug(f"Task {task_id} file path updated in DB: {file_path}")
else:
logger.warning(f"Task {task_id} not found in DB for file_path update")
except Exception as e:
logger.warning(f"Failed to update task {task_id} file path in DB: {e}", exc_info=True)
# Try to rollback if session is still open
try:
async with get_async_session_local()() as session:
await session.rollback()
except:
pass
async def get_task_by_id(self, task_id: int) -> Optional[Task]:
"""Get task by ID"""
async with self._lock:
return self._tasks.get(task_id)
async def cancel_task(self, task_id: int) -> bool:
"""
Cancel task
Args:
task_id: Task ID
Returns:
True if successful
"""
async with self._lock:
if task_id in self._tasks:
task = self._tasks[task_id]
# Can only cancel pending or processing tasks
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
task.status = TaskStatus.CANCELLED
task.error_message = "Cancelled by user"
# Set cancellation event to interrupt download (thread-safe)
with self._cancel_events_lock:
if task_id in self._cancel_events:
self._cancel_events[task_id].set()
logger.info(f"Task {task_id} cancelled")
return True
return False
def get_cancel_event(self, task_id: int) -> threading.Event:
"""
Get cancellation event for task (created if doesn't exist)
Thread-safe method for use from different threads.
Args:
task_id: Task ID
Returns:
threading.Event for cancellation check
"""
with self._cancel_events_lock:
if task_id not in self._cancel_events:
self._cancel_events[task_id] = threading.Event()
return self._cancel_events[task_id]
def clear_cancel_event(self, task_id: int):
"""
Clear cancellation event after task completion
Thread-safe method for use from different threads.
"""
with self._cancel_events_lock:
if task_id in self._cancel_events:
del self._cancel_events[task_id]
async def get_user_tasks(self, user_id: int) -> list[Task]:
"""Get user tasks"""
async with self._lock:
return [task for task in self._tasks.values() if task.user_id == user_id]
async def get_user_active_tasks_count(self, user_id: int) -> int:
"""Get count of active user tasks"""
async with self._lock:
active_statuses = [TaskStatus.PENDING, TaskStatus.PROCESSING]
return sum(1 for task in self._tasks.values()
if task.user_id == user_id and task.status in active_statuses)
async def get_all_tasks(self) -> list[Task]:
"""Get all tasks"""
async with self._lock:
return list(self._tasks.values())
# Global task queue
task_queue = TaskQueue()