import rjsmin import rcssmin import subprocess from fastapi import Response from fastapi.responses import PlainTextResponse from starlette.types import Scope, ASGIApp, Receive, Send from .logger import log_access server_version = subprocess.run(["/usr/bin/git", "rev-parse", "--short", "HEAD"], text=True, capture_output=True).stdout.strip() onion_hostname = "4sbb7xhdn4meuesnqvcreewk6sjnvchrsx4lpnxmnjhz2soat74finid.onion" hostnames = ["localhost", "nercone.dev", "d-g-c.net", "diamondgotcat.net", onion_hostname] class Middleware: def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send): if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return if scope["type"] == "websocket": await self.app(scope, receive, send) return headers = dict(scope.get("headers", [])) hostname = headers.get(b"host", b"").decode().split(":")[0].strip() scope["log"] = log_access(scope) if not any([hostname.endswith(candidate) for candidate in hostnames]): response = PlainTextResponse("許可されていないホスト名でのアクセスです。", status_code=400) await self._send(response, scope, receive, send) return hostname_parts = hostname.split(".") if hostname_parts[1:] == ["localhost"]: subdomain = ".".join(hostname_parts[:-1]) else: subdomain = ".".join(hostname_parts[:-2]) body = await self._read_body(receive) async def cached_receive(): return {"type": "http.request", "body": body, "more_body": False} if subdomain not in ["", "www"]: original_path = scope["path"] if scope["path"].strip() else "/" subdomain_path = f"/{'/'.join(subdomain.split('.')[::-1])}{original_path}" response = await self._get_response(scope, cached_receive, subdomain_path) if response.status_code < 400: await self._send(response, scope, cached_receive, send) return response = await self._get_response(scope, cached_receive, original_path) await self._send(response, scope, cached_receive, send) else: response = await self._get_response(scope, cached_receive, scope["path"]) await self._send(response, scope, cached_receive, send) async def _get_response(self, scope: Scope, receive: Receive, path: str) -> Response: new_scope = dict(scope, path=path) status_code = 200 resp_headers = [] body_parts = [] async def capture_send(message): nonlocal status_code, resp_headers if message["type"] == "http.response.start": status_code = message["status"] resp_headers = message.get("headers", []) elif message["type"] == "http.response.body": body_parts.append(message.get("body", b"")) body = await self._read_body(receive) async def cached_receive(): return {"type": "http.request", "body": body, "more_body": False} await self.app(new_scope, cached_receive, capture_send) response = Response( content=b"".join(body_parts), status_code=status_code, ) if response.status_code == 404 and path != "/" and path.endswith("/"): return await self._get_response(scope, cached_receive, path.rstrip("/")) for k, v in resp_headers: response.headers.raw.append((k, v)) return response async def _read_body(self, receive: Receive) -> bytes: body = b"" while True: message = await receive() body += message.get("body", b"") if not message.get("more_body", False): break return body async def _send(self, response: Response, scope, receive, send): content_type = response.headers.get("content-type", "") response.headers["Server"] = f"nercone.dev ({server_version})" response.headers["Onion-Location"] = f"http://{onion_hostname}/" if "access-control-allow-origin" not in response.headers: response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" if any(content_type.startswith(t) for t in ["text/html", "text/css", "text/javascript", "application/javascript"]): response.headers["Cache-Control"] = "no-cache" else: response.headers["Cache-Control"] = "public, max-age=3600" if "text/css" in content_type: try: response.body = rcssmin.cssmin(response.body.decode("utf-8", errors="replace")).encode("utf-8") except Exception: pass elif any(content_type.startswith(t) for t in ["text/javascript", "application/javascript"]): try: response.body = rjsmin.jsmin(response.body.decode("utf-8", errors="replace")).encode("utf-8") except Exception: pass response.headers["Content-Length"] = str(len(response.body)) await response(scope, receive, send)