Skip to content

Commit 0a64a6a

Browse files
authored
Merge pull request #2 from ChainSafe/te/shuffling
2 parents cc2c135 + 5ad9ae3 commit 0a64a6a

8 files changed

Lines changed: 809 additions & 142 deletions

File tree

build.zig

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,15 @@ pub fn build(b: *std.Build) void {
4343

4444
const sharedLib = b.addSharedLibrary(.{
4545
.name = "state_transition_utils",
46-
.root_source_file = b.path("src/state_transition_utils.zig"),
46+
.root_source_file = b.path("src/root_c_abi.zig"),
4747
.target = target,
4848
.optimize = optimize,
4949
});
5050
b.installArtifact(sharedLib);
5151

52+
// need libc for threading capability
53+
sharedLib.linkLibC();
54+
5255
// This *creates* a Run step in the build graph, to be executed when another
5356
// step is evaluated that depends on it. The next line below will establish
5457
// such a dependency.
@@ -73,7 +76,7 @@ pub fn build(b: *std.Build) void {
7376
run_step.dependOn(&run_cmd.step);
7477

7578
const shared_lib_unit_tests = b.addTest(.{
76-
.root_source_file = b.path("src/state_transition_utils.zig"),
79+
.root_source_file = b.path("src/root_c_abi.zig"),
7780
.target = target,
7881
.optimize = optimize,
7982
});

src/error.zig

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
pub const ErrorCode = struct {
2+
pub const Success: c_uint = 0;
3+
pub const InvalidInput: c_uint = 1;
4+
pub const Error: c_uint = 2;
5+
pub const TooManyThreadError: c_uint = 2;
6+
pub const MemoryError: c_uint = 3;
7+
pub const ThreadError: c_uint = 4;
8+
pub const InvalidPointerError: c_uint = 5;
9+
pub const Pending: c_uint = 10;
10+
};

src/root_c_abi.zig

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
const std = @import("std");
2+
const Mutex = std.Thread.Mutex;
3+
pub const PubkeyIndexMap = @import("pubkey_index_map.zig").PubkeyIndexMap;
4+
const PUBKEY_INDEX_MAP_KEY_SIZE = @import("pubkey_index_map.zig").PUBKEY_INDEX_MAP_KEY_SIZE;
5+
const innerShuffleList = @import("shuffle.zig").innerShuffleList;
6+
const SEED_SIZE = @import("shuffle.zig").SEED_SIZE;
7+
const ErrorCode = @import("error.zig").ErrorCode;
8+
9+
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
10+
// this special index 4,294,967,295 is used to mark a not found
11+
const NOT_FOUND_INDEX: c_uint = 0xffffffff;
12+
13+
/// C-ABI functions for PubkeyIndexMap
14+
/// create an instance of PubkeyIndexMap
15+
/// this returns a pointer to the instance in the heap which we can use in the following functions
16+
export fn createPubkeyIndexMap() u64 {
17+
const allocator = gpa.allocator();
18+
const instance_ptr = PubkeyIndexMap.init(allocator) catch return 0;
19+
return @intFromPtr(instance_ptr);
20+
}
21+
22+
/// destroy an instance of PubkeyIndexMap
23+
export fn destroyPubkeyIndexMap(nbr_ptr: u64) void {
24+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
25+
instance_ptr.deinit();
26+
}
27+
28+
/// synchronize this special index to Bun
29+
export fn getNotFoundIndex() c_uint {
30+
return NOT_FOUND_INDEX;
31+
}
32+
33+
/// set a value to the specified PubkeyIndexMap instance
34+
export fn pubkeyIndexMapSet(nbr_ptr: u64, key: [*c]const u8, key_length: c_uint, value: c_uint) c_uint {
35+
if (key_length != PUBKEY_INDEX_MAP_KEY_SIZE) {
36+
return ErrorCode.InvalidInput;
37+
}
38+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
39+
instance_ptr.set(key[0..key_length], value) catch return ErrorCode.Error;
40+
return ErrorCode.Success;
41+
}
42+
43+
/// get a value from the specified PubkeyIndexMap instance
44+
export fn pubkeyIndexMapGet(nbr_ptr: u64, key: [*c]const u8, key_length: c_uint) c_uint {
45+
if (key_length != PUBKEY_INDEX_MAP_KEY_SIZE) {
46+
return NOT_FOUND_INDEX;
47+
}
48+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
49+
const value = instance_ptr.get(key[0..key_length]) orelse return NOT_FOUND_INDEX;
50+
return value;
51+
}
52+
53+
/// clear all values from the specified PubkeyIndexMap instance
54+
export fn pubkeyIndexMapClear(nbr_ptr: u64) void {
55+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
56+
instance_ptr.clear();
57+
}
58+
59+
/// clone the specified PubkeyIndexMap instance
60+
/// this returns a pointer to the new instance in the heap
61+
export fn pubkeyIndexMapClone(nbr_ptr: u64) u64 {
62+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
63+
const clone_ptr = instance_ptr.clone() catch return 0;
64+
return @intFromPtr(clone_ptr);
65+
}
66+
67+
/// check if the specified PubkeyIndexMap instance has the specified key
68+
export fn pubkeyIndexMapHas(nbr_ptr: u64, key: [*c]const u8, key_length: c_uint) bool {
69+
if (key_length != PUBKEY_INDEX_MAP_KEY_SIZE) {
70+
return false;
71+
}
72+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
73+
return instance_ptr.has(key[0..key_length]);
74+
}
75+
76+
/// delete the specified key from the specified PubkeyIndexMap instance
77+
export fn pubkeyIndexMapDelete(nbr_ptr: u64, key: [*c]const u8, key_length: c_uint) bool {
78+
if (key_length != PUBKEY_INDEX_MAP_KEY_SIZE) {
79+
return false;
80+
}
81+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
82+
return instance_ptr.delete(key[0..key_length]);
83+
}
84+
85+
/// get the size of the specified PubkeyIndexMap instance
86+
export fn pubkeyIndexMapSize(nbr_ptr: u64) c_uint {
87+
const instance_ptr: *PubkeyIndexMap = @ptrFromInt(nbr_ptr);
88+
return instance_ptr.size();
89+
}
90+
91+
/// C-ABI functions for shuffle_list
92+
/// on Ethereum consensus, shuffling is called once per epoch so this is more than enough
93+
/// don't want to have too big value here so that we can detect issue sooner
94+
const MAX_ASYNC_RESULT_SIZE = 4;
95+
var mutex: Mutex = Mutex{};
96+
var async_result_pointer_indices: [MAX_ASYNC_RESULT_SIZE]u64 = [_]u64{0} ** MAX_ASYNC_RESULT_SIZE;
97+
var async_result_index: usize = 0;
98+
const Status = enum {
99+
Pending,
100+
Done,
101+
Error,
102+
};
103+
104+
/// object to store result from another thread and for bun to poll
105+
const AsyncResult = struct {
106+
allocator: std.mem.Allocator,
107+
status: Status,
108+
mutex: Mutex,
109+
110+
// can put any result here but no need for shuffling apis
111+
pub fn init(allocator: std.mem.Allocator) !*@This() {
112+
const instance_ptr = try allocator.create(@This());
113+
instance_ptr.allocator = allocator;
114+
instance_ptr.status = Status.Pending;
115+
instance_ptr.mutex = Mutex{};
116+
return instance_ptr;
117+
}
118+
119+
pub fn updateStatus(self: *@This(), new_status: Status) void {
120+
self.mutex.lock();
121+
defer self.mutex.unlock();
122+
self.status = new_status;
123+
}
124+
125+
// Get status safely while locking the mutex
126+
pub fn getStatus(self: *@This()) Status {
127+
self.mutex.lock();
128+
defer self.mutex.unlock();
129+
return self.status;
130+
}
131+
132+
pub fn deinit(self: *@This()) void {
133+
self.allocator.destroy(self);
134+
}
135+
};
136+
137+
/// shuffle the `active_indices` array in place asynchronously
138+
/// return an u64 which is the index within `MAX_ASYNC_RESULT_SIZE`
139+
/// consumer needs to poll the AsyncResult via pollAsyncResult() using that index and
140+
/// then release the AsyncResult via releaseAsyncResult() when done
141+
export fn asyncShuffleList(active_indices: [*c]u32, len: usize, seed: [*c]const u8, seed_len: usize, rounds: u8) usize {
142+
const forwards = true;
143+
return doAsyncShuffleList(active_indices, len, seed, seed_len, rounds, forwards);
144+
}
145+
146+
/// unshuffle the `active_indices` array in place asynchronously
147+
/// return an u64 which is the index within `MAX_ASYNC_RESULT_SIZE`
148+
/// consumer needs to poll the AsyncResult via pollAsyncResult() using that index and
149+
/// then release the AsyncResult via releaseAsyncResult() when done
150+
export fn asyncUnshuffleList(active_indices: [*c]u32, len: usize, seed: [*c]const u8, seed_len: usize, rounds: u8) usize {
151+
const forwards = false;
152+
return doAsyncShuffleList(active_indices, len, seed, seed_len, rounds, forwards);
153+
}
154+
155+
fn doAsyncShuffleList(active_indices: [*c]u32, len: usize, seed: [*c]const u8, seed_len: usize, rounds: u8, forwards: bool) usize {
156+
if (len == 0 or seed_len == 0) {
157+
return ErrorCode.InvalidInput;
158+
}
159+
mutex.lock();
160+
defer mutex.unlock();
161+
// too many threads on-going for async result
162+
if (async_result_pointer_indices[(async_result_index + 1) % MAX_ASYNC_RESULT_SIZE] != 0) {
163+
return ErrorCode.TooManyThreadError;
164+
}
165+
async_result_index += 1;
166+
const pointer_index = async_result_index % MAX_ASYNC_RESULT_SIZE;
167+
168+
const allocator = gpa.allocator();
169+
const result = AsyncResult.init(allocator) catch return ErrorCode.MemoryError;
170+
async_result_pointer_indices[pointer_index] = @intFromPtr(result);
171+
172+
// this is called really sparsely, so we can just spawn new thread instead of using a thread pool like in blst-z
173+
const thread = std.Thread.spawn(.{}, struct {
174+
pub fn run(_active_indices: [*c]u32, _len: usize, _seed: [*c]const u8, _seed_len: usize, _rounds: u8, _forwards: bool, _result: *AsyncResult) void {
175+
innerShuffleList(
176+
_active_indices[0.._len],
177+
_seed[0.._seed_len],
178+
_rounds,
179+
_forwards,
180+
) catch {
181+
_result.updateStatus(Status.Error);
182+
return;
183+
};
184+
_result.updateStatus(Status.Done);
185+
}
186+
}.run, .{ active_indices, len, seed, seed_len, rounds, forwards, result }) catch return ErrorCode.ThreadError;
187+
188+
thread.detach();
189+
190+
return pointer_index;
191+
}
192+
193+
/// bun to store a pointer index
194+
/// zig to get pointer u64 from async_result_pointer_indices and restore AsyncResult pointer
195+
/// then release it
196+
export fn releaseAsyncResult(pointer_index_param: usize) void {
197+
mutex.lock();
198+
defer mutex.unlock();
199+
const pointer_index = pointer_index_param % MAX_ASYNC_RESULT_SIZE;
200+
const async_result_ptr = async_result_pointer_indices[pointer_index];
201+
// avoid double-free
202+
if (async_result_ptr == 0) {
203+
return;
204+
}
205+
const result_ptr: *AsyncResult = @ptrFromInt(async_result_ptr);
206+
result_ptr.deinit();
207+
// native pointer cannot be 0 https://zig.guide/language-basics/pointers/
208+
async_result_pointer_indices[pointer_index] = 0;
209+
}
210+
211+
/// bun to store a pointer index
212+
/// zig to get pointer u64 from async_result_pointer_indices and restore AsyncResult pointer
213+
/// then check value inside it
214+
export fn pollAsyncResult(pointer_index_param: usize) c_uint {
215+
mutex.lock();
216+
defer mutex.unlock();
217+
const pointer_index = pointer_index_param % MAX_ASYNC_RESULT_SIZE;
218+
const async_result_ptr = async_result_pointer_indices[pointer_index];
219+
// native pointer cannot be 0 https://zig.guide/language-basics/pointers/
220+
if (async_result_ptr == 0) {
221+
return ErrorCode.InvalidPointerError;
222+
}
223+
const result_ptr: *AsyncResult = @ptrFromInt(async_result_ptr);
224+
const status = result_ptr.getStatus();
225+
if (status == Status.Done) {
226+
return ErrorCode.Success;
227+
} else if (status == Status.Error) {
228+
return ErrorCode.Error;
229+
}
230+
return ErrorCode.Pending;
231+
}
232+
233+
/// shuffle the `active_indices` array in place synchronously
234+
export fn shuffleList(active_indices: [*c]u32, len: usize, seed: [*c]u8, seed_len: usize, rounds: u8) c_uint {
235+
const forwards = true;
236+
return doShuffleList(active_indices, len, seed, seed_len, rounds, forwards);
237+
}
238+
239+
/// unshuffle the `active_indices` array in place synchronously
240+
export fn unshuffleList(active_indices: [*c]u32, len: usize, seed: [*c]u8, seed_len: usize, rounds: u8) c_uint {
241+
const forwards = false;
242+
return doShuffleList(active_indices, len, seed, seed_len, rounds, forwards);
243+
}
244+
245+
export fn doShuffleList(active_indices: [*c]u32, len: usize, seed: [*c]u8, seed_len: usize, rounds: u8, forwards: bool) c_uint {
246+
if (len == 0 or seed_len == 0) {
247+
return ErrorCode.InvalidInput;
248+
}
249+
250+
innerShuffleList(
251+
active_indices[0..len],
252+
seed[0..seed_len],
253+
rounds,
254+
forwards,
255+
) catch return ErrorCode.Error;
256+
return ErrorCode.Success;
257+
}
258+
259+
test "PubkeyIndexMap C-ABI functions" {
260+
const map = createPubkeyIndexMap();
261+
defer destroyPubkeyIndexMap(map);
262+
263+
var key: [PUBKEY_INDEX_MAP_KEY_SIZE]u8 = [_]u8{5} ** PUBKEY_INDEX_MAP_KEY_SIZE;
264+
const value = 42;
265+
_ = pubkeyIndexMapSet(map, &key[0], key.len, value);
266+
var result = pubkeyIndexMapGet(map, &key[0], key.len);
267+
try std.testing.expect(result == value);
268+
269+
// change key
270+
key[1] = 1;
271+
result = pubkeyIndexMapGet(map, &key[0], key.len);
272+
try std.testing.expect(result == 0xffffffff);
273+
274+
// new instance with same value
275+
const key2: [PUBKEY_INDEX_MAP_KEY_SIZE]u8 = [_]u8{5} ** PUBKEY_INDEX_MAP_KEY_SIZE;
276+
result = pubkeyIndexMapGet(map, &key2[0], key2.len);
277+
try std.testing.expect(result == value);
278+
279+
// has
280+
try std.testing.expect(pubkeyIndexMapHas(map, &key2[0], key2.len));
281+
282+
// size
283+
try std.testing.expectEqual(1, pubkeyIndexMapSize(map));
284+
const new_key = ([_]u8{255} ** PUBKEY_INDEX_MAP_KEY_SIZE)[0..];
285+
_ = pubkeyIndexMapSet(map, &new_key[0], new_key.len, 100);
286+
try std.testing.expectEqual(2, pubkeyIndexMapSize(map));
287+
288+
// delete
289+
const missing_key = ([_]u8{254} ** PUBKEY_INDEX_MAP_KEY_SIZE)[0..];
290+
var del_res = pubkeyIndexMapDelete(map, &missing_key[0], missing_key.len);
291+
try std.testing.expect(!del_res);
292+
del_res = pubkeyIndexMapDelete(map, &new_key[0], new_key.len);
293+
try std.testing.expect(del_res);
294+
try std.testing.expectEqual(1, pubkeyIndexMapSize(map));
295+
296+
// clone
297+
const cloned_map = pubkeyIndexMapClone(map);
298+
defer destroyPubkeyIndexMap(cloned_map);
299+
try std.testing.expectEqual(1, pubkeyIndexMapSize(cloned_map));
300+
result = pubkeyIndexMapGet(cloned_map, &key2[0], key2.len);
301+
try std.testing.expect(result == value);
302+
303+
// clear
304+
pubkeyIndexMapClear(map);
305+
try std.testing.expectEqual(0, pubkeyIndexMapSize(map));
306+
307+
// cloned instance is not affected
308+
try std.testing.expectEqual(1, pubkeyIndexMapSize(cloned_map));
309+
}
310+
311+
// more tests for async shuffle and unshuffle at bun side
312+
test "asyncShuffleList - issue single thread and poll the result" {
313+
var input = [_]u32{ 0, 1, 2, 3, 4, 5, 6, 7, 8 };
314+
var seed = [_]u8{0} ** SEED_SIZE;
315+
const rounds = 32;
316+
317+
const pointer_index = asyncUnshuffleList(&input[0], input.len, &seed[0], seed.len, rounds);
318+
defer releaseAsyncResult(pointer_index);
319+
320+
// poll the AsyncResult, this should happen in less than 100ms or the test wil fail
321+
const start = std.time.milliTimestamp();
322+
while (std.time.milliTimestamp() - start < 100) {
323+
const status = pollAsyncResult(pointer_index);
324+
if (status == ErrorCode.Success) {
325+
const expected = [_]u32{ 6, 2, 3, 5, 1, 7, 8, 0, 4 };
326+
try std.testing.expectEqualSlices(u32, expected[0..], input[0..]);
327+
return;
328+
}
329+
std.time.sleep(10 * std.time.ns_per_ms);
330+
}
331+
332+
// after 100ms and still pending, this is a failure
333+
try std.testing.expect(false);
334+
}

0 commit comments

Comments
 (0)