diff --git a/docker-compose.test.yml b/docker-compose.test.yml index f7d6f5f..4eab11d 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -77,3 +77,24 @@ services: interval: 1s timeout: 3s retries: 5 + + nats-token-auth: + image: nats:2.10-alpine + network_mode: host + command: + - "--addr" + - "127.0.0.1" + - "--port" + - "14225" + - "--http_port" + - "18225" + - "--auth" + - "test_token_123" + - "-D" + - "--server_name" + - "nats-token-auth" + healthcheck: + test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://127.0.0.1:18225/healthz"] + interval: 1s + timeout: 3s + retries: 5 diff --git a/src/connection.zig b/src/connection.zig index 7995f6b..8cc2ae3 100644 --- a/src/connection.zig +++ b/src/connection.zig @@ -116,7 +116,36 @@ pub const ConnectionError = error{ NotConnected, ManualReconnect, StaleConnection, -} || PublishError || std.Thread.SpawnError || std.posix.WriteError || std.posix.ReadError; +} || PublishError || ProtocolError || std.Thread.SpawnError || std.posix.WriteError || std.posix.ReadError; + +// Protocol-specific errors from server -ERR messages (matching nats.go approach) +pub const ProtocolError = error{ + // Authentication/Authorization errors + AuthorizationViolation, // "Authorization Violation" + AuthExpired, // "User Authentication Expired" + AuthRevoked, // "User Authentication Revoked" + AccountAuthExpired, // "Account Authentication Expired" + PermissionViolation, // "Permissions Violation" + + // Connection/Limit errors + MaxConnectionsExceeded, // "maximum connections exceeded" + ConnectionThrottling, // "Connection throttling is active" + MaxPayloadViolation, // "Maximum Payload Violation" + MaxSubscriptionsExceeded, // "maximum subscriptions exceeded" + + // Protocol errors + SecureConnectionRequired, // "Secure Connection - TLS Required" + InvalidClientProtocol, // "invalid client protocol" + UnknownProtocolOperation, // "Unknown Protocol Operation" + InvalidPublishSubject, // "Invalid Publish Subject" + NoRespondersRequiresHeaders, // "no responders requires headers support" + + // Account errors + FailedAccountRegistration, // "Failed Account Registration" + + // Generic fallback + UnknownServerError, // For unrecognized -ERR messages +}; pub const ConnectionStatus = enum { closed, @@ -193,6 +222,10 @@ pub const ConnectionOptions = struct { max_scratch_size: usize = 1024 * 1024 * 10, ping_interval_ms: u64 = 120000, // 2 minutes default, 0 = disabled max_pings_out: u32 = 2, // max unanswered keep-alive PINGs + + // Authentication + token: ?[]const u8 = null, + token_handler: ?*const fn () []const u8 = null, }; pub const Connection = struct { @@ -371,6 +404,8 @@ pub const Connection = struct { } fn connectToServer(self: *Self) !void { + errdefer self.close(); + // This is called from initial connect() - needs to manage its own mutex self.mutex.lock(); defer self.mutex.unlock(); @@ -379,7 +414,6 @@ pub const Connection = struct { // Get server using C library's GetNextServer algorithm const selected_server = try self.server_pool.getNextServer(self.options.reconnect.max_reconnect, self.current_server) orelse { - self.status = .closed; return ConnectionError.ConnectionFailed; }; @@ -388,11 +422,7 @@ pub const Connection = struct { selected_server.reconnects += 1; // Establish connection (under mutex for consistent state management) - self.establishConnection(selected_server) catch |err| { - self.status = .closed; - self.cleanupFailedConnection(err, true); - return err; - }; + try self.establishConnection(selected_server); // Socket is now established and connection state is set up self.should_stop.store(false, .monotonic); @@ -404,12 +434,7 @@ pub const Connection = struct { self.flusher_thread = try std.Thread.spawn(.{}, flusherLoop, .{self}); // Wait for handshake completion - self.waitForHandshakeCompletion() catch |err| { - // Clean up failed handshake state and close socket - self.status = .closed; - self.cleanupFailedConnection(err, true); - return err; - }; + try self.waitForHandshakeCompletion(); // Handshake completed successfully self.status = .connected; @@ -1209,6 +1234,12 @@ pub const Connection = struct { // Get client name from options or use default const client_name = self.options.name orelse build_options.name; + // Get authentication token (dynamic handler takes precedence) + const auth_token = if (self.options.token_handler) |handler| + handler() + else + self.options.token; + // Create CONNECT JSON object const connect_obj = .{ .verbose = self.options.verbose, @@ -1219,6 +1250,7 @@ pub const Connection = struct { .lang = build_options.lang, .version = build_options.version, .protocol = 1, + .auth_token = auth_token, }; try buffer.writer().writeAll("CONNECT "); @@ -1308,6 +1340,53 @@ pub const Connection = struct { // No action needed for now } + /// Maps -ERR message to specific ProtocolError (similar to nats.go approach) + fn parseProtocolError(err_msg: []const u8, allocator: std.mem.Allocator) ProtocolError { + const lower_err = std.ascii.allocLowerString(allocator, err_msg) catch return ProtocolError.UnknownServerError; + defer allocator.free(lower_err); + + // Authentication/Authorization errors + if (std.mem.containsAtLeast(u8, lower_err, 1, "authorization violation")) { + return ProtocolError.AuthorizationViolation; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "user authentication expired")) { + return ProtocolError.AuthExpired; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "user authentication revoked")) { + return ProtocolError.AuthRevoked; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "account authentication expired")) { + return ProtocolError.AccountAuthExpired; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "permissions violation")) { + return ProtocolError.PermissionViolation; + } + // Connection/Limit errors + else if (std.mem.containsAtLeast(u8, lower_err, 1, "maximum connections exceeded")) { + return ProtocolError.MaxConnectionsExceeded; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "connection throttling")) { + return ProtocolError.ConnectionThrottling; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "maximum payload violation")) { + return ProtocolError.MaxPayloadViolation; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "maximum subscriptions exceeded")) { + return ProtocolError.MaxSubscriptionsExceeded; + } + // Protocol errors + else if (std.mem.containsAtLeast(u8, lower_err, 1, "secure connection") and + std.mem.containsAtLeast(u8, lower_err, 1, "tls required")) + { + return ProtocolError.SecureConnectionRequired; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "invalid client protocol")) { + return ProtocolError.InvalidClientProtocol; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "unknown protocol operation")) { + return ProtocolError.UnknownProtocolOperation; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "invalid publish subject")) { + return ProtocolError.InvalidPublishSubject; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "no responders requires headers")) { + return ProtocolError.NoRespondersRequiresHeaders; + } else if (std.mem.containsAtLeast(u8, lower_err, 1, "failed account registration")) { + return ProtocolError.FailedAccountRegistration; + } + + return ProtocolError.UnknownServerError; // Unrecognized error + } + pub fn processErr(self: *Self, err_msg: []const u8) !void { if (self.should_stop.load(.acquire)) { return error.ShouldStop; @@ -1320,14 +1399,18 @@ pub const Connection = struct { self.mutex.lock(); defer self.mutex.unlock(); - log.err("Received -ERR: {s}", .{err_msg}); + // Parse the protocol error once + const protocol_err = parseProtocolError(err_msg, self.allocator); + + log.err("Server protocol error: {} - {s}", .{ protocol_err, err_msg }); // Handle handshake failure if (self.handshake_state.isWaiting()) { - self.handshake_error = ConnectionError.AuthFailed; + // Propagate specific protocol errors to client + self.handshake_error = protocol_err; self.handshake_state = .failed; self.handshake_cond.broadcast(); // Signal handshake failure - log.debug("Handshake failed due to server error: {s}", .{err_msg}); + log.debug("Handshake failed: {}", .{protocol_err}); return; } diff --git a/src/root.zig b/src/root.zig index f1800d8..9b16ff5 100644 --- a/src/root.zig +++ b/src/root.zig @@ -18,6 +18,7 @@ pub const Connection = @import("connection.zig").Connection; pub const ConnectionOptions = @import("connection.zig").ConnectionOptions; pub const ConnectionStatus = @import("connection.zig").ConnectionStatus; pub const ConnectionError = @import("connection.zig").ConnectionError; +pub const ProtocolError = @import("connection.zig").ProtocolError; pub const PublishError = @import("connection.zig").PublishError; pub const Message = @import("message.zig").Message; pub const Subscription = @import("subscription.zig").Subscription; diff --git a/tests/all_tests.zig b/tests/all_tests.zig index 2a52376..ace1161 100644 --- a/tests/all_tests.zig +++ b/tests/all_tests.zig @@ -5,6 +5,7 @@ const utils = @import("utils.zig"); test { _ = @import("socket_test.zig"); _ = @import("minimal_test.zig"); + _ = @import("auth_test.zig"); _ = @import("headers_test.zig"); _ = @import("subscribe_test.zig"); _ = @import("autounsubscribe_test.zig"); diff --git a/tests/auth_test.zig b/tests/auth_test.zig new file mode 100644 index 0000000..500cfb7 --- /dev/null +++ b/tests/auth_test.zig @@ -0,0 +1,99 @@ +const std = @import("std"); +const nats = @import("nats"); +const utils = @import("utils.zig"); + +const log = std.log.default; + +test "token authentication success" { + // Test against actual NATS server with token auth (port 14225) + const opts = nats.ConnectionOptions{ + .token = "test_token_123", + }; + + const conn = try utils.createConnection(.token_auth, opts); + defer utils.closeConnection(conn); + + // If we reach here, authentication succeeded + // Test basic publish/subscribe to verify connection works + try conn.publish("test.auth.success", "authenticated message"); + try conn.flush(); +} + +test "token handler authentication" { + // Test token handler callback against real server + const TestTokenHandler = struct { + fn getToken() []const u8 { + return "test_token_123"; // Return valid token for auth server + } + }; + + const opts = nats.ConnectionOptions{ + .token_handler = TestTokenHandler.getToken, + }; + + const conn = try utils.createConnection(.token_auth, opts); + defer utils.closeConnection(conn); + + // If we reach here, the token handler was called and authentication succeeded + try conn.publish("test.auth.handler", "handler authenticated"); + try conn.flush(); +} + +test "token handler takes precedence over static token" { + // Test that dynamic token handler takes precedence over static token + const TestTokenHandler = struct { + fn getToken() []const u8 { + return "test_token_123"; // Valid token (handler wins) + } + }; + + const opts = nats.ConnectionOptions{ + .token = "invalid_static_token", // Invalid static token + .token_handler = TestTokenHandler.getToken, + }; + + // Should succeed because handler returns valid token + const conn = try utils.createConnection(.token_auth, opts); + defer utils.closeConnection(conn); + + // Authentication succeeded, proving handler took precedence + try conn.publish("test.auth.precedence", "handler wins"); + try conn.flush(); +} + +test "token authentication failure" { + // Test authentication failure with invalid token and short timeout + const opts = nats.ConnectionOptions{ + .token = "invalid_token", + .timeout_ms = 2000, // 2 second timeout + }; + + // This should fail with AuthFailed error + const result = utils.createConnection(.token_auth, opts); + + if (result) |conn| { + defer utils.closeConnection(conn); + // Should not reach here + std.log.err("Connection unexpectedly succeeded with invalid token", .{}); + try std.testing.expect(false); + } else |err| { + std.log.info("Got error: {}", .{err}); + // Now we get specific protocol errors + try std.testing.expect(err == nats.ProtocolError.AuthorizationViolation); + } +} + +test "no authentication options against auth server" { + // Test connection without token to auth server (should fail) + const opts = nats.ConnectionOptions{}; + + const result = utils.createConnection(.token_auth, opts); + + if (result) |conn| { + defer utils.closeConnection(conn); + // Should not reach here + try std.testing.expect(false); + } else |err| { + try std.testing.expect(err == nats.ProtocolError.AuthorizationViolation); + } +} diff --git a/tests/utils.zig b/tests/utils.zig index 5061d40..19cc828 100644 --- a/tests/utils.zig +++ b/tests/utils.zig @@ -7,7 +7,8 @@ pub const Node = enum(u16) { node1 = 14222, node2 = 14223, node3 = 14224, - unknown = 14225, + token_auth = 14225, + unknown = 14226, }; pub fn createConnection(node: Node, opts: nats.ConnectionOptions) !*nats.Connection {