This repository has been archived on 2026-04-27. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
website/src/nercone_website/tools/tls_test/runner.py
T
2026-04-19 11:33:11 +09:00

168 lines
6.0 KiB
Python

from __future__ import annotations
import asyncio
import json
import time
import uuid
import logging
from typing import Any, Callable, Awaitable
from fastapi import WebSocket
from .db import TlsTestDB
from .schemas import ScanResult, Finding, ProgressMessage
logger = logging.getLogger("tls_test.runner")
MAX_CONCURRENT = 5
ReportProgress = Callable[[str, str, float, str], Awaitable[None]]
EngineFn = Callable[[str, ReportProgress, Callable[[Finding], Awaitable[None]]], Awaitable[ScanResult]]
class TlsJobQueue:
def __init__(self, db: TlsTestDB, engine: EngineFn):
self.db = db
self.engine = engine
self._queue: asyncio.Queue[tuple[str, str]] = asyncio.Queue()
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
self._subscribers: dict[str, set[WebSocket]] = {}
self._seq: dict[str, int] = {}
self._tasks: set[asyncio.Task] = set()
self._dispatcher_task: asyncio.Task | None = None
self._cleanup_task: asyncio.Task | None = None
self._closed = False
async def start(self) -> None:
self._dispatcher_task = asyncio.create_task(self._dispatcher())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop(self) -> None:
self._closed = True
if self._dispatcher_task:
self._dispatcher_task.cancel()
if self._cleanup_task:
self._cleanup_task.cancel()
for t in list(self._tasks):
t.cancel()
await asyncio.gather(*(t for t in self._tasks), return_exceptions=True)
def submit(self, target: str, client_ip: str | None) -> str:
test_id = str(uuid.uuid4())
self.db.create_job(test_id, target, client_ip)
self._queue.put_nowait((test_id, target))
return test_id
async def _dispatcher(self) -> None:
while not self._closed:
try:
test_id, target = await self._queue.get()
except asyncio.CancelledError:
return
task = asyncio.create_task(self._run_one(test_id, target))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
async def _run_one(self, test_id: str, target: str) -> None:
async with self._semaphore:
await self._broadcast(test_id, {"type": "started", "target": target})
self.db.update_status(test_id, "running", started_at=int(time.time()))
seq_ref = {"n": 0}
async def report_progress(phase: str, detail: str, progress: float, severity: str = "info") -> None:
seq_ref["n"] += 1
self.db.append_progress(test_id, seq_ref["n"], phase, detail, progress, severity)
await self._broadcast(
test_id,
{
"type": "progress",
"phase": phase,
"detail": detail,
"progress": progress,
"severity": severity,
},
)
async def report_finding(f: Finding) -> None:
await self._broadcast(
test_id,
{"type": "finding", "finding": f.to_dict()},
)
try:
result = await self.engine(target, report_progress, report_finding)
payload = result.to_dict()
self.db.update_status(
test_id,
"done",
finished_at=int(time.time()),
rank=result.rank,
score=result.score,
result_json=json.dumps(payload, ensure_ascii=False),
)
await self._broadcast(
test_id,
{
"type": "done",
"redirect": f"/tools/tls-test/results/{test_id}/",
"rank": result.rank,
"score": result.score,
},
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.exception("tls-test engine failed for %s", target)
self.db.update_status(
test_id,
"error",
finished_at=int(time.time()),
error_message=str(e),
)
await self._broadcast(
test_id,
{"type": "error", "message": str(e)},
)
finally:
await self._close_subscribers(test_id)
async def _cleanup_loop(self) -> None:
while not self._closed:
try:
self.db.delete_expired()
except Exception:
logger.exception("tls-test expired cleanup failed")
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
return
def add_subscriber(self, test_id: str, ws: WebSocket) -> None:
self._subscribers.setdefault(test_id, set()).add(ws)
def remove_subscriber(self, test_id: str, ws: WebSocket) -> None:
subs = self._subscribers.get(test_id)
if subs and ws in subs:
subs.discard(ws)
if not subs:
self._subscribers.pop(test_id, None)
async def _broadcast(self, test_id: str, payload: dict[str, Any]) -> None:
subs = list(self._subscribers.get(test_id, set()))
if not subs:
return
text = json.dumps(payload, ensure_ascii=False)
for ws in subs:
try:
await ws.send_text(text)
except Exception:
self.remove_subscriber(test_id, ws)
async def _close_subscribers(self, test_id: str) -> None:
subs = list(self._subscribers.get(test_id, set()))
for ws in subs:
try:
await ws.close()
except Exception:
pass
self._subscribers.pop(test_id, None)