#!/usr/bin/env python3
"""Generate a plain (non-Geneve) SNMP pcap for reproducing the tx leak bug.

This creates a single long-lived UDP flow with continuous SNMP GetRequest/GetResponse
pairs. When combined with a `pass` rule matching the source, transactions accumulate
indefinitely because FLOW_ACTION_PASS prevents detection from running on subsequent
packets, so FLOW_SGH_TOCLIENT is never set, blocking transaction cleanup.

Bug: https://redmine.openinfosecfoundation.org/issues/XXXX
Upstream commit: d8ddef4c1485004cfb24d0e4b1c490f185bedc92

Usage:
    python3 gen_plain_snmp_pcap.py [--count 10000] [--output snmp_leak_repro.pcap]

Repro steps:
    1. Generate pcap:  python3 gen_plain_snmp_pcap.py --count 10000
    2. Run suricata:   ./suricata_afp --config-dir dev-runner/repro_config
    3. Replay pcap:    sudo tcpreplay -i SFE_0_TX snmp_leak_repro.pcap
    4. Observe memory growth in stats or via massif
"""
from __future__ import annotations

import argparse
import struct
import socket


def build_ethernet(src_mac: bytes, dst_mac: bytes, ethertype: int) -> bytes:
    return dst_mac + src_mac + struct.pack("!H", ethertype)


def build_ipv4(
    src: str, dst: str, proto: int, payload_len: int, ident: int = 0
) -> bytes:
    src_bytes = socket.inet_aton(src)
    dst_bytes = socket.inet_aton(dst)
    version_ihl = 0x45
    tos = 0
    total_len = 20 + payload_len
    flags_offset = 0x4000  # Don't fragment
    ttl = 64
    header = struct.pack(
        "!BBHHHBBH4s4s",
        version_ihl,
        tos,
        total_len,
        ident & 0xFFFF,
        flags_offset,
        ttl,
        proto,
        0,  # checksum placeholder
        src_bytes,
        dst_bytes,
    )
    checksum = ip_checksum(header)
    header = header[:10] + struct.pack("!H", checksum) + header[12:]
    return header


def build_udp(src_port: int, dst_port: int, payload: bytes) -> bytes:
    length = 8 + len(payload)
    header = struct.pack("!HHHH", src_port, dst_port, length, 0)
    return header + payload


def build_snmp_get_request(request_id: int, oids: list[str]) -> bytes:
    """Build a minimal SNMPv2c GetRequest PDU with multiple varbinds."""
    varbinds = b""
    for oid in oids:
        oid_encoded = encode_oid(oid)
        varbinds += asn1_sequence(oid_encoded + b"\x05\x00")
    varbind_list = asn1_sequence(varbinds)
    pdu_content = (
        asn1_integer(request_id)
        + asn1_integer(0)
        + asn1_integer(0)
        + varbind_list
    )
    pdu = bytes([0xA0]) + asn1_length(len(pdu_content)) + pdu_content
    message_content = (
        asn1_integer(1)  # SNMPv2c
        + asn1_octet_string(b"public")
        + pdu
    )
    return asn1_sequence(message_content)


def build_snmp_get_response(
    request_id: int, oids: list[str], base_value: int = 12345
) -> bytes:
    """Build a minimal SNMPv2c GetResponse PDU with multiple varbinds."""
    varbinds = b""
    for idx, oid in enumerate(oids):
        oid_encoded = encode_oid(oid)
        varbinds += asn1_sequence(oid_encoded + asn1_integer(base_value + idx))
    varbind_list = asn1_sequence(varbinds)
    pdu_content = (
        asn1_integer(request_id)
        + asn1_integer(0)
        + asn1_integer(0)
        + varbind_list
    )
    pdu = bytes([0xA2]) + asn1_length(len(pdu_content)) + pdu_content
    message_content = (
        asn1_integer(1)
        + asn1_octet_string(b"public")
        + pdu
    )
    return asn1_sequence(message_content)


def encode_oid(oid_str: str) -> bytes:
    """Encode an OID string to ASN.1 DER."""
    parts = [int(x) for x in oid_str.split(".")]
    if len(parts) < 2:
        parts += [0] * (2 - len(parts))
    encoded = bytes([40 * parts[0] + parts[1]])
    for part in parts[2:]:
        if part < 128:
            encoded += bytes([part])
        else:
            octets: list[int] = []
            while part > 0:
                octets.append(part & 0x7F)
                part >>= 7
            octets.reverse()
            for i in range(len(octets) - 1):
                octets[i] |= 0x80
            encoded += bytes(octets)
    return bytes([0x06]) + asn1_length(len(encoded)) + encoded


def asn1_sequence(content: bytes) -> bytes:
    return bytes([0x30]) + asn1_length(len(content)) + content


def asn1_integer(value: int) -> bytes:
    if value == 0:
        return b"\x02\x01\x00"
    negative = value < 0
    if negative:
        value = -value - 1
    octets: list[int] = []
    while value > 0:
        octets.append(value & 0xFF)
        value >>= 8
    octets.reverse()
    if not negative and octets and octets[0] & 0x80:
        octets.insert(0, 0)
    if negative:
        octets = [b ^ 0xFF for b in octets]
        if not octets or not (octets[0] & 0x80):
            octets.insert(0, 0xFF)
    content = bytes(octets)
    return bytes([0x02]) + asn1_length(len(content)) + content


def asn1_octet_string(value: bytes) -> bytes:
    return bytes([0x04]) + asn1_length(len(value)) + value


def asn1_length(length: int) -> bytes:
    if length < 128:
        return bytes([length])
    octets: list[int] = []
    tmp = length
    while tmp > 0:
        octets.append(tmp & 0xFF)
        tmp >>= 8
    octets.reverse()
    return bytes([0x80 | len(octets)]) + bytes(octets)


def ip_checksum(header: bytes) -> int:
    if len(header) % 2 == 1:
        header += b"\x00"
    s = 0
    for i in range(0, len(header), 2):
        s += (header[i] << 8) + header[i + 1]
    s = (s >> 16) + (s & 0xFFFF)
    s += s >> 16
    return ~s & 0xFFFF


def write_pcap(filename: str, packets: list[tuple[float, bytes]]) -> None:
    """Write packets to a pcap file (linktype Ethernet)."""
    with open(filename, "wb") as f:
        # Global header: magic, major, minor, thiszone, sigfigs, snaplen, linktype
        f.write(struct.pack("<IHHiIII", 0xA1B2C3D4, 2, 4, 0, 0, 65535, 1))
        for ts, data in packets:
            sec = int(ts)
            usec = int((ts - sec) * 1_000_000)
            f.write(struct.pack("<IIII", sec, usec, len(data), len(data)))
            f.write(data)


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate plain SNMP pcap for tx leak repro"
    )
    parser.add_argument(
        "--count", type=int, default=10000,
        help="Number of SNMP request/response pairs (default: 10000)",
    )
    parser.add_argument(
        "--output", default="snmp_leak_repro.pcap",
        help="Output pcap filename (default: snmp_leak_repro.pcap)",
    )
    parser.add_argument(
        "--oids-per-pdu", type=int, default=10,
        help="Number of OID varbinds per SNMP PDU (default: 10)",
    )
    parser.add_argument(
        "--interval-ms", type=float, default=5.0,
        help="Milliseconds between request/response pairs (default: 5.0)",
    )
    args = parser.parse_args()

    # Plain UDP 5-tuple — source is in 10.0.0.0/8 to match the pass rule
    snmp_client = "10.109.200.125"  # NMS (matches pass rule src 10.0.0.0/8)
    snmp_server = "172.16.50.10"    # SNMP agent (outside HOME_NET)
    client_port = 53688             # Fixed = single flow
    server_port = 161               # SNMP

    src_mac = b"\x02\x00\x00\x00\x00\x01"
    dst_mac = b"\x02\x00\x00\x00\x00\x02"

    packets: list[tuple[float, bytes]] = []
    base_ts = 1700000000.0
    interval = args.interval_ms / 1000.0

    for i in range(args.count):
        request_id = i + 1
        ts = base_ts + (i * interval)

        # Build list of OIDs for this PDU
        oids = [
            f"1.3.6.1.2.1.2.2.1.{(j + 1) % 20}.{(i * args.oids_per_pdu + j) % 500}"
            for j in range(args.oids_per_pdu)
        ]

        # GetRequest: client -> server (toserver direction)
        snmp_req = build_snmp_get_request(request_id, oids)
        udp_req = build_udp(client_port, server_port, snmp_req)
        ip_req = build_ipv4(snmp_client, snmp_server, 17, len(udp_req), ident=i)
        eth_req = build_ethernet(src_mac, dst_mac, 0x0800)
        packets.append((ts, eth_req + ip_req + udp_req))

        # GetResponse: server -> client (toclient direction)
        snmp_resp = build_snmp_get_response(request_id, oids, base_value=i * 1000)
        udp_resp = build_udp(server_port, client_port, snmp_resp)
        ip_resp = build_ipv4(
            snmp_server, snmp_client, 17, len(udp_resp), ident=i + 50000
        )
        eth_resp = build_ethernet(dst_mac, src_mac, 0x0800)
        packets.append((ts + (interval / 2), eth_resp + ip_resp + udp_resp))

    write_pcap(args.output, packets)
    print(f"Generated {args.output}:")
    print(f"  {args.count} SNMP request/response pairs ({len(packets)} packets)")
    print(f"  {args.oids_per_pdu} OIDs per PDU")
    print(f"  Flow: {snmp_client}:{client_port} <-> {snmp_server}:{server_port}")
    print(f"  Interval: {args.interval_ms}ms between pairs")
    print(f"  Time span: {args.count * interval:.1f}s")
    print(f"  Source in 10.0.0.0/8 -> matches 'pass udp 10.0.0.0/8 any -> any any'")


if __name__ == "__main__":
    main()
