This repository has been archived on 2026-04-27. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
website/src/nercone_website/tools/tls_test/protocol/wire.py
T
2026-04-19 11:33:11 +09:00

261 lines
8.1 KiB
Python

from __future__ import annotations
import os
import struct
from dataclasses import dataclass
from typing import Any
from . import constants as C
def _u8(x: int) -> bytes:
return struct.pack("!B", x)
def _u16(x: int) -> bytes:
return struct.pack("!H", x)
def _u24(x: int) -> bytes:
return struct.pack("!I", x)[1:]
def build_extension(ext_type: int, data: bytes) -> bytes:
return _u16(ext_type) + _u16(len(data)) + data
def ext_server_name(hostname: str) -> bytes:
name = hostname.encode("idna")
entry = _u8(0) + _u16(len(name)) + name
lst = _u16(len(entry)) + entry
return build_extension(C.EXT_SERVER_NAME, lst)
def ext_supported_versions_client(versions: list[int]) -> bytes:
body = _u8(2 * len(versions)) + b"".join(_u16(v) for v in versions)
return build_extension(C.EXT_SUPPORTED_VERSIONS, body)
def ext_supported_groups(groups: list[int]) -> bytes:
body = _u16(2 * len(groups)) + b"".join(_u16(g) for g in groups)
return build_extension(C.EXT_SUPPORTED_GROUPS, body)
def ext_signature_algorithms(algos: list[int]) -> bytes:
body = _u16(2 * len(algos)) + b"".join(_u16(a) for a in algos)
return build_extension(C.EXT_SIGNATURE_ALGORITHMS, body)
def ext_ec_point_formats() -> bytes:
return build_extension(C.EXT_EC_POINT_FORMATS, _u8(1) + _u8(0))
def ext_key_share_empty() -> bytes:
return build_extension(C.EXT_KEY_SHARE, _u16(0))
def ext_psk_key_exchange_modes() -> bytes:
return build_extension(C.EXT_PSK_KEY_EXCHANGE_MODES, _u8(1) + _u8(1))
def ext_alpn(protos: list[bytes]) -> bytes:
inner = b"".join(_u8(len(p)) + p for p in protos)
return build_extension(C.EXT_ALPN, _u16(len(inner)) + inner)
def ext_heartbeat_enabled() -> bytes:
return build_extension(C.EXT_HEARTBEAT, _u8(1))
def ext_renegotiation_info_empty() -> bytes:
return build_extension(C.EXT_RENEGOTIATION_INFO, _u8(0))
def ext_status_request() -> bytes:
# status_type=OCSP(1) + empty responder_id_list + empty extensions
body = _u8(1) + _u16(0) + _u16(0)
return build_extension(C.EXT_STATUS_REQUEST, body)
def ext_signed_cert_timestamp() -> bytes:
return build_extension(C.EXT_SIGNED_CERT_TIMESTAMP, b"")
def ext_extended_master_secret() -> bytes:
return build_extension(C.EXT_EXTENDED_MASTER_SECRET, b"")
def build_client_hello(
record_version: int,
client_hello_version: int,
hostname: str | None,
cipher_suites: list[int],
extensions: bytes = b"",
compression: bytes = b"\x01\x00", # 1 length, null method
) -> bytes:
random_bytes = os.urandom(32)
session_id = b""
cs_bytes = b"".join(_u16(c) for c in cipher_suites)
body = (
_u16(client_hello_version)
+ random_bytes
+ _u8(len(session_id))
+ session_id
+ _u16(len(cs_bytes))
+ cs_bytes
+ compression
)
if extensions:
body += _u16(len(extensions)) + extensions
handshake = _u8(C.HS_CLIENT_HELLO) + _u24(len(body)) + body
record = _u8(C.CT_HANDSHAKE) + _u16(record_version) + _u16(len(handshake)) + handshake
return record
def default_ch_extensions(hostname: str, versions: list[int], groups: list[int] | None = None) -> bytes:
groups = groups or [0x001d, 0x0017, 0x0018]
parts = []
if hostname and not _is_ip_literal(hostname):
parts.append(ext_server_name(hostname))
parts.append(ext_ec_point_formats())
parts.append(ext_supported_groups(groups))
parts.append(ext_signature_algorithms([
0x0403, 0x0804, 0x0401, 0x0503, 0x0805, 0x0501,
0x0603, 0x0806, 0x0601, 0x0807, 0x0808,
]))
parts.append(ext_renegotiation_info_empty())
parts.append(ext_signed_cert_timestamp())
parts.append(ext_status_request())
parts.append(ext_extended_master_secret())
parts.append(ext_alpn([b"h2", b"http/1.1"]))
if C.TLS_1_3 in versions:
parts.append(ext_supported_versions_client(versions))
parts.append(ext_psk_key_exchange_modes())
parts.append(ext_key_share_empty())
exts = b"".join(parts)
return _u16(len(exts)) + exts if False else exts
def _is_ip_literal(host: str) -> bool:
import ipaddress
try:
ipaddress.ip_address(host.strip("[]"))
return True
except ValueError:
return False
@dataclass
class ParsedServerHello:
record_version: int
server_version: int
cipher_suite: int | None
alert: tuple[int, int] | None
raw_record: bytes
handshake_type: int | None
server_random: bytes | None = None
session_id: bytes | None = None
extensions: dict[int, bytes] | None = None
negotiated_version: int | None = None
key_share_group: int | None = None
def parse_server_response(data: bytes) -> ParsedServerHello | None:
"""Parse the first TLS record. Returns None if insufficient data or garbage."""
if len(data) < 5:
return None
ct = data[0]
rec_ver = (data[1] << 8) | data[2]
rec_len = (data[3] << 8) | data[4]
body = data[5:5 + rec_len]
if ct == C.CT_ALERT and len(body) >= 2:
return ParsedServerHello(
record_version=rec_ver,
server_version=0,
cipher_suite=None,
alert=(body[0], body[1]),
raw_record=data,
handshake_type=None,
)
if ct != C.CT_HANDSHAKE or len(body) < 4:
return ParsedServerHello(
record_version=rec_ver,
server_version=0,
cipher_suite=None,
alert=None,
raw_record=data,
handshake_type=None,
)
hs_type = body[0]
hs_len = (body[1] << 16) | (body[2] << 8) | body[3]
hs = body[4:4 + hs_len]
if hs_type != C.HS_SERVER_HELLO or len(hs) < 38:
return ParsedServerHello(
record_version=rec_ver,
server_version=0,
cipher_suite=None,
alert=None,
raw_record=data,
handshake_type=hs_type,
)
server_version = (hs[0] << 8) | hs[1]
server_random = hs[2:34]
sess_len = hs[34]
off = 35 + sess_len
if len(hs) < off + 3:
return None
session_id = hs[35:35 + sess_len]
cs = (hs[off] << 8) | hs[off + 1]
off += 2
# comp method (1 byte)
off += 1
ext_map: dict[int, bytes] = {}
negotiated_version = server_version
key_share_group = None
if off + 2 <= len(hs):
ext_total = (hs[off] << 8) | hs[off + 1]
off += 2
ext_end = off + ext_total
while off + 4 <= ext_end:
et = (hs[off] << 8) | hs[off + 1]
el = (hs[off + 2] << 8) | hs[off + 3]
off += 4
ext_data = hs[off:off + el]
off += el
ext_map[et] = ext_data
if et == C.EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2:
negotiated_version = (ext_data[0] << 8) | ext_data[1]
elif et == C.EXT_KEY_SHARE and len(ext_data) >= 2:
key_share_group = (ext_data[0] << 8) | ext_data[1]
return ParsedServerHello(
record_version=rec_ver,
server_version=server_version,
cipher_suite=cs,
alert=None,
raw_record=data,
handshake_type=hs_type,
server_random=server_random,
session_id=session_id,
extensions=ext_map,
negotiated_version=negotiated_version,
key_share_group=key_share_group,
)
def build_ssl2_client_hello(ciphers: list[int] | None = None) -> bytes:
"""Craft an SSLv2 ClientHello used to detect DROWN / SSLv2 support."""
# SSLv2 cipher specs are 3-byte. A minimal set:
if ciphers is None:
# SSL_CK_RC4_128_WITH_MD5, SSL_CK_RC4_128_EXPORT40_WITH_MD5, SSL_CK_DES_192_EDE3_CBC_WITH_MD5
ciphers = [0x010080, 0x020080, 0x0700c0]
challenge = os.urandom(16)
cipher_bytes = b"".join(struct.pack("!I", c)[1:] for c in ciphers)
body = (
_u8(1) # MSG-CLIENT-HELLO
+ _u16(0x0002) # SSLv2 version
+ _u16(len(cipher_bytes)) # cipher specs length
+ _u16(0) # session-id length
+ _u16(len(challenge)) # challenge length
+ cipher_bytes
+ challenge
)
header = struct.pack("!H", 0x8000 | len(body))
return header + body