#!/usr/bin/env python3
"""
Send a DNS query whose QNAME uses a forward compression pointer.

Default layout:

  offset 12:  c0 12                 QNAME: pointer to offset 18
  offset 14:  00 01 00 01           QTYPE=A, QCLASS=IN
  offset 18:  03 'www' ... 00       real target name, outside the question

The --layout second-question mode instead sets QDCOUNT=2 and makes the first
question's QNAME point forward to the second question's QNAME.

These packets are intentionally malformed/non-standard, but useful for testing
whether a resolver's DNS name parser accepts forward compression pointers.
"""

import argparse
import random
import socket
import struct
import sys

from scapy.all import DNS, IP, IPv6, Raw, UDP, hexdump, sr1  # type: ignore

QTYPES = {
    "A": 1,
    "NS": 2,
    "CNAME": 5,
    "SOA": 6,
    "PTR": 12,
    "MX": 15,
    "TXT": 16,
    "AAAA": 28,
    "SRV": 33,
    "HTTPS": 65,
    "ANY": 255,
}


def encode_name(name: str) -> bytes:
    name = name.rstrip(".")
    if not name:
        return b"\x00"
    out = bytearray()
    for label in name.split("."):
        raw = label.encode("ascii")
        if len(raw) > 63:
            raise ValueError(f"label too long: {label!r}")
        out.append(len(raw))
        out.extend(raw)
    out.append(0)
    return bytes(out)


def parse_qtype(value: str) -> int:
    upper = value.upper()
    if upper in QTYPES:
        return QTYPES[upper]
    return int(value, 0)


def make_header(dns_id: int, rd: bool, qdcount: int = 1) -> bytes:
    flags = 0x0100 if rd else 0x0000
    return struct.pack("!HHHHHH", dns_id, flags, qdcount, 0, 0, 0)


def make_normal_query(name: str, qtype: int, qclass: int, dns_id: int, rd: bool) -> bytes:
    return make_header(dns_id, rd) + encode_name(name) + struct.pack("!HH", qtype, qclass)


def make_forward_ptr_query(
    name: str, qtype: int, qclass: int, dns_id: int, rd: bool, layout: str
) -> bytes:
    header = make_header(dns_id, rd)

    if layout == "whole":
        # QNAME is only a pointer. The pointed-to name starts after QTYPE/QCLASS.
        target_offset = 12 + 2 + 4
        if target_offset > 0x3FFF:
            raise ValueError("target offset too large for DNS compression pointer")
        qname = struct.pack("!H", 0xC000 | target_offset)
        return header + qname + struct.pack("!HH", qtype, qclass) + encode_name(name)

    if layout == "suffix":
        labels = name.rstrip(".").split(".")
        if len(labels) < 2:
            raise ValueError("suffix layout requires a name with at least two labels")
        first = labels[0].encode("ascii")
        if len(first) > 63:
            raise ValueError(f"label too long: {labels[0]!r}")
        suffix = ".".join(labels[1:])
        prefix = bytes([len(first)]) + first
        target_offset = 12 + len(prefix) + 2 + 4
        if target_offset > 0x3FFF:
            raise ValueError("target offset too large for DNS compression pointer")
        qname = prefix + struct.pack("!H", 0xC000 | target_offset)
        return header + qname + struct.pack("!HH", qtype, qclass) + encode_name(suffix)

    if layout == "second-question":
        # QDCOUNT=2. First question's QNAME is a forward pointer to the
        # second question's QNAME, which starts immediately after the first
        # question's QTYPE/QCLASS.
        header = make_header(dns_id, rd, qdcount=2)
        target_offset = 12 + 2 + 4
        if target_offset > 0x3FFF:
            raise ValueError("target offset too large for DNS compression pointer")
        q1 = struct.pack("!H", 0xC000 | target_offset) + struct.pack("!HH", qtype, qclass)
        q2 = encode_name(name) + struct.pack("!HH", qtype, qclass)
        return header + q1 + q2

    raise ValueError(f"unknown layout: {layout}")


def send_socket(payload: bytes, server: str, port: int, timeout: float, use_tcp: bool) -> bytes | None:
    family = socket.AF_INET6 if ":" in server else socket.AF_INET
    socktype = socket.SOCK_STREAM if use_tcp else socket.SOCK_DGRAM
    with socket.socket(family, socktype) as s:
        s.settimeout(timeout)
        if use_tcp:
            s.connect((server, port))
            s.sendall(struct.pack("!H", len(payload)) + payload)
            hdr = s.recv(2)
            if len(hdr) != 2:
                return None
            length = struct.unpack("!H", hdr)[0]
            data = bytearray()
            while len(data) < length:
                chunk = s.recv(length - len(data))
                if not chunk:
                    break
                data.extend(chunk)
            return bytes(data)
        s.sendto(payload, (server, port))
        return s.recvfrom(65535)[0]


def send_scapy_l3(payload: bytes, server: str, port: int, timeout: float) -> bytes | None:
    ip = IPv6(dst=server) if ":" in server else IP(dst=server)
    pkt = ip / UDP(dport=port, sport=random.randint(1024, 65535)) / Raw(payload)
    resp = sr1(pkt, timeout=timeout, verbose=False)
    if resp is None or Raw not in resp:
        return None
    return bytes(resp[Raw].load)


def print_dns_header(data: bytes) -> None:
    if len(data) < 12:
        print(f"short response: {len(data)} bytes")
        return
    dns_id, flags, qd, an, ns, ar = struct.unpack("!HHHHHH", data[:12])
    print(
        f"id=0x{dns_id:04x} flags=0x{flags:04x} "
        f"rcode={flags & 0x0f} qd={qd} an={an} ns={ns} ar={ar}"
    )


def main() -> int:
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument("server", help="resolver IP address")
    ap.add_argument("name", help="query name, e.g. www.example.com")
    ap.add_argument("--port", type=int, default=53)
    ap.add_argument("--type", default="A", help="QTYPE name or number (default: A)")
    ap.add_argument("--class", dest="qclass", type=int, default=1, help="QCLASS (default: 1/IN)")
    ap.add_argument("--id", dest="dns_id", type=lambda x: int(x, 0), default=None)
    ap.add_argument("--no-rd", action="store_true", help="clear recursion-desired bit")
    ap.add_argument("--timeout", type=float, default=2.0)
    ap.add_argument("--tcp", action="store_true", help="send over TCP using DNS length prefix")
    ap.add_argument(
        "--layout",
        choices=("whole", "suffix", "second-question"),
        default="whole",
        help="forward pointer layout (default: whole)",
    )
    ap.add_argument(
        "--mode",
        choices=("socket", "scapy-l3"),
        default="socket",
        help="socket avoids raw-socket privileges; scapy-l3 sends IP/UDP/Raw with Scapy",
    )
    ap.add_argument("--also-normal", action="store_true", help="send a normal query first for comparison")
    ap.add_argument("--decode", action="store_true", help="try Scapy DNS decode of responses")
    args = ap.parse_args()

    qtype = parse_qtype(args.type)
    dns_id = args.dns_id if args.dns_id is not None else random.randint(0, 65535)
    rd = not args.no_rd

    tests: list[tuple[str, bytes]] = []
    if args.also_normal:
        tests.append(("normal", make_normal_query(args.name, qtype, args.qclass, dns_id, rd)))
        dns_id = (dns_id + 1) & 0xFFFF
    tests.append(
        (
            f"forward-pointer/{args.layout}",
            make_forward_ptr_query(args.name, qtype, args.qclass, dns_id, rd, args.layout),
        )
    )

    for label, payload in tests:
        print(f"\n=== {label} query ===")
        print(f"sending {len(payload)} DNS bytes to {args.server}:{args.port}")
        hexdump(payload)
        try:
            if args.mode == "scapy-l3":
                if args.tcp:
                    raise ValueError("--tcp is only supported with --mode socket")
                resp = send_scapy_l3(payload, args.server, args.port, args.timeout)
            else:
                resp = send_socket(payload, args.server, args.port, args.timeout, args.tcp)
        except socket.timeout:
            resp = None
        except OSError as e:
            print(f"send failed: {e}")
            return 2

        if resp is None:
            print("no response")
            continue

        print(f"\nresponse: {len(resp)} DNS bytes")
        print_dns_header(resp)
        hexdump(resp)
        if args.decode:
            print("\nScapy DNS decode:")
            try:
                DNS(resp).show()
            except Exception as e:
                print(f"decode failed: {e}")

    return 0


if __name__ == "__main__":
    sys.exit(main())
