from __future__ import annotations import asyncio import json import time import uuid import logging from typing import Any, Callable, Awaitable from fastapi import WebSocket from .db import TlsTestDB from .schemas import ScanResult, Finding, ProgressMessage logger = logging.getLogger("tls_test.runner") MAX_CONCURRENT = 5 ReportProgress = Callable[[str, str, float, str], Awaitable[None]] EngineFn = Callable[[str, ReportProgress, Callable[[Finding], Awaitable[None]]], Awaitable[ScanResult]] class TlsJobQueue: def __init__(self, db: TlsTestDB, engine: EngineFn): self.db = db self.engine = engine self._queue: asyncio.Queue[tuple[str, str]] = asyncio.Queue() self._semaphore = asyncio.Semaphore(MAX_CONCURRENT) self._subscribers: dict[str, set[WebSocket]] = {} self._seq: dict[str, int] = {} self._tasks: set[asyncio.Task] = set() self._dispatcher_task: asyncio.Task | None = None self._cleanup_task: asyncio.Task | None = None self._closed = False async def start(self) -> None: self._dispatcher_task = asyncio.create_task(self._dispatcher()) self._cleanup_task = asyncio.create_task(self._cleanup_loop()) async def stop(self) -> None: self._closed = True if self._dispatcher_task: self._dispatcher_task.cancel() if self._cleanup_task: self._cleanup_task.cancel() for t in list(self._tasks): t.cancel() await asyncio.gather(*(t for t in self._tasks), return_exceptions=True) def submit(self, target: str, client_ip: str | None) -> str: test_id = str(uuid.uuid4()) self.db.create_job(test_id, target, client_ip) self._queue.put_nowait((test_id, target)) return test_id async def _dispatcher(self) -> None: while not self._closed: try: test_id, target = await self._queue.get() except asyncio.CancelledError: return task = asyncio.create_task(self._run_one(test_id, target)) self._tasks.add(task) task.add_done_callback(self._tasks.discard) async def _run_one(self, test_id: str, target: str) -> None: async with self._semaphore: await self._broadcast(test_id, {"type": "started", "target": target}) self.db.update_status(test_id, "running", started_at=int(time.time())) seq_ref = {"n": 0} async def report_progress(phase: str, detail: str, progress: float, severity: str = "info") -> None: seq_ref["n"] += 1 self.db.append_progress(test_id, seq_ref["n"], phase, detail, progress, severity) await self._broadcast( test_id, { "type": "progress", "phase": phase, "detail": detail, "progress": progress, "severity": severity, }, ) async def report_finding(f: Finding) -> None: await self._broadcast( test_id, {"type": "finding", "finding": f.to_dict()}, ) try: result = await self.engine(target, report_progress, report_finding) payload = result.to_dict() self.db.update_status( test_id, "done", finished_at=int(time.time()), rank=result.rank, score=result.score, result_json=json.dumps(payload, ensure_ascii=False), ) await self._broadcast( test_id, { "type": "done", "redirect": f"/tools/tls-test/results/{test_id}/", "rank": result.rank, "score": result.score, }, ) except asyncio.CancelledError: raise except Exception as e: logger.exception("tls-test engine failed for %s", target) self.db.update_status( test_id, "error", finished_at=int(time.time()), error_message=str(e), ) await self._broadcast( test_id, {"type": "error", "message": str(e)}, ) finally: await self._close_subscribers(test_id) async def _cleanup_loop(self) -> None: while not self._closed: try: self.db.delete_expired() except Exception: logger.exception("tls-test expired cleanup failed") try: await asyncio.sleep(3600) except asyncio.CancelledError: return def add_subscriber(self, test_id: str, ws: WebSocket) -> None: self._subscribers.setdefault(test_id, set()).add(ws) def remove_subscriber(self, test_id: str, ws: WebSocket) -> None: subs = self._subscribers.get(test_id) if subs and ws in subs: subs.discard(ws) if not subs: self._subscribers.pop(test_id, None) async def _broadcast(self, test_id: str, payload: dict[str, Any]) -> None: subs = list(self._subscribers.get(test_id, set())) if not subs: return text = json.dumps(payload, ensure_ascii=False) for ws in subs: try: await ws.send_text(text) except Exception: self.remove_subscriber(test_id, ws) async def _close_subscribers(self, test_id: str) -> None: subs = list(self._subscribers.get(test_id, set())) for ws in subs: try: await ws.close() except Exception: pass self._subscribers.pop(test_id, None)