Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,24 @@ services:
interval: 1s
timeout: 3s
retries: 5

nats-token-auth:
image: nats:2.10-alpine
network_mode: host
command:
- "--addr"
- "127.0.0.1"
- "--port"
- "14225"
- "--http_port"
- "18225"
- "--auth"
- "test_token_123"
- "-D"
- "--server_name"
- "nats-token-auth"
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://127.0.0.1:18225/healthz"]
interval: 1s
timeout: 3s
retries: 5
115 changes: 99 additions & 16 deletions src/connection.zig
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,36 @@ pub const ConnectionError = error{
NotConnected,
ManualReconnect,
StaleConnection,
} || PublishError || std.Thread.SpawnError || std.posix.WriteError || std.posix.ReadError;
} || PublishError || ProtocolError || std.Thread.SpawnError || std.posix.WriteError || std.posix.ReadError;

// Protocol-specific errors from server -ERR messages (matching nats.go approach)
pub const ProtocolError = error{
// Authentication/Authorization errors
AuthorizationViolation, // "Authorization Violation"
AuthExpired, // "User Authentication Expired"
AuthRevoked, // "User Authentication Revoked"
AccountAuthExpired, // "Account Authentication Expired"
PermissionViolation, // "Permissions Violation"

// Connection/Limit errors
MaxConnectionsExceeded, // "maximum connections exceeded"
ConnectionThrottling, // "Connection throttling is active"
MaxPayloadViolation, // "Maximum Payload Violation"
MaxSubscriptionsExceeded, // "maximum subscriptions exceeded"

// Protocol errors
SecureConnectionRequired, // "Secure Connection - TLS Required"
InvalidClientProtocol, // "invalid client protocol"
UnknownProtocolOperation, // "Unknown Protocol Operation"
InvalidPublishSubject, // "Invalid Publish Subject"
NoRespondersRequiresHeaders, // "no responders requires headers support"

// Account errors
FailedAccountRegistration, // "Failed Account Registration"

// Generic fallback
UnknownServerError, // For unrecognized -ERR messages
};

pub const ConnectionStatus = enum {
closed,
Expand Down Expand Up @@ -193,6 +222,10 @@ pub const ConnectionOptions = struct {
max_scratch_size: usize = 1024 * 1024 * 10,
ping_interval_ms: u64 = 120000, // 2 minutes default, 0 = disabled
max_pings_out: u32 = 2, // max unanswered keep-alive PINGs

// Authentication
token: ?[]const u8 = null,
token_handler: ?*const fn () []const u8 = null,
};

pub const Connection = struct {
Expand Down Expand Up @@ -371,6 +404,8 @@ pub const Connection = struct {
}

fn connectToServer(self: *Self) !void {
errdefer self.close();

// This is called from initial connect() - needs to manage its own mutex
self.mutex.lock();
defer self.mutex.unlock();
Expand All @@ -379,7 +414,6 @@ pub const Connection = struct {

// Get server using C library's GetNextServer algorithm
const selected_server = try self.server_pool.getNextServer(self.options.reconnect.max_reconnect, self.current_server) orelse {
self.status = .closed;
return ConnectionError.ConnectionFailed;
};

Expand All @@ -388,11 +422,7 @@ pub const Connection = struct {
selected_server.reconnects += 1;

// Establish connection (under mutex for consistent state management)
self.establishConnection(selected_server) catch |err| {
self.status = .closed;
self.cleanupFailedConnection(err, true);
return err;
};
try self.establishConnection(selected_server);

// Socket is now established and connection state is set up
self.should_stop.store(false, .monotonic);
Expand All @@ -404,12 +434,7 @@ pub const Connection = struct {
self.flusher_thread = try std.Thread.spawn(.{}, flusherLoop, .{self});

// Wait for handshake completion
self.waitForHandshakeCompletion() catch |err| {
// Clean up failed handshake state and close socket
self.status = .closed;
self.cleanupFailedConnection(err, true);
return err;
};
try self.waitForHandshakeCompletion();

// Handshake completed successfully
self.status = .connected;
Expand Down Expand Up @@ -1209,6 +1234,12 @@ pub const Connection = struct {
// Get client name from options or use default
const client_name = self.options.name orelse build_options.name;

// Get authentication token (dynamic handler takes precedence)
const auth_token = if (self.options.token_handler) |handler|
handler()
else
self.options.token;

// Create CONNECT JSON object
const connect_obj = .{
.verbose = self.options.verbose,
Expand All @@ -1219,6 +1250,7 @@ pub const Connection = struct {
.lang = build_options.lang,
.version = build_options.version,
.protocol = 1,
.auth_token = auth_token,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Ensure we don’t emit "auth_token": null

If auth_token is null and stringify emits nulls, some servers reject the CONNECT. Set the option explicitly.

-        try std.json.stringify(connect_obj, .{}, buffer.writer());
+        try std.json.stringify(connect_obj, .{ .emit_null_optional_fields = false }, buffer.writer());

🌐 Web query:

In current Zig std (0.12.x), what is the default for std.json.stringify option `emit_null_optional_fields`? Is it false by default?

💡 Result:

No — in Zig std 0.12.x the default for StringifyOptions.emit_null_optional_fields is true. [1][2]

Sources:
[1] std/json/stringify.zig (browseable) — shows emit_null_optional_fields: bool = true.
[2] std/lib source copy (Fossies) — documents Options.emit_null_optional_fields and the behavior.


Don't emit "auth_token": null — explicitly disable emitting null optional fields.

Zig std 0.12.x defaults emit_null_optional_fields = true, so stringify will emit nulls; set it to false.

-        try std.json.stringify(connect_obj, .{}, buffer.writer());
+        try std.json.stringify(connect_obj, .{ .emit_null_optional_fields = false }, buffer.writer());
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.auth_token = auth_token,
.auth_token = auth_token,
};
try std.json.stringify(connect_obj, .{ .emit_null_optional_fields = false }, buffer.writer());
🤖 Prompt for AI Agents
In src/connection.zig around line 1253, the JSON output is emitting
"auth_token": null because Zig 0.12.x defaults emit_null_optional_fields = true;
update the stringify call or the StringifyOptions used there to set
emit_null_optional_fields = false so optional fields that are null are omitted
from output (i.e., explicitly disable emitting null optional fields in the
options passed to std.json.stringify/format).

};

try buffer.writer().writeAll("CONNECT ");
Expand Down Expand Up @@ -1308,6 +1340,53 @@ pub const Connection = struct {
// No action needed for now
}

/// Maps -ERR message to specific ProtocolError (similar to nats.go approach)
fn parseProtocolError(err_msg: []const u8, allocator: std.mem.Allocator) ProtocolError {
const lower_err = std.ascii.allocLowerString(allocator, err_msg) catch return ProtocolError.UnknownServerError;
defer allocator.free(lower_err);

// Authentication/Authorization errors
if (std.mem.containsAtLeast(u8, lower_err, 1, "authorization violation")) {
return ProtocolError.AuthorizationViolation;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "user authentication expired")) {
return ProtocolError.AuthExpired;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "user authentication revoked")) {
return ProtocolError.AuthRevoked;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "account authentication expired")) {
return ProtocolError.AccountAuthExpired;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "permissions violation")) {
return ProtocolError.PermissionViolation;
}
// Connection/Limit errors
else if (std.mem.containsAtLeast(u8, lower_err, 1, "maximum connections exceeded")) {
return ProtocolError.MaxConnectionsExceeded;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "connection throttling")) {
return ProtocolError.ConnectionThrottling;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "maximum payload violation")) {
return ProtocolError.MaxPayloadViolation;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "maximum subscriptions exceeded")) {
return ProtocolError.MaxSubscriptionsExceeded;
}
// Protocol errors
else if (std.mem.containsAtLeast(u8, lower_err, 1, "secure connection") and
std.mem.containsAtLeast(u8, lower_err, 1, "tls required"))
{
return ProtocolError.SecureConnectionRequired;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "invalid client protocol")) {
return ProtocolError.InvalidClientProtocol;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "unknown protocol operation")) {
return ProtocolError.UnknownProtocolOperation;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "invalid publish subject")) {
return ProtocolError.InvalidPublishSubject;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "no responders requires headers")) {
return ProtocolError.NoRespondersRequiresHeaders;
} else if (std.mem.containsAtLeast(u8, lower_err, 1, "failed account registration")) {
return ProtocolError.FailedAccountRegistration;
}

return ProtocolError.UnknownServerError; // Unrecognized error
}

pub fn processErr(self: *Self, err_msg: []const u8) !void {
if (self.should_stop.load(.acquire)) {
return error.ShouldStop;
Expand All @@ -1320,14 +1399,18 @@ pub const Connection = struct {
self.mutex.lock();
defer self.mutex.unlock();

log.err("Received -ERR: {s}", .{err_msg});
// Parse the protocol error once
const protocol_err = parseProtocolError(err_msg, self.allocator);

log.err("Server protocol error: {} - {s}", .{ protocol_err, err_msg });

// Handle handshake failure
if (self.handshake_state.isWaiting()) {
self.handshake_error = ConnectionError.AuthFailed;
// Propagate specific protocol errors to client
self.handshake_error = protocol_err;
self.handshake_state = .failed;
self.handshake_cond.broadcast(); // Signal handshake failure
log.debug("Handshake failed due to server error: {s}", .{err_msg});
log.debug("Handshake failed: {}", .{protocol_err});
return;
}

Expand Down
1 change: 1 addition & 0 deletions src/root.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub const Connection = @import("connection.zig").Connection;
pub const ConnectionOptions = @import("connection.zig").ConnectionOptions;
pub const ConnectionStatus = @import("connection.zig").ConnectionStatus;
pub const ConnectionError = @import("connection.zig").ConnectionError;
pub const ProtocolError = @import("connection.zig").ProtocolError;
pub const PublishError = @import("connection.zig").PublishError;
pub const Message = @import("message.zig").Message;
pub const Subscription = @import("subscription.zig").Subscription;
Expand Down
1 change: 1 addition & 0 deletions tests/all_tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const utils = @import("utils.zig");
test {
_ = @import("socket_test.zig");
_ = @import("minimal_test.zig");
_ = @import("auth_test.zig");
_ = @import("headers_test.zig");
_ = @import("subscribe_test.zig");
_ = @import("autounsubscribe_test.zig");
Expand Down
99 changes: 99 additions & 0 deletions tests/auth_test.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
const std = @import("std");
const nats = @import("nats");
const utils = @import("utils.zig");

const log = std.log.default;

test "token authentication success" {
// Test against actual NATS server with token auth (port 14225)
const opts = nats.ConnectionOptions{
.token = "test_token_123",
};

const conn = try utils.createConnection(.token_auth, opts);
defer utils.closeConnection(conn);

// If we reach here, authentication succeeded
// Test basic publish/subscribe to verify connection works
try conn.publish("test.auth.success", "authenticated message");
try conn.flush();
}

test "token handler authentication" {
// Test token handler callback against real server
const TestTokenHandler = struct {
fn getToken() []const u8 {
return "test_token_123"; // Return valid token for auth server
}
};

const opts = nats.ConnectionOptions{
.token_handler = TestTokenHandler.getToken,
};

const conn = try utils.createConnection(.token_auth, opts);
defer utils.closeConnection(conn);

// If we reach here, the token handler was called and authentication succeeded
try conn.publish("test.auth.handler", "handler authenticated");
try conn.flush();
}

test "token handler takes precedence over static token" {
// Test that dynamic token handler takes precedence over static token
const TestTokenHandler = struct {
fn getToken() []const u8 {
return "test_token_123"; // Valid token (handler wins)
}
};

const opts = nats.ConnectionOptions{
.token = "invalid_static_token", // Invalid static token
.token_handler = TestTokenHandler.getToken,
};

// Should succeed because handler returns valid token
const conn = try utils.createConnection(.token_auth, opts);
defer utils.closeConnection(conn);

// Authentication succeeded, proving handler took precedence
try conn.publish("test.auth.precedence", "handler wins");
try conn.flush();
}

test "token authentication failure" {
// Test authentication failure with invalid token and short timeout
const opts = nats.ConnectionOptions{
.token = "invalid_token",
.timeout_ms = 2000, // 2 second timeout
};

// This should fail with AuthFailed error
const result = utils.createConnection(.token_auth, opts);

if (result) |conn| {
defer utils.closeConnection(conn);
// Should not reach here
std.log.err("Connection unexpectedly succeeded with invalid token", .{});
try std.testing.expect(false);
} else |err| {
std.log.info("Got error: {}", .{err});
// Now we get specific protocol errors
try std.testing.expect(err == nats.ProtocolError.AuthorizationViolation);
}
}

test "no authentication options against auth server" {
// Test connection without token to auth server (should fail)
const opts = nats.ConnectionOptions{};

const result = utils.createConnection(.token_auth, opts);

if (result) |conn| {
defer utils.closeConnection(conn);
// Should not reach here
try std.testing.expect(false);
} else |err| {
try std.testing.expect(err == nats.ProtocolError.AuthorizationViolation);
}
}
3 changes: 2 additions & 1 deletion tests/utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ pub const Node = enum(u16) {
node1 = 14222,
node2 = 14223,
node3 = 14224,
unknown = 14225,
token_auth = 14225,
unknown = 14226,
};

pub fn createConnection(node: Node, opts: nats.ConnectionOptions) !*nats.Connection {
Expand Down