Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ func WithOpenTelemetry(endpoint, serviceName string) ServerOption {
return runtime.WithOpenTelemetry(endpoint, serviceName)
}

func WithTLSCertPath(certPath string) ServerOption {
return runtime.WithTLSCertPath(certPath)
}

func WithTLSKeyPath(keyPath string) ServerOption {
return runtime.WithTLSKeyPath(keyPath)
}

func WithRawYAML(yaml string) ServerOption {
return runtime.WithRawYAML(yaml)
}
21 changes: 19 additions & 2 deletions internal/runtime/chroma.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
chromaServerPort func(uintptr) int32
chromaServerAddress func(uintptr) *byte
chromaServerPersistPath func(uintptr) *byte
chromaServerTLSEnabled func(uintptr) int32
chromaServerStop func(uintptr) int32
chromaServerFree func(uintptr)
chromaEmbeddedStart func(*byte) uintptr
Expand Down Expand Up @@ -94,6 +95,7 @@ func registerFunctions() error {
{&chromaServerPort, "chroma_server_port"},
{&chromaServerAddress, "chroma_server_address"},
{&chromaServerPersistPath, "chroma_server_persist_path"},
{&chromaServerTLSEnabled, "chroma_server_tls_enabled"},
{&chromaServerStop, "chroma_server_stop"},
{&chromaServerFree, "chroma_server_free"},
{&chromaEmbeddedStart, "chroma_embedded_start"},
Expand Down Expand Up @@ -220,6 +222,7 @@ type Server struct {
addr string
config StartServerConfig
persistPath string
tls bool
}

// StartServerConfig contains configuration options for starting a server.
Expand Down Expand Up @@ -253,6 +256,7 @@ func StartServer(config StartServerConfig) (*Server, error) {
var port int32
addr := ""
persistPath := ""
tlsEnabled := false
func() {
ffiMu.Lock()
defer ffiMu.Unlock()
Expand All @@ -265,6 +269,7 @@ func StartServer(config StartServerConfig) (*Server, error) {
if persistPathPtr != nil {
persistPath = goStringFromPtr(persistPathPtr)
}
tlsEnabled = chromaServerTLSEnabled(handle) > 0
}()

resolvedPersistPath, persistPathErr := normalizePersistPath(persistPath)
Expand Down Expand Up @@ -293,6 +298,7 @@ func StartServer(config StartServerConfig) (*Server, error) {
addr: addr,
config: config,
persistPath: resolvedPersistPath,
tls: tlsEnabled,
}

goruntime.SetFinalizer(server, func(s *Server) {
Expand All @@ -316,11 +322,22 @@ func (s *Server) Address() string {
return s.addr
}

// URL returns the full URL of the server (e.g., "http://127.0.0.1:8000").
// TLS returns true if the server was started with TLS configured.
func (s *Server) TLS() bool {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return s.tls
}

// URL returns the full URL of the server (e.g., "http://127.0.0.1:8000" or "https://127.0.0.1:8000").
func (s *Server) URL() string {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return fmt.Sprintf("http://%s:%d", s.addr, s.port)
scheme := "http"
if s.tls {
scheme = "https"
}
return fmt.Sprintf("%s://%s:%d", scheme, s.addr, s.port)
}

// Stop gracefully stops the server.
Expand Down
27 changes: 27 additions & 0 deletions internal/runtime/chroma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,33 @@ func TestServerConfigToYAML(t *testing.T) {
}
}

func TestServerConfigTLSYAML(t *testing.T) {
cfg := DefaultServerConfig()
WithTLSCertPath("/etc/certs/server.crt")(cfg)
WithTLSKeyPath("/etc/certs/server.key")(cfg)

yaml := cfg.toYAML()

if !strings.Contains(yaml, `tls_cert_path: "/etc/certs/server.crt"`) {
t.Errorf("YAML missing tls_cert_path, got: %s", yaml)
}
if !strings.Contains(yaml, `tls_key_path: "/etc/certs/server.key"`) {
t.Errorf("YAML missing tls_key_path, got: %s", yaml)
}
}

func TestServerConfigTLSYAML_NotIncludedWhenEmpty(t *testing.T) {
cfg := DefaultServerConfig()
yaml := cfg.toYAML()

if strings.Contains(yaml, "tls_cert_path") {
t.Errorf("YAML should not contain tls_cert_path when not set, got: %s", yaml)
}
if strings.Contains(yaml, "tls_key_path") {
t.Errorf("YAML should not contain tls_key_path when not set, got: %s", yaml)
}
}

func TestServerConfigWithOptions(t *testing.T) {
cfg := DefaultServerConfig()

Expand Down
30 changes: 30 additions & 0 deletions internal/runtime/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ type ServerConfig struct {
OTelEndpoint string
OTelServiceName string

// TLS (optional)
// TLSCertPath is the path to the PEM-encoded TLS certificate file.
// When set, chroma_server_tls_enabled will report TLS as enabled and
// Server.URL() will return an https:// URL.
TLSCertPath string
// TLSKeyPath is the path to the PEM-encoded TLS private key file.
TLSKeyPath string

// Raw config (takes precedence if set)
rawYAML string
}
Expand Down Expand Up @@ -98,6 +106,21 @@ func WithOpenTelemetry(endpoint, serviceName string) ServerOption {
}
}

// WithTLSCertPath sets the path to the PEM-encoded TLS certificate file.
// When a cert path is provided, Server.URL() will return an https:// URL.
func WithTLSCertPath(certPath string) ServerOption {
return func(c *ServerConfig) {
c.TLSCertPath = certPath
}
}

// WithTLSKeyPath sets the path to the PEM-encoded TLS private key file.
func WithTLSKeyPath(keyPath string) ServerOption {
return func(c *ServerConfig) {
c.TLSKeyPath = keyPath
}
}

// WithRawYAML sets a raw YAML config string (overrides all other options).
func WithRawYAML(yaml string) ServerOption {
return func(c *ServerConfig) {
Expand Down Expand Up @@ -135,6 +158,13 @@ func (c *ServerConfig) toYAML() string {
}
}

if c.TLSCertPath != "" {
fmt.Fprintf(&b, "tls_cert_path: %q\n", c.TLSCertPath)
}
if c.TLSKeyPath != "" {
fmt.Fprintf(&b, "tls_key_path: %q\n", c.TLSKeyPath)
}

return b.String()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public final class ServerConfigBuilder {
private List<String> corsAllowOrigins;
private String otelEndpoint;
private String otelServiceName;
private String tlsCertPath;
private String tlsKeyPath;
private String rawYaml;

public ServerConfigBuilder port(int port) {
Expand Down Expand Up @@ -65,6 +67,16 @@ public ServerConfigBuilder otelServiceName(String otelServiceName) {
return this;
}

public ServerConfigBuilder tlsCertPath(String tlsCertPath) {
this.tlsCertPath = tlsCertPath;
return this;
}

public ServerConfigBuilder tlsKeyPath(String tlsKeyPath) {
this.tlsKeyPath = tlsKeyPath;
return this;
}

public ServerConfigBuilder rawYaml(String rawYaml) {
this.rawYaml = rawYaml;
return this;
Expand Down Expand Up @@ -124,6 +136,13 @@ private String toYaml() {
map.put("open_telemetry", otel);
}

if (tlsCertPath != null && !tlsCertPath.isBlank()) {
map.put("tls_cert_path", tlsCertPath);
}
if (tlsKeyPath != null && !tlsKeyPath.isBlank()) {
map.put("tls_key_path", tlsKeyPath);
}

Yaml yaml = new Yaml(options);
return yaml.dump(map);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BooleanSupplier;
import java.util.function.Function;
import java.util.function.LongConsumer;
import java.util.function.LongFunction;
Expand All @@ -16,6 +17,7 @@ public final class ServerSession implements AutoCloseable {
private final LongToIntFunction portAccessor;
private final LongFunction<String> addressAccessor;
private final LongFunction<String> persistPathAccessor;
private final BooleanSupplier tlsEnabledAccessor;
private final Function<BackupOptions, BackupResult<ServerSession>> backupAction;
private final Function<RebuildOptions, MaintenanceResult<RebuildCollectionResult, ServerSession>> rebuildAction;
private final Function<CompactCollectionRequest, MaintenanceResult<CompactionResult, ServerSession>> compactCollectionAction;
Expand All @@ -26,6 +28,7 @@ public final class ServerSession implements AutoCloseable {
public ServerSession(long handle, LongConsumer stopAction, LongConsumer freeAction,
LongToIntFunction portAccessor, LongFunction<String> addressAccessor,
LongFunction<String> persistPathAccessor,
BooleanSupplier tlsEnabledAccessor,
Function<BackupOptions, BackupResult<ServerSession>> backupAction,
Function<RebuildOptions, MaintenanceResult<RebuildCollectionResult, ServerSession>> rebuildAction,
Function<CompactCollectionRequest, MaintenanceResult<CompactionResult, ServerSession>> compactCollectionAction,
Expand All @@ -38,6 +41,7 @@ public ServerSession(long handle, LongConsumer stopAction, LongConsumer freeActi
if (portAccessor == null) throw new IllegalArgumentException("portAccessor must be set");
if (addressAccessor == null) throw new IllegalArgumentException("addressAccessor must be set");
if (persistPathAccessor == null) throw new IllegalArgumentException("persistPathAccessor must be set");
if (tlsEnabledAccessor == null) throw new IllegalArgumentException("tlsEnabledAccessor must be set");
if (backupAction == null) throw new IllegalArgumentException("backupAction must be set");
if (rebuildAction == null) throw new IllegalArgumentException("rebuildAction must be set");
if (compactCollectionAction == null) throw new IllegalArgumentException("compactCollectionAction must be set");
Expand All @@ -50,6 +54,7 @@ public ServerSession(long handle, LongConsumer stopAction, LongConsumer freeActi
this.portAccessor = portAccessor;
this.addressAccessor = addressAccessor;
this.persistPathAccessor = persistPathAccessor;
this.tlsEnabledAccessor = tlsEnabledAccessor;
this.backupAction = backupAction;
this.rebuildAction = rebuildAction;
this.compactCollectionAction = compactCollectionAction;
Expand All @@ -74,9 +79,11 @@ private void ensureOpen() {

public String persistPath() { ensureOpen(); return persistPathAccessor.apply(handle); }

// TLS not yet supported
public boolean tlsEnabled() { ensureOpen(); return tlsEnabledAccessor.getAsBoolean(); }

public String url() {
return "http://" + address() + ":" + port();
String scheme = tlsEnabledAccessor.getAsBoolean() ? "https" : "http";
return scheme + "://" + address() + ":" + port();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,31 @@ void defaultBuildContainsAllRequiredKeys() {
assertTrue(map.containsKey("sqlite_filename"));
assertTrue(map.containsKey("allow_reset"));
}

@Test
void tlsCertPathAppearsInYaml() {
String yaml = new ServerConfigBuilder().tlsCertPath("/etc/certs/server.crt").build();
Map<String, Object> map = parseYaml(yaml);
assertEquals("/etc/certs/server.crt", map.get("tls_cert_path"));
assertFalse(map.containsKey("tls_key_path"));
}

@Test
void tlsKeyPathAppearsInYaml() {
String yaml = new ServerConfigBuilder()
.tlsCertPath("/etc/certs/server.crt")
.tlsKeyPath("/etc/certs/server.key")
.build();
Map<String, Object> map = parseYaml(yaml);
assertEquals("/etc/certs/server.crt", map.get("tls_cert_path"));
assertEquals("/etc/certs/server.key", map.get("tls_key_path"));
}

@Test
void tlsFieldsAbsentWhenNotSet() {
String yaml = new ServerConfigBuilder().build();
Map<String, Object> map = parseYaml(yaml);
assertFalse(map.containsKey("tls_cert_path"));
assertFalse(map.containsKey("tls_key_path"));
}
}
Loading
Loading