168 lines
6.0 KiB
Python
168 lines
6.0 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
import logging
|
|
from typing import Any, Callable, Awaitable
|
|
|
|
from fastapi import WebSocket
|
|
|
|
from .db import TlsTestDB
|
|
from .schemas import ScanResult, Finding, ProgressMessage
|
|
|
|
logger = logging.getLogger("tls_test.runner")
|
|
|
|
MAX_CONCURRENT = 5
|
|
|
|
ReportProgress = Callable[[str, str, float, str], Awaitable[None]]
|
|
EngineFn = Callable[[str, ReportProgress, Callable[[Finding], Awaitable[None]]], Awaitable[ScanResult]]
|
|
|
|
|
|
class TlsJobQueue:
|
|
def __init__(self, db: TlsTestDB, engine: EngineFn):
|
|
self.db = db
|
|
self.engine = engine
|
|
self._queue: asyncio.Queue[tuple[str, str]] = asyncio.Queue()
|
|
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
|
self._subscribers: dict[str, set[WebSocket]] = {}
|
|
self._seq: dict[str, int] = {}
|
|
self._tasks: set[asyncio.Task] = set()
|
|
self._dispatcher_task: asyncio.Task | None = None
|
|
self._cleanup_task: asyncio.Task | None = None
|
|
self._closed = False
|
|
|
|
async def start(self) -> None:
|
|
self._dispatcher_task = asyncio.create_task(self._dispatcher())
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
|
|
async def stop(self) -> None:
|
|
self._closed = True
|
|
if self._dispatcher_task:
|
|
self._dispatcher_task.cancel()
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
for t in list(self._tasks):
|
|
t.cancel()
|
|
await asyncio.gather(*(t for t in self._tasks), return_exceptions=True)
|
|
|
|
def submit(self, target: str, client_ip: str | None) -> str:
|
|
test_id = str(uuid.uuid4())
|
|
self.db.create_job(test_id, target, client_ip)
|
|
self._queue.put_nowait((test_id, target))
|
|
return test_id
|
|
|
|
async def _dispatcher(self) -> None:
|
|
while not self._closed:
|
|
try:
|
|
test_id, target = await self._queue.get()
|
|
except asyncio.CancelledError:
|
|
return
|
|
task = asyncio.create_task(self._run_one(test_id, target))
|
|
self._tasks.add(task)
|
|
task.add_done_callback(self._tasks.discard)
|
|
|
|
async def _run_one(self, test_id: str, target: str) -> None:
|
|
async with self._semaphore:
|
|
await self._broadcast(test_id, {"type": "started", "target": target})
|
|
self.db.update_status(test_id, "running", started_at=int(time.time()))
|
|
seq_ref = {"n": 0}
|
|
|
|
async def report_progress(phase: str, detail: str, progress: float, severity: str = "info") -> None:
|
|
seq_ref["n"] += 1
|
|
self.db.append_progress(test_id, seq_ref["n"], phase, detail, progress, severity)
|
|
await self._broadcast(
|
|
test_id,
|
|
{
|
|
"type": "progress",
|
|
"phase": phase,
|
|
"detail": detail,
|
|
"progress": progress,
|
|
"severity": severity,
|
|
},
|
|
)
|
|
|
|
async def report_finding(f: Finding) -> None:
|
|
await self._broadcast(
|
|
test_id,
|
|
{"type": "finding", "finding": f.to_dict()},
|
|
)
|
|
|
|
try:
|
|
result = await self.engine(target, report_progress, report_finding)
|
|
payload = result.to_dict()
|
|
self.db.update_status(
|
|
test_id,
|
|
"done",
|
|
finished_at=int(time.time()),
|
|
rank=result.rank,
|
|
score=result.score,
|
|
result_json=json.dumps(payload, ensure_ascii=False),
|
|
)
|
|
await self._broadcast(
|
|
test_id,
|
|
{
|
|
"type": "done",
|
|
"redirect": f"/tools/tls-test/results/{test_id}/",
|
|
"rank": result.rank,
|
|
"score": result.score,
|
|
},
|
|
)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("tls-test engine failed for %s", target)
|
|
self.db.update_status(
|
|
test_id,
|
|
"error",
|
|
finished_at=int(time.time()),
|
|
error_message=str(e),
|
|
)
|
|
await self._broadcast(
|
|
test_id,
|
|
{"type": "error", "message": str(e)},
|
|
)
|
|
finally:
|
|
await self._close_subscribers(test_id)
|
|
|
|
async def _cleanup_loop(self) -> None:
|
|
while not self._closed:
|
|
try:
|
|
self.db.delete_expired()
|
|
except Exception:
|
|
logger.exception("tls-test expired cleanup failed")
|
|
try:
|
|
await asyncio.sleep(3600)
|
|
except asyncio.CancelledError:
|
|
return
|
|
|
|
def add_subscriber(self, test_id: str, ws: WebSocket) -> None:
|
|
self._subscribers.setdefault(test_id, set()).add(ws)
|
|
|
|
def remove_subscriber(self, test_id: str, ws: WebSocket) -> None:
|
|
subs = self._subscribers.get(test_id)
|
|
if subs and ws in subs:
|
|
subs.discard(ws)
|
|
if not subs:
|
|
self._subscribers.pop(test_id, None)
|
|
|
|
async def _broadcast(self, test_id: str, payload: dict[str, Any]) -> None:
|
|
subs = list(self._subscribers.get(test_id, set()))
|
|
if not subs:
|
|
return
|
|
text = json.dumps(payload, ensure_ascii=False)
|
|
for ws in subs:
|
|
try:
|
|
await ws.send_text(text)
|
|
except Exception:
|
|
self.remove_subscriber(test_id, ws)
|
|
|
|
async def _close_subscribers(self, test_id: str) -> None:
|
|
subs = list(self._subscribers.get(test_id, set()))
|
|
for ws in subs:
|
|
try:
|
|
await ws.close()
|
|
except Exception:
|
|
pass
|
|
self._subscribers.pop(test_id, None)
|