Skip to content

Commit 9d5d30a

Browse files
authored
Merge pull request #144 from tursodatabase/request-headers
allow client to pass arbitrary http headers in the sql-over-http requests
2 parents a9a8fad + 2480b3f commit 9d5d30a

3 files changed

Lines changed: 41 additions & 15 deletions

File tree

libsql/internal/http/driver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ import (
66
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/hranaV2"
77
)
88

9-
func Connect(url, jwt, host string, schemaDb bool, remoteEncryptionKey string) driver.Conn {
10-
return hranaV2.Connect(url, jwt, host, schemaDb, remoteEncryptionKey)
9+
func Connect(url, jwt, host string, schemaDb bool, remoteEncryptionKey string, requestHeaders map[string]string) driver.Conn {
10+
return hranaV2.Connect(url, jwt, host, schemaDb, remoteEncryptionKey, requestHeaders)
1111
}

libsql/internal/http/hranaV2/hranaV2.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ func init() {
3737
commitHash = "unknown"
3838
}
3939

40-
func Connect(url, jwt, host string, schemaDb bool, encryptionKey string) driver.Conn {
41-
return &hranaV2Conn{url, jwt, host, schemaDb, encryptionKey, "", false, 0}
40+
func Connect(url, jwt, host string, schemaDb bool, encryptionKey string, requestHeaders map[string]string) driver.Conn {
41+
return &hranaV2Conn{url, jwt, host, schemaDb, encryptionKey, requestHeaders, "", false, 0}
4242
}
4343

4444
type hranaV2Stmt struct {
@@ -88,6 +88,7 @@ type hranaV2Conn struct {
8888
host string
8989
schemaDb bool
9090
remoteEncryptionKey string
91+
requestHeaders map[string]string
9192
baton string
9293
streamClosed bool
9394
replicationIndex uint64
@@ -123,11 +124,11 @@ func (h *hranaV2Conn) PrepareContext(ctx context.Context, query string) (driver.
123124

124125
func (h *hranaV2Conn) Close() error {
125126
if h.baton != "" {
126-
go func(baton, url, jwt, host, encryptionKey string) {
127+
go func(baton, url, jwt, host, encryptionKey string, requestHeaders map[string]string) {
127128
msg := hrana.PipelineRequest{Baton: baton}
128129
msg.Add(hrana.CloseStream())
129-
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
130-
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
130+
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey, requestHeaders)
131+
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey, h.requestHeaders)
131132
}
132133
return nil
133134
}
@@ -175,7 +176,7 @@ func (h *hranaV2Conn) sendPipelineRequest(ctx context.Context, msg *hrana.Pipeli
175176
if h.replicationIndex > 0 {
176177
addReplicationIndex(msg, h.replicationIndex)
177178
}
178-
result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host, h.remoteEncryptionKey)
179+
result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host, h.remoteEncryptionKey, h.requestHeaders)
179180
if streamClosed {
180181
h.streamClosed = true
181182
}
@@ -232,7 +233,7 @@ func getReplicationIndex(response *hrana.PipelineResponse) uint64 {
232233
return replicationIndex
233234
}
234235

235-
func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string, remoteEncryptionKey string) (result hrana.PipelineResponse, streamClosed bool, err error) {
236+
func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string, remoteEncryptionKey string, requestHeaders map[string]string) (result hrana.PipelineResponse, streamClosed bool, err error) {
236237
reqBody, err := json.Marshal(msg)
237238
if err != nil {
238239
return hrana.PipelineResponse{}, false, err
@@ -254,6 +255,9 @@ func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url st
254255
}
255256

256257
req.Host = host
258+
for name, value := range requestHeaders {
259+
req.Header.Set(name, value)
260+
}
257261
resp, err := http.DefaultClient.Do(req)
258262
if err != nil {
259263
return hrana.PipelineResponse{}, false, err
@@ -597,11 +601,11 @@ func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []dri
597601

598602
func (h *hranaV2Conn) closeStream() {
599603
if h.baton != "" {
600-
go func(baton, url, jwt, host, encryptionKey string) {
604+
go func(baton, url, jwt, host, encryptionKey string, requestHeaders map[string]string) {
601605
msg := hrana.PipelineRequest{Baton: baton}
602606
msg.Add(hrana.CloseStream())
603-
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
604-
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
607+
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey, requestHeaders)
608+
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey, h.requestHeaders)
605609
h.baton = ""
606610
}
607611
}

libsql/sql.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type config struct {
1919
proxy *string
2020
schemaDb *bool
2121
remoteEncryptionKey *string
22+
requestHeaders map[string]string
2223
}
2324

2425
type Option interface {
@@ -90,6 +91,26 @@ func WithRemoteEncryptionKey(key string) Option {
9091
})
9192
}
9293

94+
// WithRequestHeaders attaches arbitrary HTTP headers to every request the
95+
// driver sends to the remote server. Passing the `Host` key (case-insensitive)
96+
// has no effect.
97+
func WithRequestHeaders(headers map[string]string) Option {
98+
return option(func(o *config) error {
99+
if o.requestHeaders != nil {
100+
return fmt.Errorf("requestHeaders already set")
101+
}
102+
if len(headers) == 0 {
103+
return fmt.Errorf("requestHeaders must not be empty")
104+
}
105+
copied := make(map[string]string, len(headers))
106+
for k, v := range headers {
107+
copied[k] = v
108+
}
109+
o.requestHeaders = copied
110+
return nil
111+
})
112+
}
113+
93114
func (c config) connector(dbPath string) (driver.Connector, error) {
94115
u, err := url.Parse(dbPath)
95116
if err != nil {
@@ -182,7 +203,7 @@ func (c config) connector(dbPath string) (driver.Connector, error) {
182203
return wsConnector{url: u.String(), authToken: authToken}, nil
183204
}
184205
if u.Scheme == "https" || u.Scheme == "http" {
185-
return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb, remoteEncryptionKey: encryptionKey}, nil
206+
return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb, remoteEncryptionKey: encryptionKey, requestHeaders: c.requestHeaders}, nil
186207
}
187208

188209
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)
@@ -208,10 +229,11 @@ type httpConnector struct {
208229
host string
209230
schemaDb bool
210231
remoteEncryptionKey string
232+
requestHeaders map[string]string
211233
}
212234

213235
func (c httpConnector) Connect(_ctx context.Context) (driver.Conn, error) {
214-
return http.Connect(c.url, c.authToken, c.host, c.schemaDb, c.remoteEncryptionKey), nil
236+
return http.Connect(c.url, c.authToken, c.host, c.schemaDb, c.remoteEncryptionKey, c.requestHeaders), nil
215237
}
216238

217239
func (c httpConnector) Driver() driver.Driver {
@@ -360,7 +382,7 @@ func (d Driver) Open(dbUrl string) (driver.Conn, error) {
360382
return ws.Connect(u.String(), jwt)
361383
}
362384
if u.Scheme == "https" || u.Scheme == "http" {
363-
return http.Connect(u.String(), jwt, u.Host, false, ""), nil
385+
return http.Connect(u.String(), jwt, u.Host, false, "", nil), nil
364386
}
365387

366388
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)

0 commit comments

Comments
 (0)