This commit is contained in:
2026-04-19 11:33:11 +09:00
parent 867ae25fa0
commit da8d91b87f
43 changed files with 4044 additions and 15 deletions
@@ -0,0 +1,167 @@
from __future__ import annotations
import asyncio
import json
import time
import uuid
import logging
from typing import Any, Callable, Awaitable
from fastapi import WebSocket
from .db import TlsTestDB
from .schemas import ScanResult, Finding, ProgressMessage
logger = logging.getLogger("tls_test.runner")
MAX_CONCURRENT = 5
ReportProgress = Callable[[str, str, float, str], Awaitable[None]]
EngineFn = Callable[[str, ReportProgress, Callable[[Finding], Awaitable[None]]], Awaitable[ScanResult]]
class TlsJobQueue:
def __init__(self, db: TlsTestDB, engine: EngineFn):
self.db = db
self.engine = engine
self._queue: asyncio.Queue[tuple[str, str]] = asyncio.Queue()
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
self._subscribers: dict[str, set[WebSocket]] = {}
self._seq: dict[str, int] = {}
self._tasks: set[asyncio.Task] = set()
self._dispatcher_task: asyncio.Task | None = None
self._cleanup_task: asyncio.Task | None = None
self._closed = False
async def start(self) -> None:
self._dispatcher_task = asyncio.create_task(self._dispatcher())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop(self) -> None:
self._closed = True
if self._dispatcher_task:
self._dispatcher_task.cancel()
if self._cleanup_task:
self._cleanup_task.cancel()
for t in list(self._tasks):
t.cancel()
await asyncio.gather(*(t for t in self._tasks), return_exceptions=True)
def submit(self, target: str, client_ip: str | None) -> str:
test_id = str(uuid.uuid4())
self.db.create_job(test_id, target, client_ip)
self._queue.put_nowait((test_id, target))
return test_id
async def _dispatcher(self) -> None:
while not self._closed:
try:
test_id, target = await self._queue.get()
except asyncio.CancelledError:
return
task = asyncio.create_task(self._run_one(test_id, target))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
async def _run_one(self, test_id: str, target: str) -> None:
async with self._semaphore:
await self._broadcast(test_id, {"type": "started", "target": target})
self.db.update_status(test_id, "running", started_at=int(time.time()))
seq_ref = {"n": 0}
async def report_progress(phase: str, detail: str, progress: float, severity: str = "info") -> None:
seq_ref["n"] += 1
self.db.append_progress(test_id, seq_ref["n"], phase, detail, progress, severity)
await self._broadcast(
test_id,
{
"type": "progress",
"phase": phase,
"detail": detail,
"progress": progress,
"severity": severity,
},
)
async def report_finding(f: Finding) -> None:
await self._broadcast(
test_id,
{"type": "finding", "finding": f.to_dict()},
)
try:
result = await self.engine(target, report_progress, report_finding)
payload = result.to_dict()
self.db.update_status(
test_id,
"done",
finished_at=int(time.time()),
rank=result.rank,
score=result.score,
result_json=json.dumps(payload, ensure_ascii=False),
)
await self._broadcast(
test_id,
{
"type": "done",
"redirect": f"/tools/tls-test/results/{test_id}/",
"rank": result.rank,
"score": result.score,
},
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.exception("tls-test engine failed for %s", target)
self.db.update_status(
test_id,
"error",
finished_at=int(time.time()),
error_message=str(e),
)
await self._broadcast(
test_id,
{"type": "error", "message": str(e)},
)
finally:
await self._close_subscribers(test_id)
async def _cleanup_loop(self) -> None:
while not self._closed:
try:
self.db.delete_expired()
except Exception:
logger.exception("tls-test expired cleanup failed")
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
return
def add_subscriber(self, test_id: str, ws: WebSocket) -> None:
self._subscribers.setdefault(test_id, set()).add(ws)
def remove_subscriber(self, test_id: str, ws: WebSocket) -> None:
subs = self._subscribers.get(test_id)
if subs and ws in subs:
subs.discard(ws)
if not subs:
self._subscribers.pop(test_id, None)
async def _broadcast(self, test_id: str, payload: dict[str, Any]) -> None:
subs = list(self._subscribers.get(test_id, set()))
if not subs:
return
text = json.dumps(payload, ensure_ascii=False)
for ws in subs:
try:
await ws.send_text(text)
except Exception:
self.remove_subscriber(test_id, ws)
async def _close_subscribers(self, test_id: str) -> None:
subs = list(self._subscribers.get(test_id, set()))
for ws in subs:
try:
await ws.close()
except Exception:
pass
self._subscribers.pop(test_id, None)