11#!/bin/python
22import base64
3+ import collections
34import io
45import itertools
6+ import os
57import pathlib
8+ import subprocess
69import struct
710import sys
811
@@ -106,6 +109,7 @@ def layer3(data: bytes) -> str:
106109
107110def layer4 (data : bytes ) -> str :
108111 """Layer 4/6: Network Traffic."""
112+ mask = (1 << 16 ) - 1
109113
110114 def ones_complement (x : int ) -> int :
111115 assert x <= mask , x
@@ -116,33 +120,47 @@ def checksum_ok(block: bytes) -> bool:
116120 checksum = 0
117121 for word in words :
118122 checksum += ones_complement (word )
119- if checksum >> 16 :
123+ if checksum > mask :
120124 checksum = (checksum & mask ) + 1
121- return not ones_complement (checksum )
125+ return ones_complement (checksum ) == 0
122126
123- mask = (1 << 16 ) - 1
124127 want_from = bytes (bytearray ([10 , 1 , 1 , 10 ])) # 10.1.1.10
125128 want_to = bytes (bytearray ([10 , 1 , 1 , 200 ])) # 10.1.1.200
126- i = io .BytesIO (data )
127- o = io .BytesIO ()
128- while i .tell () < len (data ):
129- ip4_header = i .read1 (20 )
130- udp_header = i .read1 (8 )
131- src = ip4_header [12 :16 ]
132- dst = ip4_header [16 :20 ]
129+ want_port = 42069
130+
131+ raw_data_in = io .BytesIO (data )
132+ data_out = io .BytesIO ()
133+ valid_count = 0
134+ packets_count = 0
135+
136+ while raw_data_in .tell () < len (data ):
137+ ipv4_header = raw_data_in .read1 (20 )
138+ udp_header = raw_data_in .read1 (8 )
133139 udp_length = int .from_bytes (udp_header [4 :6 ])
134- ip4_length = int .from_bytes (ip4_header [2 :4 ])
135- assert ip4_length == udp_length + 20
136- udp_data = i .read1 (udp_length - 8 )
140+ udp_data = raw_data_in .read1 (udp_length - 8 )
141+
142+ src = ipv4_header [12 :16 ]
143+ dst = ipv4_header [16 :20 ]
144+ udp_to_port = int .from_bytes (udp_header [2 :4 ])
145+
146+ ipv4_length = int .from_bytes (ipv4_header [2 :4 ])
147+ assert ipv4_length == udp_length + 20
148+
137149 udp_pseudo_header = src + dst + (17 ).to_bytes (2 ) + udp_length .to_bytes (2 ) + udp_header + udp_data
150+ # Pad to 16 bits.
138151 if udp_length % 2 :
139152 udp_pseudo_header += bytes (bytearray ([0 ]))
140- if src == want_from and dst == want_to and checksum_ok (ip4_header ) and checksum_ok (udp_pseudo_header ):
141- o .write (udp_data )
142- return o .getvalue ().decode ()
153+
154+ packets_count += 1
155+ if src == want_from and dst == want_to and udp_to_port == want_port and checksum_ok (ipv4_header ) and checksum_ok (udp_pseudo_header ):
156+ data_out .write (udp_data )
157+ valid_count += 1
158+ print ("Valid packets:" , valid_count , "out of" , packets_count )
159+ return data_out .getvalue ().decode ()
143160
144161
145162def layer5 (data : bytes ) -> str :
163+ """Layer 5/6: Advanced Encryption Standard."""
146164 # - First 32 bytes: The 256-bit key encrypting key (KEK).
147165 # - Next 8 bytes: The 64-bit initialization vector (IV) for
148166 # the wrapped key.
@@ -178,19 +196,111 @@ def aes_unwrap_key_and_iv(kek, wrapped):
178196 assert kiv == got_iv
179197 assert len (decrypted_data_key ) == 32
180198
181- iv = int .from_bytes (div )
182- ctr = Counter .new (AES .block_size * 8 , initial_value = iv )
183-
184- cipher = AES .new (decrypted_data_key , mode = AES .MODE_CTR , counter = ctr )
185- decrypted = cipher .decrypt (encrypted )
186- unpad = lambda s : s [:- ord (s [len (s )- 1 :])]
187- decrypted = unpad (decrypted )
188-
189- print (decrypted )
190- return decrypted .decode ()
191-
192-
193- PARTS = [starting , layer0 , layer1 , layer2 , layer3 , layer4 , layer5 ]
194-
195- print (get_instructions (int (sys .argv [1 ])))
199+ cmd = ["openssl" , "enc" , "-aes-256-ctr" , "-d" ]
200+ cmd .extend (["-K" , decrypted_data_key .hex ()])
201+ cmd .extend (["-iv" , div .hex ()])
202+ proc = subprocess .run (cmd , capture_output = True , input = encrypted )
203+ return proc .stdout .decode ()
204+
205+
206+ def layer6 (data : bytes ) -> str :
207+ """Layer 6/6: Virtual Machine."""
208+ memory = dict (enumerate (bytearray (data )))
209+ print (" " .join (f"{ i :08b} " for i in data [:15 ]))
210+ print (" " .join (f"{ i :08b} " for i in data [15 :30 ]))
211+ data_out = io .BytesIO ()
212+ reg = {i : 0 for i in "abcdef" } | {f"l{ i } " : 0 for i in "abcd" } | {i : 0 for i in ["ptr" , "pc" ]}
213+
214+ regmap = {
215+ False : dict (enumerate ("abcdef" , start = 1 )), # short
216+ True : dict (enumerate (["la" , "lb" , "lc" , "ld" , "ptr" , "pc" ], start = 1 )), # long
217+ }
218+
219+ def memread (i : int , long : bool ) -> int :
220+ if 1 <= i <= 6 :
221+ return reg [regmap [long ][i ]]
222+ elif i == 7 and not long :
223+ target = reg ["ptr" ] + reg ["c" ]
224+ if target > len (memory ):
225+ print (reg ["ptr" ], reg ["c" ], f"{ target = } { len (memory )= } " )
226+ return memory .get (target , 0 )
227+ raise ValueError ()
228+
229+ def memwrite (i : int , long : bool , val : int ) -> int :
230+ if 1 <= i <= 6 :
231+ target = regmap [long ][i ]
232+ print (f"memwrite { target } = { val } " )
233+ reg [target ] = val
234+ elif i == 7 and not long :
235+ memory [reg ["ptr" ] + reg ["c" ]] = val
236+ else :
237+ raise ValueError (f"Invalid write to destination { i } { long = } " )
238+
239+ OPS = {
240+ 0xC2 : "ADD" , 0xE1 : "APTR" , 0xC1 : "CMP" , 0x01 : "HALT" ,
241+ 0x21 : "JEZ" , 0x22 : "JNZ" , 0x02 : "OUT" , 0xC3 : "SUB" , 0xC4 : "XOR" ,
242+ }
243+
244+ def get_pc (size : int ) -> int :
245+ pc = reg ["pc" ]
246+ val = int .from_bytes (bytes (bytearray ([memory [i ] for i in range (pc , pc + size )])))
247+ reg ["pc" ] += size
248+ if size == 1 :
249+ print (f"@{ pc :3} read { val :08b} = { val } " )
250+ else :
251+ print (f"@{ pc :3} read { val :32b} = { val } " )
252+ return val
253+
254+ for i in range (1000 ):
255+ op = get_pc (1 )
256+ print (f"{ i + 1 } : @{ reg ["pc" ]:4} : { op :3} = { op :02x} = { op :08b} " )
257+ # print(OPS.get(op, "MV"))
258+ match op :
259+ case 0xC2 : # (1 byte) ADD a <- b
260+ reg ["a" ] = (reg ["a" ] + reg ["b" ]) % 256
261+ case 0xE1 : # 0x__ (2 bytes) APTR imm8
262+ reg ["ptr" ] += get_pc (1 )
263+ print (f"APTR { reg ["ptr" ]= } " )
264+ case 0xC1 : # (1 byte) CMP
265+ reg ["f" ] = 0 if reg ["a" ] == reg ["b" ] else 1
266+ case 0x01 : # (1 byte) HALT
267+ break
268+ case 0x21 : # 0x__ 0x__ 0x__ 0x__ (5 bytes) JEZ imm32
269+ imm = get_pc (4 )
270+ if reg ["f" ] == 0 :
271+ reg ["pc" ] = imm
272+ case 0x22 : # 0x__ 0x__ 0x__ 0x__ (5 bytes) JNZ imm32
273+ imm = get_pc (4 )
274+ if reg ["f" ] != 0 :
275+ reg ["pc" ] = imm
276+ case 0x02 : # (1 byte) OUT a
277+ data_out .write (reg ["a" ].to_bytes (1 ))
278+ case 0xC3 : # (1 byte) SUB a <- b
279+ reg ["a" ] = (reg ["a" ] - reg ["b" ]) % 256
280+ case 0xC4 : # (1 byte) XOR a <- b
281+ reg ["a" ] = reg ["a" ] ^ reg ["b" ]
282+ case _:
283+ # 0b01DDDSSS: # (1 byte) MV {dest} <- {src}
284+ # 0b10DDDSSS: # (1 byte) MV32 {dest} <- {src}
285+ # 0b01DDD000: # 0x__ (2 bytes) MVI {dest} <- imm8
286+ # 0b10DDD000: # 0x__ 0x__ 0x__ 0x__ (5 bytes) MVI32 {dest} <- imm32
287+ long = (op & 0b11000000 ) == 0b10000000
288+ src = op & 0b111
289+ dst = (op >> 3 ) & 0b111
290+ if src != 0 :
291+ val = memread (src , long )
292+ print (f"Read { val } from { src } { long } " )
293+ else :
294+ val = get_pc (4 if long else 1 )
295+ memwrite (dst , long , val )
296+
297+ return data_out .getvalue ().decode ()
298+
299+
300+
301+ PARTS = [starting , layer0 , layer1 , layer2 , layer3 , layer4 , layer5 , layer6 ]
302+
303+ if len (sys .argv ) > 1 :
304+ print (get_instructions (int (sys .argv [1 ])))
305+ layer5 (get_data (5 ))
196306
0 commit comments