261 lines
8.1 KiB
Python
261 lines
8.1 KiB
Python
from __future__ import annotations
|
|
import os
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from . import constants as C
|
|
|
|
|
|
def _u8(x: int) -> bytes:
|
|
return struct.pack("!B", x)
|
|
|
|
def _u16(x: int) -> bytes:
|
|
return struct.pack("!H", x)
|
|
|
|
def _u24(x: int) -> bytes:
|
|
return struct.pack("!I", x)[1:]
|
|
|
|
|
|
def build_extension(ext_type: int, data: bytes) -> bytes:
|
|
return _u16(ext_type) + _u16(len(data)) + data
|
|
|
|
|
|
def ext_server_name(hostname: str) -> bytes:
|
|
name = hostname.encode("idna")
|
|
entry = _u8(0) + _u16(len(name)) + name
|
|
lst = _u16(len(entry)) + entry
|
|
return build_extension(C.EXT_SERVER_NAME, lst)
|
|
|
|
|
|
def ext_supported_versions_client(versions: list[int]) -> bytes:
|
|
body = _u8(2 * len(versions)) + b"".join(_u16(v) for v in versions)
|
|
return build_extension(C.EXT_SUPPORTED_VERSIONS, body)
|
|
|
|
|
|
def ext_supported_groups(groups: list[int]) -> bytes:
|
|
body = _u16(2 * len(groups)) + b"".join(_u16(g) for g in groups)
|
|
return build_extension(C.EXT_SUPPORTED_GROUPS, body)
|
|
|
|
|
|
def ext_signature_algorithms(algos: list[int]) -> bytes:
|
|
body = _u16(2 * len(algos)) + b"".join(_u16(a) for a in algos)
|
|
return build_extension(C.EXT_SIGNATURE_ALGORITHMS, body)
|
|
|
|
|
|
def ext_ec_point_formats() -> bytes:
|
|
return build_extension(C.EXT_EC_POINT_FORMATS, _u8(1) + _u8(0))
|
|
|
|
|
|
def ext_key_share_empty() -> bytes:
|
|
return build_extension(C.EXT_KEY_SHARE, _u16(0))
|
|
|
|
|
|
def ext_psk_key_exchange_modes() -> bytes:
|
|
return build_extension(C.EXT_PSK_KEY_EXCHANGE_MODES, _u8(1) + _u8(1))
|
|
|
|
|
|
def ext_alpn(protos: list[bytes]) -> bytes:
|
|
inner = b"".join(_u8(len(p)) + p for p in protos)
|
|
return build_extension(C.EXT_ALPN, _u16(len(inner)) + inner)
|
|
|
|
|
|
def ext_heartbeat_enabled() -> bytes:
|
|
return build_extension(C.EXT_HEARTBEAT, _u8(1))
|
|
|
|
|
|
def ext_renegotiation_info_empty() -> bytes:
|
|
return build_extension(C.EXT_RENEGOTIATION_INFO, _u8(0))
|
|
|
|
|
|
def ext_status_request() -> bytes:
|
|
# status_type=OCSP(1) + empty responder_id_list + empty extensions
|
|
body = _u8(1) + _u16(0) + _u16(0)
|
|
return build_extension(C.EXT_STATUS_REQUEST, body)
|
|
|
|
|
|
def ext_signed_cert_timestamp() -> bytes:
|
|
return build_extension(C.EXT_SIGNED_CERT_TIMESTAMP, b"")
|
|
|
|
|
|
def ext_extended_master_secret() -> bytes:
|
|
return build_extension(C.EXT_EXTENDED_MASTER_SECRET, b"")
|
|
|
|
|
|
def build_client_hello(
|
|
record_version: int,
|
|
client_hello_version: int,
|
|
hostname: str | None,
|
|
cipher_suites: list[int],
|
|
extensions: bytes = b"",
|
|
compression: bytes = b"\x01\x00", # 1 length, null method
|
|
) -> bytes:
|
|
random_bytes = os.urandom(32)
|
|
session_id = b""
|
|
cs_bytes = b"".join(_u16(c) for c in cipher_suites)
|
|
body = (
|
|
_u16(client_hello_version)
|
|
+ random_bytes
|
|
+ _u8(len(session_id))
|
|
+ session_id
|
|
+ _u16(len(cs_bytes))
|
|
+ cs_bytes
|
|
+ compression
|
|
)
|
|
if extensions:
|
|
body += _u16(len(extensions)) + extensions
|
|
handshake = _u8(C.HS_CLIENT_HELLO) + _u24(len(body)) + body
|
|
record = _u8(C.CT_HANDSHAKE) + _u16(record_version) + _u16(len(handshake)) + handshake
|
|
return record
|
|
|
|
|
|
def default_ch_extensions(hostname: str, versions: list[int], groups: list[int] | None = None) -> bytes:
|
|
groups = groups or [0x001d, 0x0017, 0x0018]
|
|
parts = []
|
|
if hostname and not _is_ip_literal(hostname):
|
|
parts.append(ext_server_name(hostname))
|
|
parts.append(ext_ec_point_formats())
|
|
parts.append(ext_supported_groups(groups))
|
|
parts.append(ext_signature_algorithms([
|
|
0x0403, 0x0804, 0x0401, 0x0503, 0x0805, 0x0501,
|
|
0x0603, 0x0806, 0x0601, 0x0807, 0x0808,
|
|
]))
|
|
parts.append(ext_renegotiation_info_empty())
|
|
parts.append(ext_signed_cert_timestamp())
|
|
parts.append(ext_status_request())
|
|
parts.append(ext_extended_master_secret())
|
|
parts.append(ext_alpn([b"h2", b"http/1.1"]))
|
|
if C.TLS_1_3 in versions:
|
|
parts.append(ext_supported_versions_client(versions))
|
|
parts.append(ext_psk_key_exchange_modes())
|
|
parts.append(ext_key_share_empty())
|
|
exts = b"".join(parts)
|
|
return _u16(len(exts)) + exts if False else exts
|
|
|
|
|
|
def _is_ip_literal(host: str) -> bool:
|
|
import ipaddress
|
|
try:
|
|
ipaddress.ip_address(host.strip("[]"))
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class ParsedServerHello:
|
|
record_version: int
|
|
server_version: int
|
|
cipher_suite: int | None
|
|
alert: tuple[int, int] | None
|
|
raw_record: bytes
|
|
handshake_type: int | None
|
|
server_random: bytes | None = None
|
|
session_id: bytes | None = None
|
|
extensions: dict[int, bytes] | None = None
|
|
negotiated_version: int | None = None
|
|
key_share_group: int | None = None
|
|
|
|
|
|
def parse_server_response(data: bytes) -> ParsedServerHello | None:
|
|
"""Parse the first TLS record. Returns None if insufficient data or garbage."""
|
|
if len(data) < 5:
|
|
return None
|
|
ct = data[0]
|
|
rec_ver = (data[1] << 8) | data[2]
|
|
rec_len = (data[3] << 8) | data[4]
|
|
body = data[5:5 + rec_len]
|
|
if ct == C.CT_ALERT and len(body) >= 2:
|
|
return ParsedServerHello(
|
|
record_version=rec_ver,
|
|
server_version=0,
|
|
cipher_suite=None,
|
|
alert=(body[0], body[1]),
|
|
raw_record=data,
|
|
handshake_type=None,
|
|
)
|
|
if ct != C.CT_HANDSHAKE or len(body) < 4:
|
|
return ParsedServerHello(
|
|
record_version=rec_ver,
|
|
server_version=0,
|
|
cipher_suite=None,
|
|
alert=None,
|
|
raw_record=data,
|
|
handshake_type=None,
|
|
)
|
|
hs_type = body[0]
|
|
hs_len = (body[1] << 16) | (body[2] << 8) | body[3]
|
|
hs = body[4:4 + hs_len]
|
|
if hs_type != C.HS_SERVER_HELLO or len(hs) < 38:
|
|
return ParsedServerHello(
|
|
record_version=rec_ver,
|
|
server_version=0,
|
|
cipher_suite=None,
|
|
alert=None,
|
|
raw_record=data,
|
|
handshake_type=hs_type,
|
|
)
|
|
server_version = (hs[0] << 8) | hs[1]
|
|
server_random = hs[2:34]
|
|
sess_len = hs[34]
|
|
off = 35 + sess_len
|
|
if len(hs) < off + 3:
|
|
return None
|
|
session_id = hs[35:35 + sess_len]
|
|
cs = (hs[off] << 8) | hs[off + 1]
|
|
off += 2
|
|
# comp method (1 byte)
|
|
off += 1
|
|
ext_map: dict[int, bytes] = {}
|
|
negotiated_version = server_version
|
|
key_share_group = None
|
|
if off + 2 <= len(hs):
|
|
ext_total = (hs[off] << 8) | hs[off + 1]
|
|
off += 2
|
|
ext_end = off + ext_total
|
|
while off + 4 <= ext_end:
|
|
et = (hs[off] << 8) | hs[off + 1]
|
|
el = (hs[off + 2] << 8) | hs[off + 3]
|
|
off += 4
|
|
ext_data = hs[off:off + el]
|
|
off += el
|
|
ext_map[et] = ext_data
|
|
if et == C.EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2:
|
|
negotiated_version = (ext_data[0] << 8) | ext_data[1]
|
|
elif et == C.EXT_KEY_SHARE and len(ext_data) >= 2:
|
|
key_share_group = (ext_data[0] << 8) | ext_data[1]
|
|
return ParsedServerHello(
|
|
record_version=rec_ver,
|
|
server_version=server_version,
|
|
cipher_suite=cs,
|
|
alert=None,
|
|
raw_record=data,
|
|
handshake_type=hs_type,
|
|
server_random=server_random,
|
|
session_id=session_id,
|
|
extensions=ext_map,
|
|
negotiated_version=negotiated_version,
|
|
key_share_group=key_share_group,
|
|
)
|
|
|
|
|
|
def build_ssl2_client_hello(ciphers: list[int] | None = None) -> bytes:
|
|
"""Craft an SSLv2 ClientHello used to detect DROWN / SSLv2 support."""
|
|
# SSLv2 cipher specs are 3-byte. A minimal set:
|
|
if ciphers is None:
|
|
# SSL_CK_RC4_128_WITH_MD5, SSL_CK_RC4_128_EXPORT40_WITH_MD5, SSL_CK_DES_192_EDE3_CBC_WITH_MD5
|
|
ciphers = [0x010080, 0x020080, 0x0700c0]
|
|
challenge = os.urandom(16)
|
|
cipher_bytes = b"".join(struct.pack("!I", c)[1:] for c in ciphers)
|
|
body = (
|
|
_u8(1) # MSG-CLIENT-HELLO
|
|
+ _u16(0x0002) # SSLv2 version
|
|
+ _u16(len(cipher_bytes)) # cipher specs length
|
|
+ _u16(0) # session-id length
|
|
+ _u16(len(challenge)) # challenge length
|
|
+ cipher_bytes
|
|
+ challenge
|
|
)
|
|
header = struct.pack("!H", 0x8000 | len(body))
|
|
return header + body
|