Skip to content

Commit 6e8aa99

Browse files
committed
performance optimizations
1 parent 5e22f08 commit 6e8aa99

3 files changed

Lines changed: 174 additions & 11 deletions

File tree

src/main.zig

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ pub fn main() !void {
1919
});
2020
defer net_server.deinit();
2121

22+
// Persistent HTTP client — reused across all requests for connection pooling.
23+
var client = std.http.Client{ .allocator = allocator };
24+
defer client.deinit();
25+
2226
while (true) {
2327
var connection = net_server.accept() catch |err| {
2428
std.debug.print("Error accepting connection: {}\n", .{err});
2529
continue;
2630
};
27-
// Explicitly handle sequentially for MVP
2831
defer connection.stream.close();
2932

3033
var read_buf: [16 * 1024]u8 = undefined;
@@ -40,7 +43,7 @@ pub fn main() !void {
4043
continue;
4144
};
4245

43-
proxy.handleRequest(allocator, &request, target_host, target_port) catch |err| {
46+
proxy.handleRequest(allocator, &request, &client, target_host, target_port) catch |err| {
4447
std.debug.print("Error handling request: {}\n", .{err});
4548
};
4649
}

src/proxy.zig

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ const std = @import("std");
22
const http = std.http;
33
const redact = @import("redact.zig");
44

5-
pub fn handleRequest(allocator: std.mem.Allocator, request: *http.Server.Request, target_host: []const u8, target_port: u16) !void {
5+
/// Maximum length for the constructed target URL (stack-allocated).
6+
const max_url_len = 512;
7+
8+
pub fn handleRequest(allocator: std.mem.Allocator, request: *http.Server.Request, client: *std.http.Client, target_host: []const u8, target_port: u16) !void {
69
const method = request.head.method;
710
const uri_str = request.head.target;
811

912
std.debug.print("[PRX] {s} {s}\n", .{ @tagName(method), uri_str });
1013

11-
var client = std.http.Client{ .allocator = allocator };
12-
defer client.deinit();
13-
14-
const target_url_str = try std.fmt.allocPrint(allocator, "http://{s}:{d}{s}", .{ target_host, target_port, uri_str });
15-
defer allocator.free(target_url_str);
14+
// Stack-allocated URL construction — zero heap allocs per request.
15+
var url_buf: [max_url_len]u8 = undefined;
16+
const target_url_str = try std.fmt.bufPrint(&url_buf, "http://{s}:{d}{s}", .{ target_host, target_port, uri_str });
1617

1718
const target_uri = try std.Uri.parse(target_url_str);
1819

src/redact.zig

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,103 @@
11
const std = @import("std");
22

3-
/// Redacts SSNs in-place in the given buffer.
4-
/// Matches any sequence of XXX-XX-XXXX (digits only) and replaces the digits with '*'.
3+
// ---------------------------------------------------------------------------
4+
// SIMD-accelerated SSN redaction engine
5+
// ---------------------------------------------------------------------------
6+
// Strategy: scan for '-' characters first (rare in typical payloads) using
7+
// @Vector(16, u8) SIMD loads, then validate the full XXX-XX-XXXX digit
8+
// pattern only at candidate positions. This avoids touching most bytes.
9+
// ---------------------------------------------------------------------------
10+
11+
const vector_len = 16;
12+
const Vector = @Vector(vector_len, u8);
13+
14+
/// Produce a bitmask where bit `i` is set if `vec[i] == needle`.
15+
inline fn dashMask(vec: Vector) u16 {
16+
const dashes: Vector = @splat('-');
17+
const match_result: @Vector(vector_len, u1) = @bitCast(vec == dashes);
18+
return @bitCast(match_result);
19+
}
20+
21+
/// Validate and redact a full SSN at `buf[start .. start+11]`.
22+
/// Returns true (and mutates in-place) if the pattern matches XXX-XX-XXXX.
23+
inline fn tryRedactAt(buf: []u8, start: usize) bool {
24+
if (start + 11 > buf.len) return false;
25+
26+
const b = buf[start..][0..11];
27+
28+
// Dashes at fixed offsets
29+
if (b[3] != '-' or b[6] != '-') return false;
30+
31+
// All other positions must be ASCII digits
32+
inline for ([_]usize{ 0, 1, 2, 4, 5, 7, 8, 9, 10 }) |off| {
33+
if (!std.ascii.isDigit(b[off])) return false;
34+
}
35+
36+
// Pattern confirmed — redact digits in-place
37+
inline for ([_]usize{ 0, 1, 2, 4, 5, 7, 8, 9, 10 }) |off| {
38+
b[off] = '*';
39+
}
40+
return true;
41+
}
42+
43+
/// High-performance SSN redactor using SIMD dash scanning.
44+
/// Falls back to scalar tail processing for the last < 16 bytes.
545
pub fn redactSsn(buffer: []u8) void {
46+
if (buffer.len < 11) return;
47+
48+
var i: usize = 0;
49+
50+
// --- SIMD pass: scan 16 bytes at a time for '-' candidates -----------
51+
while (i + vector_len + 10 <= buffer.len) {
52+
// We need dashes at relative offsets +3 and +6 from any SSN start.
53+
// A dash found at absolute position `d` could be offset +3 of an SSN
54+
// starting at `d - 3`, OR offset +6 of one starting at `d - 6`.
55+
const chunk: Vector = buffer[i..][0..vector_len].*;
56+
var mask = dashMask(chunk);
57+
58+
if (mask == 0) {
59+
// No dashes in this 16-byte window — skip entirely.
60+
i += vector_len;
61+
continue;
62+
}
63+
64+
// Process each dash position in the mask
65+
while (mask != 0) {
66+
const bit_pos: u4 = @truncate(@ctz(mask));
67+
const dash_abs = i + bit_pos;
68+
69+
// This dash could be at SSN offset +3 → start = dash_abs - 3
70+
if (dash_abs >= 3) {
71+
const start = dash_abs - 3;
72+
if (tryRedactAt(buffer, start)) {
73+
// Successfully redacted — jump past this SSN entirely.
74+
// Set i to end of SSN so outer loop re-scans from there.
75+
i = start + 11;
76+
// Re-enter outer loop (mask is now stale).
77+
break;
78+
}
79+
}
80+
81+
// Clear this bit and move to next dash
82+
mask &= mask - 1;
83+
} else {
84+
// Exhausted all dashes in this chunk with no redaction — advance.
85+
i += 1;
86+
}
87+
}
88+
89+
// --- Scalar tail: handle remaining bytes that don't fill a vector -----
90+
while (i + 11 <= buffer.len) {
91+
if (tryRedactAt(buffer, i)) {
92+
i += 11;
93+
} else {
94+
i += 1;
95+
}
96+
}
97+
}
98+
99+
/// Scalar-only reference implementation (kept for benchmarking comparisons).
100+
pub fn redactSsnScalar(buffer: []u8) void {
6101
var i: usize = 0;
7102
while (i + 11 <= buffer.len) {
8103
if (std.ascii.isDigit(buffer[i]) and std.ascii.isDigit(buffer[i + 1]) and std.ascii.isDigit(buffer[i + 2]) and
@@ -28,7 +123,7 @@ pub fn redactSsn(buffer: []u8) void {
28123
}
29124

30125
// ---------------------------------------------------------------------------
31-
// Tests
126+
// Unit Tests
32127
// ---------------------------------------------------------------------------
33128

34129
test "redactSsn - basic multi-SSN redaction" {
@@ -85,3 +180,67 @@ test "redactSsn - SSN with surrounding digits" {
85180
redactSsn(&buf);
86181
try std.testing.expectEqualStrings("9***-**-****0", &buf);
87182
}
183+
184+
test "redactSsn - scalar fallback matches SIMD" {
185+
// Verify both implementations produce identical results on the same input.
186+
const input = "prefix 111-22-3333 mid 444-55-6666 end 777-88-9999 tail".*;
187+
var simd_buf = input;
188+
var scalar_buf = input;
189+
redactSsn(&simd_buf);
190+
redactSsnScalar(&scalar_buf);
191+
try std.testing.expectEqualStrings(&scalar_buf, &simd_buf);
192+
}
193+
194+
test "redactSsn - large buffer with scattered SSNs" {
195+
// 80-byte buffer with SSNs at various offsets to exercise SIMD chunking.
196+
var buf = "aaaaaaaaaa123-45-6789bbbbbbbbbb987-65-4321cccccccccc555-12-9876ddddddddddeeeeee".*;
197+
redactSsn(&buf);
198+
try std.testing.expectEqualStrings("aaaaaaaaaa***-**-****bbbbbbbbbb***-**-****cccccccccc***-**-****ddddddddddeeeeee", &buf);
199+
}
200+
201+
// ---------------------------------------------------------------------------
202+
// Benchmark
203+
// ---------------------------------------------------------------------------
204+
205+
test "bench - redactSsn throughput" {
206+
// 1 MB payload with SSNs every ~100 bytes.
207+
const payload_size = 1024 * 1024;
208+
var buf: [payload_size]u8 = undefined;
209+
210+
// Fill with 'a' and plant SSNs every 100 bytes
211+
@memset(&buf, 'a');
212+
var pos: usize = 50;
213+
while (pos + 11 <= payload_size) {
214+
@memcpy(buf[pos..][0..11], "123-45-6789");
215+
pos += 100;
216+
}
217+
218+
var timer = std.time.Timer.start() catch {
219+
// Timer not available on all platforms — skip silently.
220+
return;
221+
};
222+
223+
const iterations = 100;
224+
var run: usize = 0;
225+
while (run < iterations) : (run += 1) {
226+
// Re-plant SSNs (they get masked each iteration)
227+
pos = 50;
228+
while (pos + 11 <= payload_size) {
229+
@memcpy(buf[pos..][0..11], "123-45-6789");
230+
pos += 100;
231+
}
232+
redactSsn(&buf);
233+
}
234+
235+
const elapsed_ns = timer.read();
236+
const total_bytes = payload_size * iterations;
237+
const ns_per_byte = elapsed_ns / total_bytes;
238+
const mb_per_sec = (@as(f64, @floatFromInt(total_bytes)) / @as(f64, @floatFromInt(elapsed_ns))) * 1_000_000_000.0 / (1024.0 * 1024.0);
239+
240+
std.debug.print("\n[BENCH] SIMD redactSsn: {d} ns/byte, {d:.1} MB/s ({} iterations x {} bytes)\n", .{
241+
ns_per_byte,
242+
mb_per_sec,
243+
iterations,
244+
payload_size,
245+
});
246+
}

0 commit comments

Comments
 (0)