--
This commit is contained in:
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from .db import TlsTestDB
|
||||
from .schemas import ScanResult, Finding, ProgressMessage
|
||||
|
||||
logger = logging.getLogger("tls_test.runner")
|
||||
|
||||
MAX_CONCURRENT = 5
|
||||
|
||||
ReportProgress = Callable[[str, str, float, str], Awaitable[None]]
|
||||
EngineFn = Callable[[str, ReportProgress, Callable[[Finding], Awaitable[None]]], Awaitable[ScanResult]]
|
||||
|
||||
|
||||
class TlsJobQueue:
|
||||
def __init__(self, db: TlsTestDB, engine: EngineFn):
|
||||
self.db = db
|
||||
self.engine = engine
|
||||
self._queue: asyncio.Queue[tuple[str, str]] = asyncio.Queue()
|
||||
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||||
self._subscribers: dict[str, set[WebSocket]] = {}
|
||||
self._seq: dict[str, int] = {}
|
||||
self._tasks: set[asyncio.Task] = set()
|
||||
self._dispatcher_task: asyncio.Task | None = None
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self._dispatcher_task = asyncio.create_task(self._dispatcher())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._closed = True
|
||||
if self._dispatcher_task:
|
||||
self._dispatcher_task.cancel()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
for t in list(self._tasks):
|
||||
t.cancel()
|
||||
await asyncio.gather(*(t for t in self._tasks), return_exceptions=True)
|
||||
|
||||
def submit(self, target: str, client_ip: str | None) -> str:
|
||||
test_id = str(uuid.uuid4())
|
||||
self.db.create_job(test_id, target, client_ip)
|
||||
self._queue.put_nowait((test_id, target))
|
||||
return test_id
|
||||
|
||||
async def _dispatcher(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
test_id, target = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
task = asyncio.create_task(self._run_one(test_id, target))
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard)
|
||||
|
||||
async def _run_one(self, test_id: str, target: str) -> None:
|
||||
async with self._semaphore:
|
||||
await self._broadcast(test_id, {"type": "started", "target": target})
|
||||
self.db.update_status(test_id, "running", started_at=int(time.time()))
|
||||
seq_ref = {"n": 0}
|
||||
|
||||
async def report_progress(phase: str, detail: str, progress: float, severity: str = "info") -> None:
|
||||
seq_ref["n"] += 1
|
||||
self.db.append_progress(test_id, seq_ref["n"], phase, detail, progress, severity)
|
||||
await self._broadcast(
|
||||
test_id,
|
||||
{
|
||||
"type": "progress",
|
||||
"phase": phase,
|
||||
"detail": detail,
|
||||
"progress": progress,
|
||||
"severity": severity,
|
||||
},
|
||||
)
|
||||
|
||||
async def report_finding(f: Finding) -> None:
|
||||
await self._broadcast(
|
||||
test_id,
|
||||
{"type": "finding", "finding": f.to_dict()},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.engine(target, report_progress, report_finding)
|
||||
payload = result.to_dict()
|
||||
self.db.update_status(
|
||||
test_id,
|
||||
"done",
|
||||
finished_at=int(time.time()),
|
||||
rank=result.rank,
|
||||
score=result.score,
|
||||
result_json=json.dumps(payload, ensure_ascii=False),
|
||||
)
|
||||
await self._broadcast(
|
||||
test_id,
|
||||
{
|
||||
"type": "done",
|
||||
"redirect": f"/tools/tls-test/results/{test_id}/",
|
||||
"rank": result.rank,
|
||||
"score": result.score,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("tls-test engine failed for %s", target)
|
||||
self.db.update_status(
|
||||
test_id,
|
||||
"error",
|
||||
finished_at=int(time.time()),
|
||||
error_message=str(e),
|
||||
)
|
||||
await self._broadcast(
|
||||
test_id,
|
||||
{"type": "error", "message": str(e)},
|
||||
)
|
||||
finally:
|
||||
await self._close_subscribers(test_id)
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
self.db.delete_expired()
|
||||
except Exception:
|
||||
logger.exception("tls-test expired cleanup failed")
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
def add_subscriber(self, test_id: str, ws: WebSocket) -> None:
|
||||
self._subscribers.setdefault(test_id, set()).add(ws)
|
||||
|
||||
def remove_subscriber(self, test_id: str, ws: WebSocket) -> None:
|
||||
subs = self._subscribers.get(test_id)
|
||||
if subs and ws in subs:
|
||||
subs.discard(ws)
|
||||
if not subs:
|
||||
self._subscribers.pop(test_id, None)
|
||||
|
||||
async def _broadcast(self, test_id: str, payload: dict[str, Any]) -> None:
|
||||
subs = list(self._subscribers.get(test_id, set()))
|
||||
if not subs:
|
||||
return
|
||||
text = json.dumps(payload, ensure_ascii=False)
|
||||
for ws in subs:
|
||||
try:
|
||||
await ws.send_text(text)
|
||||
except Exception:
|
||||
self.remove_subscriber(test_id, ws)
|
||||
|
||||
async def _close_subscribers(self, test_id: str) -> None:
|
||||
subs = list(self._subscribers.get(test_id, set()))
|
||||
for ws in subs:
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._subscribers.pop(test_id, None)
|
||||
Reference in New Issue
Block a user