From 9d01ac23c8f8070841c5310c99063e6afedf2611 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 07:50:14 +0000 Subject: [PATCH 1/2] Initial plan From 89a8b251bc789a40111d42913cf4a6e47edcce07 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 08:40:26 +0000 Subject: [PATCH 2/2] Add TLS support: shim chroma_server_tls_enabled, Go/Java URL scheme propagation Agent-Logs-Url: https://github.com/amikos-tech/chroma-go-local/sessions/187dd255-b99b-4339-9576-faec7646b9d7 Co-authored-by: tazarov <1157440+tazarov@users.noreply.github.com> --- config.go | 8 +++ internal/runtime/chroma.go | 21 +++++- internal/runtime/chroma_test.go | 27 +++++++ internal/runtime/config.go | 30 ++++++++ .../local/core/ServerConfigBuilder.java | 19 +++++ .../chroma/local/core/ServerSession.java | 11 ++- .../local/core/ServerConfigBuilderTest.java | 27 +++++++ .../chroma/local/core/ServerSessionTest.java | 71 +++++++++++++++---- .../chroma/local/jna/JnaChromaRuntime.java | 7 ++ .../local/panama/PanamaChromaRuntime.java | 21 +++++- shim/src/lib.rs | 59 ++++++++++++++- 11 files changed, 280 insertions(+), 21 deletions(-) diff --git a/config.go b/config.go index 5a76891..81e9c31 100644 --- a/config.go +++ b/config.go @@ -46,6 +46,14 @@ func WithOpenTelemetry(endpoint, serviceName string) ServerOption { return runtime.WithOpenTelemetry(endpoint, serviceName) } +func WithTLSCertPath(certPath string) ServerOption { + return runtime.WithTLSCertPath(certPath) +} + +func WithTLSKeyPath(keyPath string) ServerOption { + return runtime.WithTLSKeyPath(keyPath) +} + func WithRawYAML(yaml string) ServerOption { return runtime.WithRawYAML(yaml) } diff --git a/internal/runtime/chroma.go b/internal/runtime/chroma.go index ffa8cd2..92e9cc4 100644 --- a/internal/runtime/chroma.go +++ b/internal/runtime/chroma.go @@ -27,6 +27,7 @@ var ( chromaServerPort func(uintptr) int32 chromaServerAddress func(uintptr) *byte chromaServerPersistPath func(uintptr) *byte + chromaServerTLSEnabled func(uintptr) int32 chromaServerStop func(uintptr) int32 chromaServerFree func(uintptr) chromaEmbeddedStart func(*byte) uintptr @@ -94,6 +95,7 @@ func registerFunctions() error { {&chromaServerPort, "chroma_server_port"}, {&chromaServerAddress, "chroma_server_address"}, {&chromaServerPersistPath, "chroma_server_persist_path"}, + {&chromaServerTLSEnabled, "chroma_server_tls_enabled"}, {&chromaServerStop, "chroma_server_stop"}, {&chromaServerFree, "chroma_server_free"}, {&chromaEmbeddedStart, "chroma_embedded_start"}, @@ -220,6 +222,7 @@ type Server struct { addr string config StartServerConfig persistPath string + tls bool } // StartServerConfig contains configuration options for starting a server. @@ -253,6 +256,7 @@ func StartServer(config StartServerConfig) (*Server, error) { var port int32 addr := "" persistPath := "" + tlsEnabled := false func() { ffiMu.Lock() defer ffiMu.Unlock() @@ -265,6 +269,7 @@ func StartServer(config StartServerConfig) (*Server, error) { if persistPathPtr != nil { persistPath = goStringFromPtr(persistPathPtr) } + tlsEnabled = chromaServerTLSEnabled(handle) > 0 }() resolvedPersistPath, persistPathErr := normalizePersistPath(persistPath) @@ -293,6 +298,7 @@ func StartServer(config StartServerConfig) (*Server, error) { addr: addr, config: config, persistPath: resolvedPersistPath, + tls: tlsEnabled, } goruntime.SetFinalizer(server, func(s *Server) { @@ -316,11 +322,22 @@ func (s *Server) Address() string { return s.addr } -// URL returns the full URL of the server (e.g., "http://127.0.0.1:8000"). +// TLS returns true if the server was started with TLS configured. +func (s *Server) TLS() bool { + s.stateMu.RLock() + defer s.stateMu.RUnlock() + return s.tls +} + +// URL returns the full URL of the server (e.g., "http://127.0.0.1:8000" or "https://127.0.0.1:8000"). func (s *Server) URL() string { s.stateMu.RLock() defer s.stateMu.RUnlock() - return fmt.Sprintf("http://%s:%d", s.addr, s.port) + scheme := "http" + if s.tls { + scheme = "https" + } + return fmt.Sprintf("%s://%s:%d", scheme, s.addr, s.port) } // Stop gracefully stops the server. diff --git a/internal/runtime/chroma_test.go b/internal/runtime/chroma_test.go index 83e5a53..fd32dcb 100644 --- a/internal/runtime/chroma_test.go +++ b/internal/runtime/chroma_test.go @@ -175,6 +175,33 @@ func TestServerConfigToYAML(t *testing.T) { } } +func TestServerConfigTLSYAML(t *testing.T) { + cfg := DefaultServerConfig() + WithTLSCertPath("/etc/certs/server.crt")(cfg) + WithTLSKeyPath("/etc/certs/server.key")(cfg) + + yaml := cfg.toYAML() + + if !strings.Contains(yaml, `tls_cert_path: "/etc/certs/server.crt"`) { + t.Errorf("YAML missing tls_cert_path, got: %s", yaml) + } + if !strings.Contains(yaml, `tls_key_path: "/etc/certs/server.key"`) { + t.Errorf("YAML missing tls_key_path, got: %s", yaml) + } +} + +func TestServerConfigTLSYAML_NotIncludedWhenEmpty(t *testing.T) { + cfg := DefaultServerConfig() + yaml := cfg.toYAML() + + if strings.Contains(yaml, "tls_cert_path") { + t.Errorf("YAML should not contain tls_cert_path when not set, got: %s", yaml) + } + if strings.Contains(yaml, "tls_key_path") { + t.Errorf("YAML should not contain tls_key_path when not set, got: %s", yaml) + } +} + func TestServerConfigWithOptions(t *testing.T) { cfg := DefaultServerConfig() diff --git a/internal/runtime/config.go b/internal/runtime/config.go index e870ade..55771da 100644 --- a/internal/runtime/config.go +++ b/internal/runtime/config.go @@ -22,6 +22,14 @@ type ServerConfig struct { OTelEndpoint string OTelServiceName string + // TLS (optional) + // TLSCertPath is the path to the PEM-encoded TLS certificate file. + // When set, chroma_server_tls_enabled will report TLS as enabled and + // Server.URL() will return an https:// URL. + TLSCertPath string + // TLSKeyPath is the path to the PEM-encoded TLS private key file. + TLSKeyPath string + // Raw config (takes precedence if set) rawYAML string } @@ -98,6 +106,21 @@ func WithOpenTelemetry(endpoint, serviceName string) ServerOption { } } +// WithTLSCertPath sets the path to the PEM-encoded TLS certificate file. +// When a cert path is provided, Server.URL() will return an https:// URL. +func WithTLSCertPath(certPath string) ServerOption { + return func(c *ServerConfig) { + c.TLSCertPath = certPath + } +} + +// WithTLSKeyPath sets the path to the PEM-encoded TLS private key file. +func WithTLSKeyPath(keyPath string) ServerOption { + return func(c *ServerConfig) { + c.TLSKeyPath = keyPath + } +} + // WithRawYAML sets a raw YAML config string (overrides all other options). func WithRawYAML(yaml string) ServerOption { return func(c *ServerConfig) { @@ -135,6 +158,13 @@ func (c *ServerConfig) toYAML() string { } } + if c.TLSCertPath != "" { + fmt.Fprintf(&b, "tls_cert_path: %q\n", c.TLSCertPath) + } + if c.TLSKeyPath != "" { + fmt.Fprintf(&b, "tls_key_path: %q\n", c.TLSKeyPath) + } + return b.String() } diff --git a/java/core/src/main/java/tech/amikos/chroma/local/core/ServerConfigBuilder.java b/java/core/src/main/java/tech/amikos/chroma/local/core/ServerConfigBuilder.java index e0738e1..6c7f425 100644 --- a/java/core/src/main/java/tech/amikos/chroma/local/core/ServerConfigBuilder.java +++ b/java/core/src/main/java/tech/amikos/chroma/local/core/ServerConfigBuilder.java @@ -18,6 +18,8 @@ public final class ServerConfigBuilder { private List corsAllowOrigins; private String otelEndpoint; private String otelServiceName; + private String tlsCertPath; + private String tlsKeyPath; private String rawYaml; public ServerConfigBuilder port(int port) { @@ -65,6 +67,16 @@ public ServerConfigBuilder otelServiceName(String otelServiceName) { return this; } + public ServerConfigBuilder tlsCertPath(String tlsCertPath) { + this.tlsCertPath = tlsCertPath; + return this; + } + + public ServerConfigBuilder tlsKeyPath(String tlsKeyPath) { + this.tlsKeyPath = tlsKeyPath; + return this; + } + public ServerConfigBuilder rawYaml(String rawYaml) { this.rawYaml = rawYaml; return this; @@ -124,6 +136,13 @@ private String toYaml() { map.put("open_telemetry", otel); } + if (tlsCertPath != null && !tlsCertPath.isBlank()) { + map.put("tls_cert_path", tlsCertPath); + } + if (tlsKeyPath != null && !tlsKeyPath.isBlank()) { + map.put("tls_key_path", tlsKeyPath); + } + Yaml yaml = new Yaml(options); return yaml.dump(map); } diff --git a/java/core/src/main/java/tech/amikos/chroma/local/core/ServerSession.java b/java/core/src/main/java/tech/amikos/chroma/local/core/ServerSession.java index 150b2ca..e62963a 100644 --- a/java/core/src/main/java/tech/amikos/chroma/local/core/ServerSession.java +++ b/java/core/src/main/java/tech/amikos/chroma/local/core/ServerSession.java @@ -2,6 +2,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BooleanSupplier; import java.util.function.Function; import java.util.function.LongConsumer; import java.util.function.LongFunction; @@ -16,6 +17,7 @@ public final class ServerSession implements AutoCloseable { private final LongToIntFunction portAccessor; private final LongFunction addressAccessor; private final LongFunction persistPathAccessor; + private final BooleanSupplier tlsEnabledAccessor; private final Function> backupAction; private final Function> rebuildAction; private final Function> compactCollectionAction; @@ -26,6 +28,7 @@ public final class ServerSession implements AutoCloseable { public ServerSession(long handle, LongConsumer stopAction, LongConsumer freeAction, LongToIntFunction portAccessor, LongFunction addressAccessor, LongFunction persistPathAccessor, + BooleanSupplier tlsEnabledAccessor, Function> backupAction, Function> rebuildAction, Function> compactCollectionAction, @@ -38,6 +41,7 @@ public ServerSession(long handle, LongConsumer stopAction, LongConsumer freeActi if (portAccessor == null) throw new IllegalArgumentException("portAccessor must be set"); if (addressAccessor == null) throw new IllegalArgumentException("addressAccessor must be set"); if (persistPathAccessor == null) throw new IllegalArgumentException("persistPathAccessor must be set"); + if (tlsEnabledAccessor == null) throw new IllegalArgumentException("tlsEnabledAccessor must be set"); if (backupAction == null) throw new IllegalArgumentException("backupAction must be set"); if (rebuildAction == null) throw new IllegalArgumentException("rebuildAction must be set"); if (compactCollectionAction == null) throw new IllegalArgumentException("compactCollectionAction must be set"); @@ -50,6 +54,7 @@ public ServerSession(long handle, LongConsumer stopAction, LongConsumer freeActi this.portAccessor = portAccessor; this.addressAccessor = addressAccessor; this.persistPathAccessor = persistPathAccessor; + this.tlsEnabledAccessor = tlsEnabledAccessor; this.backupAction = backupAction; this.rebuildAction = rebuildAction; this.compactCollectionAction = compactCollectionAction; @@ -74,9 +79,11 @@ private void ensureOpen() { public String persistPath() { ensureOpen(); return persistPathAccessor.apply(handle); } - // TLS not yet supported + public boolean tlsEnabled() { ensureOpen(); return tlsEnabledAccessor.getAsBoolean(); } + public String url() { - return "http://" + address() + ":" + port(); + String scheme = tlsEnabledAccessor.getAsBoolean() ? "https" : "http"; + return scheme + "://" + address() + ":" + port(); } @Override diff --git a/java/core/src/test/java/tech/amikos/chroma/local/core/ServerConfigBuilderTest.java b/java/core/src/test/java/tech/amikos/chroma/local/core/ServerConfigBuilderTest.java index 14fcd7c..f75bde6 100644 --- a/java/core/src/test/java/tech/amikos/chroma/local/core/ServerConfigBuilderTest.java +++ b/java/core/src/test/java/tech/amikos/chroma/local/core/ServerConfigBuilderTest.java @@ -206,4 +206,31 @@ void defaultBuildContainsAllRequiredKeys() { assertTrue(map.containsKey("sqlite_filename")); assertTrue(map.containsKey("allow_reset")); } + + @Test + void tlsCertPathAppearsInYaml() { + String yaml = new ServerConfigBuilder().tlsCertPath("/etc/certs/server.crt").build(); + Map map = parseYaml(yaml); + assertEquals("/etc/certs/server.crt", map.get("tls_cert_path")); + assertFalse(map.containsKey("tls_key_path")); + } + + @Test + void tlsKeyPathAppearsInYaml() { + String yaml = new ServerConfigBuilder() + .tlsCertPath("/etc/certs/server.crt") + .tlsKeyPath("/etc/certs/server.key") + .build(); + Map map = parseYaml(yaml); + assertEquals("/etc/certs/server.crt", map.get("tls_cert_path")); + assertEquals("/etc/certs/server.key", map.get("tls_key_path")); + } + + @Test + void tlsFieldsAbsentWhenNotSet() { + String yaml = new ServerConfigBuilder().build(); + Map map = parseYaml(yaml); + assertFalse(map.containsKey("tls_cert_path")); + assertFalse(map.containsKey("tls_key_path")); + } } diff --git a/java/core/src/test/java/tech/amikos/chroma/local/core/ServerSessionTest.java b/java/core/src/test/java/tech/amikos/chroma/local/core/ServerSessionTest.java index c28e65c..6911411 100644 --- a/java/core/src/test/java/tech/amikos/chroma/local/core/ServerSessionTest.java +++ b/java/core/src/test/java/tech/amikos/chroma/local/core/ServerSessionTest.java @@ -32,6 +32,7 @@ private ServerSession createSession(long handle) { h -> 8000, h -> "localhost", h -> "/data", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -41,49 +42,56 @@ private ServerSession createSession(long handle) { @Test void constructorRejectsZeroHandle() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(0L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(0L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullStopAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, null, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, null, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullFreeAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, null, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, null, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullPortAccessor() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, null, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, null, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullAddressAccessor() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, null, h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, null, h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullPersistPathAccessor() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", null, STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", null, () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullBackupAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", null, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, null, + STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); + } + + @Test + void constructorRejectsNullTlsEnabledAccessor() { + assertThrows(IllegalArgumentException.class, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", null, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @@ -108,6 +116,7 @@ void port_callsPortAccessor() { h -> {}, h -> {}, h -> { receivedHandle.set(h); return 9090; }, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -125,6 +134,7 @@ void address_callsAddressAccessor() { h -> 8000, h -> { receivedHandle.set(h); return "0.0.0.0"; }, h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -141,6 +151,7 @@ void persistPath_callsPersistPathAccessor() { h -> {}, h -> {}, h -> 8000, h -> "host", h -> { receivedHandle.set(h); return "/my/data"; }, + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -157,6 +168,7 @@ void url_returnsHttpUrl() { h -> 8080, h -> "127.0.0.1", h -> "/data", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -164,6 +176,28 @@ void url_returnsHttpUrl() { assertEquals("http://127.0.0.1:8080", session.url()); } + @Test + void url_returnsHttpsUrl_whenTlsEnabled() { + ServerSession session = new ServerSession( + 1L, + h -> {}, h -> {}, + h -> 8443, + h -> "127.0.0.1", + h -> "/data", + () -> true, + STUB_BACKUP, + STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, + STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL + ); + assertEquals("https://127.0.0.1:8443", session.url()); + } + + @Test + void tlsEnabled_returnsFalse_byDefault() { + ServerSession session = createSession(1L); + assertEquals(false, session.tlsEnabled()); + } + @Test void close_callsStopThenFree() { AtomicInteger stopCalls = new AtomicInteger(); @@ -173,6 +207,7 @@ void close_callsStopThenFree() { h -> stopCalls.incrementAndGet(), h -> freeCalls.incrementAndGet(), h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -191,6 +226,7 @@ void close_isIdempotent() { h -> stopCalls.incrementAndGet(), h -> freeCalls.incrementAndGet(), h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -210,6 +246,7 @@ void close_freesEvenIfStopFails() { h -> { throw new RuntimeException("stop failed"); }, h -> freeCalls.incrementAndGet(), h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -226,6 +263,7 @@ void close_remainsClosedAfterStopFailure() { h -> { stopCalls.incrementAndGet(); throw new RuntimeException("stop failed"); }, h -> {}, h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -244,6 +282,7 @@ void close_freeRunsEvenWhenBothFail() { h -> { throw new RuntimeException("stop failed"); }, h -> { throw new RuntimeException("free failed"); }, h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -316,6 +355,7 @@ void backupDelegatesAndInvalidatesSession() { 42L, h -> {}, h -> {}, h -> 8000, h -> "host", h -> "/path", + () -> false, opts -> { capturedOpts.set(opts); return fakeResult; }, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -343,6 +383,7 @@ void closeAfterBackupIsNoOp() { h -> stopCalls.incrementAndGet(), h -> freeCalls.incrementAndGet(), h -> 8000, h -> "host", h -> "/path", + () -> false, opts -> new BackupResult<>(manifest, null), STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -360,6 +401,7 @@ void backupFailureStillInvalidatesSession() { 42L, h -> {}, h -> {}, h -> 8000, h -> "host", h -> "/path", + () -> false, opts -> { throw new RuntimeException("backup failed"); }, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL @@ -382,6 +424,7 @@ void backupPreValidationFailureLeavesSessionOpen() { h -> stopCalls.incrementAndGet(), h -> freeCalls.incrementAndGet(), h -> 8000, h -> "host", h -> "/path", + () -> false, opts -> { throw new BackupExecutor.PreValidationFailure( new IllegalArgumentException("dest inside source")); }, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, @@ -405,35 +448,35 @@ void backupPreValidationFailureLeavesSessionOpen() { @Test void constructorRejectsNullRebuildAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, null, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullCompactCollectionAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, null, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullCompactAllAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, null, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullPruneWalCollectionAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, null, STUB_PRUNE_ALL)); } @Test void constructorRejectsNullPruneWalAllAction() { assertThrows(IllegalArgumentException.class, - () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", STUB_BACKUP, + () -> new ServerSession(1L, h -> {}, h -> {}, h -> 0, h -> "", h -> "", () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, STUB_PRUNE_COLLECTION, null)); } @@ -447,6 +490,7 @@ void rebuildCollection_delegatesAndInvalidatesSession() { ServerSession session = new ServerSession( 42L, h -> {}, h -> {}, h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, opts -> { capturedOpts.set(opts); return fakeResult; }, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, @@ -465,6 +509,7 @@ void rebuildCollection_delegatesAndInvalidatesSession() { void rebuildCollection_failureStillInvalidatesSession() { ServerSession session = new ServerSession( 42L, h -> {}, h -> {}, h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, opts -> { throw new RuntimeException("op failed"); }, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, @@ -493,6 +538,7 @@ void compactAll_delegatesAndInvalidatesSession() { ServerSession session = new ServerSession( 42L, h -> {}, h -> {}, h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, STUB_REBUILD, STUB_COMPACT_COLLECTION, req -> fakeResult, STUB_PRUNE_COLLECTION, STUB_PRUNE_ALL); @@ -516,6 +562,7 @@ void closeAfterMaintenanceIsNoOp() { h -> stopCalls.incrementAndGet(), h -> freeCalls.incrementAndGet(), h -> 8000, h -> "host", h -> "/path", + () -> false, STUB_BACKUP, opts -> fakeResult, STUB_COMPACT_COLLECTION, STUB_COMPACT_ALL, diff --git a/java/jna/src/main/java/tech/amikos/chroma/local/jna/JnaChromaRuntime.java b/java/jna/src/main/java/tech/amikos/chroma/local/jna/JnaChromaRuntime.java index 67031ba..d88b864 100644 --- a/java/jna/src/main/java/tech/amikos/chroma/local/jna/JnaChromaRuntime.java +++ b/java/jna/src/main/java/tech/amikos/chroma/local/jna/JnaChromaRuntime.java @@ -49,6 +49,8 @@ private interface JnaBindings extends Library { Pointer chroma_server_address(Pointer handle); + int chroma_server_tls_enabled(Pointer handle); + Pointer chroma_embedded_persist_path(Pointer handle); Pointer chroma_server_persist_path(Pointer handle); @@ -142,6 +144,7 @@ protected ServerSession doStartServer(String configYaml) { this::serverPort, this::serverAddress, this::serverPersistPath, + () -> serverTlsEnabled(handle), opts -> BackupExecutor.execute(BackupMode.SERVER, persistPath, version, opts, () -> { try { serverStop(handle); } finally { serverFree(handle); } }, () -> doStartServer(configYaml)), @@ -185,6 +188,10 @@ private String serverAddress(long handle) { () -> Pointer.nativeValue(bindings.chroma_server_address(new Pointer(handle)))); } + private boolean serverTlsEnabled(long handle) { + return callFfiInt(() -> bindings.chroma_server_tls_enabled(new Pointer(handle))) > 0; + } + private String embeddedPersistPath(long handle) { return callFfiBorrowedString( () -> Pointer.nativeValue(bindings.chroma_embedded_persist_path(new Pointer(handle)))); diff --git a/java/panama/src/main/java/tech/amikos/chroma/local/panama/PanamaChromaRuntime.java b/java/panama/src/main/java/tech/amikos/chroma/local/panama/PanamaChromaRuntime.java index bb96e83..cccade1 100644 --- a/java/panama/src/main/java/tech/amikos/chroma/local/panama/PanamaChromaRuntime.java +++ b/java/panama/src/main/java/tech/amikos/chroma/local/panama/PanamaChromaRuntime.java @@ -43,7 +43,8 @@ private record Ffi( MethodHandle serverFree, MethodHandle serverPort, MethodHandle serverAddress, - MethodHandle serverPersistPath) {} + MethodHandle serverPersistPath, + MethodHandle serverTlsEnabled) {} private final Arena arena; private final Ffi ffi; @@ -112,7 +113,10 @@ public static PanamaChromaRuntime init(String libraryPath) { FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS)), linker.downcallHandle( requireSymbol(library, "chroma_server_persist_path"), - FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS))); + FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS)), + linker.downcallHandle( + requireSymbol(library, "chroma_server_tls_enabled"), + FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS))); PanamaChromaRuntime runtime = new PanamaChromaRuntime(arena, ffi); initialized = true; @@ -284,6 +288,7 @@ protected ServerSession doStartServer(String configYaml) { this::serverPort, this::serverAddress, this::serverPersistPath, + () -> serverTlsEnabled(handle), opts -> BackupExecutor.execute(BackupMode.SERVER, persistPath, version, opts, () -> { try { serverStop(handle); } finally { serverFree(handle); } }, () -> doStartServer(configYaml)), @@ -386,6 +391,18 @@ private String serverPersistPath(long handleAddress) { }); } + private boolean serverTlsEnabled(long handleAddress) { + return callFfiInt(() -> { + try { + return (int) ffi.serverTlsEnabled().invokeExact( + MemorySegment.ofAddress(handleAddress)); + } catch (Throwable t) { + if (t instanceof Error error) throw error; + throw new ChromaException("failed to read server TLS state", t); + } + }) > 0; + } + // Same invokeExact constraint as serverFree — cannot use callFfiFree. private void embeddedFree(long handleAddress) { if (handleAddress == 0L) return; diff --git a/shim/src/lib.rs b/shim/src/lib.rs index 3483d74..8e63fff 100644 --- a/shim/src/lib.rs +++ b/shim/src/lib.rs @@ -169,6 +169,34 @@ struct ServerHandle { port: u16, listen_address: CString, persist_path: CString, + tls_enabled: bool, +} + +/// Minimal shim-level TLS fields parsed from the YAML config. +/// These keys are not part of FrontendServerConfig but are recognised by the shim +/// to expose TLS state via chroma_server_tls_enabled. +#[derive(Deserialize, Default)] +struct ShimTlsConfig { + #[serde(default)] + tls_cert_path: Option, + /// Parsed but not yet acted upon — reserved for future TLS termination support. + #[serde(default)] + #[allow(dead_code)] + tls_key_path: Option, +} + +fn tls_enabled_from_string(yaml_str: &str) -> bool { + let tls_cfg: ShimTlsConfig = figment::Figment::from(Yaml::string(yaml_str)) + .extract() + .unwrap_or_default(); + tls_cfg.tls_cert_path.as_deref().is_some_and(|p| !p.is_empty()) +} + +fn tls_enabled_from_path(path: &str) -> bool { + let tls_cfg: ShimTlsConfig = figment::Figment::from(Yaml::file(path)) + .extract() + .unwrap_or_default(); + tls_cfg.tls_cert_path.as_deref().is_some_and(|p| !p.is_empty()) } struct EmbeddedHandle { @@ -2720,7 +2748,8 @@ pub unsafe extern "C" fn chroma_server_start(config_path: *const c_char) -> *mut } }; - start_server_with_config(config) + let tls_enabled = tls_enabled_from_path(&path); + start_server_with_config(config, tls_enabled) }) } @@ -2751,11 +2780,12 @@ pub unsafe extern "C" fn chroma_server_start_from_string( } }; - start_server_with_config(config) + let tls_enabled = tls_enabled_from_string(&yaml); + start_server_with_config(config, tls_enabled) }) } -fn start_server_with_config(config: FrontendServerConfig) -> *mut c_void { +fn start_server_with_config(config: FrontendServerConfig, tls_enabled: bool) -> *mut c_void { let runtime = match Runtime::new() { Ok(r) => r, Err(e) => { @@ -2834,6 +2864,7 @@ fn start_server_with_config(config: FrontendServerConfig) -> *mut c_void { port, listen_address, persist_path, + tls_enabled, }); Box::into_raw(handle) as *mut c_void @@ -2890,6 +2921,28 @@ pub unsafe extern "C" fn chroma_server_persist_path(handle: *mut c_void) -> *con }) } +/// Return whether TLS is enabled on the server handle. +/// Returns 1 if TLS is enabled, 0 if not, -1 on invalid handle. +/// +/// TLS is considered enabled when a `tls_cert_path` was present in the config used to +/// start the server. Note that the Chroma frontend does not perform TLS termination +/// itself; an external proxy or future shim-level implementation is required to serve +/// HTTPS traffic. +/// +/// # Safety +/// `handle` must be a valid handle from `chroma_server_start*` or NULL. +#[no_mangle] +pub unsafe extern "C" fn chroma_server_tls_enabled(handle: *mut c_void) -> i32 { + ffi_guard_minus_one!({ + if handle.is_null() { + set_last_error("handle is null"); + return -1; + } + let server = &*(handle as *const ServerHandle); + if server.tls_enabled { 1 } else { 0 } + }) +} + /// Stop the server gracefully. /// Returns SUCCESS on success, error code on failure. ///