from __future__ import annotations import os import struct from dataclasses import dataclass from typing import Any from . import constants as C def _u8(x: int) -> bytes: return struct.pack("!B", x) def _u16(x: int) -> bytes: return struct.pack("!H", x) def _u24(x: int) -> bytes: return struct.pack("!I", x)[1:] def build_extension(ext_type: int, data: bytes) -> bytes: return _u16(ext_type) + _u16(len(data)) + data def ext_server_name(hostname: str) -> bytes: name = hostname.encode("idna") entry = _u8(0) + _u16(len(name)) + name lst = _u16(len(entry)) + entry return build_extension(C.EXT_SERVER_NAME, lst) def ext_supported_versions_client(versions: list[int]) -> bytes: body = _u8(2 * len(versions)) + b"".join(_u16(v) for v in versions) return build_extension(C.EXT_SUPPORTED_VERSIONS, body) def ext_supported_groups(groups: list[int]) -> bytes: body = _u16(2 * len(groups)) + b"".join(_u16(g) for g in groups) return build_extension(C.EXT_SUPPORTED_GROUPS, body) def ext_signature_algorithms(algos: list[int]) -> bytes: body = _u16(2 * len(algos)) + b"".join(_u16(a) for a in algos) return build_extension(C.EXT_SIGNATURE_ALGORITHMS, body) def ext_ec_point_formats() -> bytes: return build_extension(C.EXT_EC_POINT_FORMATS, _u8(1) + _u8(0)) def ext_key_share_empty() -> bytes: return build_extension(C.EXT_KEY_SHARE, _u16(0)) def ext_psk_key_exchange_modes() -> bytes: return build_extension(C.EXT_PSK_KEY_EXCHANGE_MODES, _u8(1) + _u8(1)) def ext_alpn(protos: list[bytes]) -> bytes: inner = b"".join(_u8(len(p)) + p for p in protos) return build_extension(C.EXT_ALPN, _u16(len(inner)) + inner) def ext_heartbeat_enabled() -> bytes: return build_extension(C.EXT_HEARTBEAT, _u8(1)) def ext_renegotiation_info_empty() -> bytes: return build_extension(C.EXT_RENEGOTIATION_INFO, _u8(0)) def ext_status_request() -> bytes: # status_type=OCSP(1) + empty responder_id_list + empty extensions body = _u8(1) + _u16(0) + _u16(0) return build_extension(C.EXT_STATUS_REQUEST, body) def ext_signed_cert_timestamp() -> bytes: return build_extension(C.EXT_SIGNED_CERT_TIMESTAMP, b"") def ext_extended_master_secret() -> bytes: return build_extension(C.EXT_EXTENDED_MASTER_SECRET, b"") def build_client_hello( record_version: int, client_hello_version: int, hostname: str | None, cipher_suites: list[int], extensions: bytes = b"", compression: bytes = b"\x01\x00", # 1 length, null method ) -> bytes: random_bytes = os.urandom(32) session_id = b"" cs_bytes = b"".join(_u16(c) for c in cipher_suites) body = ( _u16(client_hello_version) + random_bytes + _u8(len(session_id)) + session_id + _u16(len(cs_bytes)) + cs_bytes + compression ) if extensions: body += _u16(len(extensions)) + extensions handshake = _u8(C.HS_CLIENT_HELLO) + _u24(len(body)) + body record = _u8(C.CT_HANDSHAKE) + _u16(record_version) + _u16(len(handshake)) + handshake return record def default_ch_extensions(hostname: str, versions: list[int], groups: list[int] | None = None) -> bytes: groups = groups or [0x001d, 0x0017, 0x0018] parts = [] if hostname and not _is_ip_literal(hostname): parts.append(ext_server_name(hostname)) parts.append(ext_ec_point_formats()) parts.append(ext_supported_groups(groups)) parts.append(ext_signature_algorithms([ 0x0403, 0x0804, 0x0401, 0x0503, 0x0805, 0x0501, 0x0603, 0x0806, 0x0601, 0x0807, 0x0808, ])) parts.append(ext_renegotiation_info_empty()) parts.append(ext_signed_cert_timestamp()) parts.append(ext_status_request()) parts.append(ext_extended_master_secret()) parts.append(ext_alpn([b"h2", b"http/1.1"])) if C.TLS_1_3 in versions: parts.append(ext_supported_versions_client(versions)) parts.append(ext_psk_key_exchange_modes()) parts.append(ext_key_share_empty()) exts = b"".join(parts) return _u16(len(exts)) + exts if False else exts def _is_ip_literal(host: str) -> bool: import ipaddress try: ipaddress.ip_address(host.strip("[]")) return True except ValueError: return False @dataclass class ParsedServerHello: record_version: int server_version: int cipher_suite: int | None alert: tuple[int, int] | None raw_record: bytes handshake_type: int | None server_random: bytes | None = None session_id: bytes | None = None extensions: dict[int, bytes] | None = None negotiated_version: int | None = None key_share_group: int | None = None def parse_server_response(data: bytes) -> ParsedServerHello | None: """Parse the first TLS record. Returns None if insufficient data or garbage.""" if len(data) < 5: return None ct = data[0] rec_ver = (data[1] << 8) | data[2] rec_len = (data[3] << 8) | data[4] body = data[5:5 + rec_len] if ct == C.CT_ALERT and len(body) >= 2: return ParsedServerHello( record_version=rec_ver, server_version=0, cipher_suite=None, alert=(body[0], body[1]), raw_record=data, handshake_type=None, ) if ct != C.CT_HANDSHAKE or len(body) < 4: return ParsedServerHello( record_version=rec_ver, server_version=0, cipher_suite=None, alert=None, raw_record=data, handshake_type=None, ) hs_type = body[0] hs_len = (body[1] << 16) | (body[2] << 8) | body[3] hs = body[4:4 + hs_len] if hs_type != C.HS_SERVER_HELLO or len(hs) < 38: return ParsedServerHello( record_version=rec_ver, server_version=0, cipher_suite=None, alert=None, raw_record=data, handshake_type=hs_type, ) server_version = (hs[0] << 8) | hs[1] server_random = hs[2:34] sess_len = hs[34] off = 35 + sess_len if len(hs) < off + 3: return None session_id = hs[35:35 + sess_len] cs = (hs[off] << 8) | hs[off + 1] off += 2 # comp method (1 byte) off += 1 ext_map: dict[int, bytes] = {} negotiated_version = server_version key_share_group = None if off + 2 <= len(hs): ext_total = (hs[off] << 8) | hs[off + 1] off += 2 ext_end = off + ext_total while off + 4 <= ext_end: et = (hs[off] << 8) | hs[off + 1] el = (hs[off + 2] << 8) | hs[off + 3] off += 4 ext_data = hs[off:off + el] off += el ext_map[et] = ext_data if et == C.EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2: negotiated_version = (ext_data[0] << 8) | ext_data[1] elif et == C.EXT_KEY_SHARE and len(ext_data) >= 2: key_share_group = (ext_data[0] << 8) | ext_data[1] return ParsedServerHello( record_version=rec_ver, server_version=server_version, cipher_suite=cs, alert=None, raw_record=data, handshake_type=hs_type, server_random=server_random, session_id=session_id, extensions=ext_map, negotiated_version=negotiated_version, key_share_group=key_share_group, ) def build_ssl2_client_hello(ciphers: list[int] | None = None) -> bytes: """Craft an SSLv2 ClientHello used to detect DROWN / SSLv2 support.""" # SSLv2 cipher specs are 3-byte. A minimal set: if ciphers is None: # SSL_CK_RC4_128_WITH_MD5, SSL_CK_RC4_128_EXPORT40_WITH_MD5, SSL_CK_DES_192_EDE3_CBC_WITH_MD5 ciphers = [0x010080, 0x020080, 0x0700c0] challenge = os.urandom(16) cipher_bytes = b"".join(struct.pack("!I", c)[1:] for c in ciphers) body = ( _u8(1) # MSG-CLIENT-HELLO + _u16(0x0002) # SSLv2 version + _u16(len(cipher_bytes)) # cipher specs length + _u16(0) # session-id length + _u16(len(challenge)) # challenge length + cipher_bytes + challenge ) header = struct.pack("!H", 0x8000 | len(body)) return header + body