diff --git a/src/collection.zig b/src/collection.zig index d437831..3f2d544 100644 --- a/src/collection.zig +++ b/src/collection.zig @@ -1471,6 +1471,7 @@ pub const Database = struct { data_dir_len: usize, alloc: std.mem.Allocator, mu: std.Thread.RwLock, + auth: @import("auth.zig").AuthStore, pub const TenantQuota = struct { max_collections: u32 = std.math.maxInt(u32), @@ -1503,6 +1504,7 @@ pub const Database = struct { try db.cdc.start(); db.alloc = alloc; db.mu = .{}; + db.auth = .{}; const n = @min(resolved_data_dir.len, 255); @memcpy(db.data_dir_buf[0..n], resolved_data_dir[0..n]); diff --git a/src/main.zig b/src/main.zig index 42f51c8..f90732c 100644 --- a/src/main.zig +++ b/src/main.zig @@ -18,6 +18,7 @@ pub fn main() !void { var use_wire: bool = true; // wire protocol by default var use_http: bool = false; var unix_path: ?[]const u8 = null; + var auth_key: ?[]const u8 = null; // Replication flags var repl_enabled: bool = false; @@ -44,6 +45,9 @@ pub fn main() !void { } else if (std.mem.eql(u8, args[i], "--unix") and i + 1 < args.len) { i += 1; unix_path = args[i]; + } else if (std.mem.eql(u8, args[i], "--auth-key") and i + 1 < args.len) { + i += 1; + auth_key = args[i]; } else if (std.mem.eql(u8, args[i], "--replicate")) { repl_enabled = true; } else if (std.mem.eql(u8, args[i], "--node-id") and i + 1 < args.len) { @@ -67,6 +71,7 @@ pub fn main() !void { \\ --http HTTP REST API \\ --both run wire + HTTP (wire on port, HTTP on port+1) \\ --unix also listen on a Unix domain socket + \\ --auth-key require this API key for all requests \\ \\Replication (Calvin deterministic): \\ --replicate enable Calvin replication @@ -87,6 +92,9 @@ pub fn main() !void { \\ , .{}); return; + } else { + std.log.err("unknown flag: {s}", .{args[i]}); + return error.InvalidArgument; } } @@ -101,6 +109,12 @@ pub fn main() !void { const db = try collection.Database.open(alloc, data_dir); defer db.close(); + // ── configure auth ──────────────────────────────────────────────────── + if (auth_key) |key| { + _ = db.auth.addKey(key, "cli", .admin); + std.log.info("Auth enabled (--auth-key)", .{}); + } + // ── replication setup ───────────────────────────────────────────────── if (repl_enabled) { std.log.info("Calvin replication: node={d} leader={} repl_port={d}", .{ diff --git a/src/server.zig b/src/server.zig index 081d658..7766b01 100644 --- a/src/server.zig +++ b/src/server.zig @@ -12,6 +12,7 @@ /// GET /context/:col smart context discovery (q, limit query params) const std = @import("std"); const activity = @import("activity.zig"); +const auth = @import("auth.zig"); const collection = @import("collection.zig"); const Database = collection.Database; @@ -233,6 +234,14 @@ fn dispatch(srv: *Server, raw: []const u8, alloc: std.mem.Allocator) usize { return ok(getBodyBuf()[0..fbs.pos]); } + // ── Auth gate — public endpoints above, protected endpoints below ──── + if (srv.db.auth.isEnabled()) { + const api_key = auth.AuthStore.extractHttpKey(raw) orelse + return err(401, "unauthorized — missing X-Api-Key header"); + if (srv.db.auth.verify(api_key) == null) + return err(401, "unauthorized — invalid API key"); + } + if (std.mem.eql(u8, path, "/billing") and std.mem.eql(u8, method, "GET")) return handleBillingLog(srv); @@ -760,6 +769,7 @@ fn err(code: u16, msg: []const u8) usize { const body = std.fmt.bufPrint(&scratch, "{{\"error\":\"{s}\"}}", .{msg}) catch msg; const status = switch (code) { 400 => "Bad Request", + 401 => "Unauthorized", 429 => "Too Many Requests", 404 => "Not Found", else => "Internal Server Error",