Skip to content

Commit 2e3ca28

Browse files
committed
Tom's Data Onion: unwrap layer 6
1 parent d2eecb7 commit 2e3ca28

1 file changed

Lines changed: 141 additions & 31 deletions

File tree

other/data_onion/onion.py

Lines changed: 141 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#!/bin/python
22
import base64
3+
import collections
34
import io
45
import itertools
6+
import os
57
import pathlib
8+
import subprocess
69
import struct
710
import sys
811

@@ -106,6 +109,7 @@ def layer3(data: bytes) -> str:
106109

107110
def layer4(data: bytes) -> str:
108111
"""Layer 4/6: Network Traffic."""
112+
mask = (1 << 16) - 1
109113

110114
def ones_complement(x: int) -> int:
111115
assert x <= mask, x
@@ -116,33 +120,47 @@ def checksum_ok(block: bytes) -> bool:
116120
checksum = 0
117121
for word in words:
118122
checksum += ones_complement(word)
119-
if checksum >> 16:
123+
if checksum > mask:
120124
checksum = (checksum & mask) + 1
121-
return not ones_complement(checksum)
125+
return ones_complement(checksum) == 0
122126

123-
mask = (1 << 16) - 1
124127
want_from = bytes(bytearray([10, 1, 1, 10])) # 10.1.1.10
125128
want_to = bytes(bytearray([10, 1, 1, 200])) # 10.1.1.200
126-
i = io.BytesIO(data)
127-
o = io.BytesIO()
128-
while i.tell() < len(data):
129-
ip4_header = i.read1(20)
130-
udp_header = i.read1(8)
131-
src = ip4_header[12:16]
132-
dst = ip4_header[16:20]
129+
want_port = 42069
130+
131+
raw_data_in = io.BytesIO(data)
132+
data_out = io.BytesIO()
133+
valid_count = 0
134+
packets_count = 0
135+
136+
while raw_data_in.tell() < len(data):
137+
ipv4_header = raw_data_in.read1(20)
138+
udp_header = raw_data_in.read1(8)
133139
udp_length = int.from_bytes(udp_header[4:6])
134-
ip4_length = int.from_bytes(ip4_header[2:4])
135-
assert ip4_length == udp_length + 20
136-
udp_data = i.read1(udp_length - 8)
140+
udp_data = raw_data_in.read1(udp_length - 8)
141+
142+
src = ipv4_header[12:16]
143+
dst = ipv4_header[16:20]
144+
udp_to_port = int.from_bytes(udp_header[2:4])
145+
146+
ipv4_length = int.from_bytes(ipv4_header[2:4])
147+
assert ipv4_length == udp_length + 20
148+
137149
udp_pseudo_header = src + dst + (17).to_bytes(2) + udp_length.to_bytes(2) + udp_header + udp_data
150+
# Pad to 16 bits.
138151
if udp_length % 2:
139152
udp_pseudo_header += bytes(bytearray([0]))
140-
if src == want_from and dst == want_to and checksum_ok(ip4_header) and checksum_ok(udp_pseudo_header):
141-
o.write(udp_data)
142-
return o.getvalue().decode()
153+
154+
packets_count += 1
155+
if src == want_from and dst == want_to and udp_to_port == want_port and checksum_ok(ipv4_header) and checksum_ok(udp_pseudo_header):
156+
data_out.write(udp_data)
157+
valid_count += 1
158+
print("Valid packets:", valid_count, "out of", packets_count)
159+
return data_out.getvalue().decode()
143160

144161

145162
def layer5(data: bytes) -> str:
163+
"""Layer 5/6: Advanced Encryption Standard."""
146164
# - First 32 bytes: The 256-bit key encrypting key (KEK).
147165
# - Next 8 bytes: The 64-bit initialization vector (IV) for
148166
# the wrapped key.
@@ -178,19 +196,111 @@ def aes_unwrap_key_and_iv(kek, wrapped):
178196
assert kiv == got_iv
179197
assert len(decrypted_data_key) == 32
180198

181-
iv = int.from_bytes(div)
182-
ctr = Counter.new(AES.block_size * 8, initial_value=iv)
183-
184-
cipher = AES.new(decrypted_data_key, mode=AES.MODE_CTR, counter=ctr)
185-
decrypted = cipher.decrypt(encrypted)
186-
unpad = lambda s: s[:-ord(s[len(s)-1:])]
187-
decrypted = unpad(decrypted)
188-
189-
print(decrypted)
190-
return decrypted.decode()
191-
192-
193-
PARTS = [starting, layer0, layer1, layer2, layer3, layer4, layer5]
194-
195-
print(get_instructions(int(sys.argv[1])))
199+
cmd = ["openssl", "enc", "-aes-256-ctr", "-d"]
200+
cmd.extend(["-K", decrypted_data_key.hex()])
201+
cmd.extend(["-iv", div.hex()])
202+
proc = subprocess.run(cmd, capture_output=True, input=encrypted)
203+
return proc.stdout.decode()
204+
205+
206+
def layer6(data: bytes) -> str:
207+
"""Layer 6/6: Virtual Machine."""
208+
memory = dict(enumerate(bytearray(data)))
209+
print(" ".join(f"{i:08b}" for i in data[:15]))
210+
print(" ".join(f"{i:08b}" for i in data[15:30]))
211+
data_out = io.BytesIO()
212+
reg = {i: 0 for i in "abcdef"} | {f"l{i}": 0 for i in "abcd"} | {i: 0 for i in ["ptr", "pc"]}
213+
214+
regmap = {
215+
False: dict(enumerate("abcdef", start=1)), # short
216+
True: dict(enumerate(["la", "lb", "lc", "ld", "ptr", "pc"], start=1)), # long
217+
}
218+
219+
def memread(i: int, long: bool) -> int:
220+
if 1 <= i <= 6:
221+
return reg[regmap[long][i]]
222+
elif i == 7 and not long:
223+
target = reg["ptr"] + reg["c"]
224+
if target > len(memory):
225+
print(reg["ptr"], reg["c"], f"{target=} {len(memory)=}")
226+
return memory.get(target, 0)
227+
raise ValueError()
228+
229+
def memwrite(i: int, long: bool, val: int) -> int:
230+
if 1 <= i <= 6:
231+
target = regmap[long][i]
232+
print(f"memwrite {target} = {val}")
233+
reg[target] = val
234+
elif i == 7 and not long:
235+
memory[reg["ptr"] + reg["c"]] = val
236+
else:
237+
raise ValueError(f"Invalid write to destination {i} {long=}")
238+
239+
OPS = {
240+
0xC2: "ADD", 0xE1: "APTR", 0xC1: "CMP", 0x01: "HALT",
241+
0x21: "JEZ", 0x22: "JNZ", 0x02: "OUT", 0xC3: "SUB", 0xC4: "XOR",
242+
}
243+
244+
def get_pc(size: int) -> int:
245+
pc = reg["pc"]
246+
val = int.from_bytes(bytes(bytearray([memory[i] for i in range(pc, pc + size)])))
247+
reg["pc"] += size
248+
if size == 1:
249+
print(f"@{pc:3} read {val:08b} = {val}")
250+
else:
251+
print(f"@{pc:3} read {val:32b} = {val}")
252+
return val
253+
254+
for i in range(1000):
255+
op = get_pc(1)
256+
print(f"{i + 1}: @{reg["pc"]:4}: {op:3} = {op:02x} = {op:08b}")
257+
# print(OPS.get(op, "MV"))
258+
match op:
259+
case 0xC2: # (1 byte) ADD a <- b
260+
reg["a"] = (reg["a"] + reg["b"]) % 256
261+
case 0xE1: # 0x__ (2 bytes) APTR imm8
262+
reg["ptr"] += get_pc(1)
263+
print(f"APTR {reg["ptr"]=}")
264+
case 0xC1: # (1 byte) CMP
265+
reg["f"] = 0 if reg["a"] == reg["b"] else 1
266+
case 0x01: # (1 byte) HALT
267+
break
268+
case 0x21: # 0x__ 0x__ 0x__ 0x__ (5 bytes) JEZ imm32
269+
imm = get_pc(4)
270+
if reg["f"] == 0:
271+
reg["pc"] = imm
272+
case 0x22: # 0x__ 0x__ 0x__ 0x__ (5 bytes) JNZ imm32
273+
imm = get_pc(4)
274+
if reg["f"] != 0:
275+
reg["pc"] = imm
276+
case 0x02: # (1 byte) OUT a
277+
data_out.write(reg["a"].to_bytes(1))
278+
case 0xC3: # (1 byte) SUB a <- b
279+
reg["a"] = (reg["a"] - reg["b"]) % 256
280+
case 0xC4: # (1 byte) XOR a <- b
281+
reg["a"] = reg["a"] ^ reg["b"]
282+
case _:
283+
# 0b01DDDSSS: # (1 byte) MV {dest} <- {src}
284+
# 0b10DDDSSS: # (1 byte) MV32 {dest} <- {src}
285+
# 0b01DDD000: # 0x__ (2 bytes) MVI {dest} <- imm8
286+
# 0b10DDD000: # 0x__ 0x__ 0x__ 0x__ (5 bytes) MVI32 {dest} <- imm32
287+
long = (op & 0b11000000) == 0b10000000
288+
src = op & 0b111
289+
dst = (op >> 3) & 0b111
290+
if src != 0:
291+
val = memread(src, long)
292+
print(f"Read {val} from {src} {long}")
293+
else:
294+
val = get_pc(4 if long else 1)
295+
memwrite(dst, long, val)
296+
297+
return data_out.getvalue().decode()
298+
299+
300+
301+
PARTS = [starting, layer0, layer1, layer2, layer3, layer4, layer5, layer6]
302+
303+
if len(sys.argv) > 1:
304+
print(get_instructions(int(sys.argv[1])))
305+
layer5(get_data(5))
196306

0 commit comments

Comments
 (0)