Skip to content

Commit ea96874

Browse files
committed
feat: add GET endpoint by primitive kind and name
1 parent 7cf064f commit ea96874

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+484
-264
lines changed

internal/server/admin.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ func adminRouter(s *Server) (chi.Router, error) {
3737

3838
r.Put("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { createOrUpdatePrimitives(s, w, r) })
3939
r.Delete("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { deletePrimitives(s, w, r) })
40+
r.Get("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { getPrimitiveByName(s, w, r) })
4041

4142
return r, nil
4243
}
@@ -165,3 +166,77 @@ func createOrUpdatePrimitives(s *Server, w http.ResponseWriter, r *http.Request)
165166
return
166167
}
167168
}
169+
170+
func getPrimitiveByName(s *Server, w http.ResponseWriter, r *http.Request) {
171+
kind := chi.URLParam(r, "kind")
172+
name := chi.URLParam(r, "name")
173+
ctx := r.Context()
174+
175+
switch strings.ToLower(kind) {
176+
case "source":
177+
source, ok := s.ResourceMgr.GetSource(name)
178+
if !ok {
179+
err := fmt.Errorf("%s with name %q does not exist", kind, name)
180+
s.logger.DebugContext(ctx, err.Error())
181+
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
182+
return
183+
}
184+
m := source.ToConfig()
185+
render.JSON(w, r, m)
186+
case "authservice":
187+
as, ok := s.ResourceMgr.GetAuthService(name)
188+
if !ok {
189+
err := fmt.Errorf("%s with name %q does not exist", kind, name)
190+
s.logger.DebugContext(ctx, err.Error())
191+
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
192+
return
193+
}
194+
m := as.ToConfig()
195+
render.JSON(w, r, m)
196+
case "embeddingmodel":
197+
em, ok := s.ResourceMgr.GetEmbeddingModel(name)
198+
if !ok {
199+
err := fmt.Errorf("%s with name %q does not exist", kind, name)
200+
s.logger.DebugContext(ctx, err.Error())
201+
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
202+
return
203+
}
204+
m := em.ToConfig()
205+
render.JSON(w, r, m)
206+
case "tool":
207+
tool, ok := s.ResourceMgr.GetTool(name)
208+
if !ok {
209+
err := fmt.Errorf("%s with name %q does not exist", kind, name)
210+
s.logger.DebugContext(ctx, err.Error())
211+
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
212+
return
213+
}
214+
m := tool.ToConfig()
215+
render.JSON(w, r, m)
216+
case "toolset":
217+
ts, ok := s.ResourceMgr.GetToolset(name)
218+
if !ok {
219+
err := fmt.Errorf("%s with name %q does not exist", kind, name)
220+
s.logger.DebugContext(ctx, err.Error())
221+
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
222+
return
223+
}
224+
m := ts.ToConfig()
225+
render.JSON(w, r, m)
226+
case "prompt":
227+
prompt, ok := s.ResourceMgr.GetPrompt(name)
228+
if !ok {
229+
err := fmt.Errorf("%s with name %q does not exist", kind, name)
230+
s.logger.DebugContext(ctx, err.Error())
231+
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
232+
return
233+
}
234+
m := prompt.ToConfig()
235+
render.JSON(w, r, m)
236+
default:
237+
err := fmt.Errorf("invalid primitive kind provided")
238+
s.logger.DebugContext(ctx, err.Error())
239+
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
240+
return
241+
}
242+
}

internal/server/admin_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ package server
1717
import (
1818
"bytes"
1919
"context"
20+
"encoding/json"
2021
"fmt"
2122
"net/http"
23+
"reflect"
2224
"testing"
2325

2426
"github.com/goccy/go-yaml"
@@ -225,3 +227,104 @@ func TestAdminDeleteEndpoint(t *testing.T) {
225227
})
226228
}
227229
}
230+
231+
func TestAdminGetEndpoint(t *testing.T) {
232+
mockSources := map[string]sources.Source{"test-source": testutils.MockSource{MockSourceConfig: testutils.MockSourceConfig{Foo: "foo"}}}
233+
mockAuthServices := map[string]auth.AuthService{"test-auth-service": testutils.MockAuthService{MockAuthServiceConfig: testutils.MockAuthServiceConfig{Foo: "foo"}}}
234+
mockEmbeddingModel := map[string]embeddingmodels.EmbeddingModel{"test-embedding-model": testutils.MockEmbeddingModel{MockEmbeddingModelConfig: testutils.MockEmbeddingModelConfig{Foo: "foo"}}}
235+
mockTool := map[string]tools.Tool{"test-tool": testutils.MockTool{MockToolConfig: testutils.MockToolConfig{Foo: "foo"}}}
236+
mockToolset := map[string]tools.Toolset{"test-toolset": tools.Toolset{ToolsetConfig: tools.ToolsetConfig{ToolNames: []string{"test-tool"}}}}
237+
mockPrompt := map[string]prompts.Prompt{"test-prompt": testutils.MockPrompt{MockPromptConfig: testutils.MockPromptConfig{Foo: "foo"}}}
238+
r, shutdown := setUpServer(t, "admin", mockSources, mockAuthServices, mockEmbeddingModel, mockTool, mockToolset, mockPrompt, map[string]prompts.Promptset{})
239+
defer shutdown()
240+
ts := runServer(r, false)
241+
defer ts.Close()
242+
243+
tests := []struct {
244+
name string
245+
kind string
246+
resourceName string
247+
want any
248+
expectedStatusCode int
249+
}{
250+
{
251+
name: "Get Source - Success",
252+
kind: "source",
253+
resourceName: "test-source",
254+
want: mockSources,
255+
expectedStatusCode: http.StatusOK,
256+
},
257+
{
258+
name: "Get Auth Service - Success",
259+
kind: "authService",
260+
resourceName: "test-auth-service",
261+
want: mockAuthServices,
262+
expectedStatusCode: http.StatusOK,
263+
},
264+
{
265+
name: "Get Embedding Model - Success",
266+
kind: "embeddingModel",
267+
resourceName: "test-embedding-model",
268+
want: mockEmbeddingModel,
269+
expectedStatusCode: http.StatusOK,
270+
},
271+
{
272+
name: "Get Tool - Success",
273+
kind: "tool",
274+
resourceName: "test-tool",
275+
want: mockTool,
276+
expectedStatusCode: http.StatusOK,
277+
},
278+
{
279+
name: "Get Toolset - Success",
280+
kind: "toolset",
281+
resourceName: "test-toolset",
282+
want: mockToolset,
283+
expectedStatusCode: http.StatusOK,
284+
},
285+
{
286+
name: "Get Prompt - Success",
287+
kind: "prompt",
288+
resourceName: "test-prompt",
289+
want: mockPrompt,
290+
expectedStatusCode: http.StatusOK,
291+
},
292+
{
293+
name: "Get Non-existent Primitive - Not Found",
294+
kind: "source",
295+
resourceName: "non-existent-source",
296+
expectedStatusCode: http.StatusNotFound,
297+
},
298+
{
299+
name: "Get with Invalid Kind - Bad Request",
300+
kind: "invalidKind",
301+
resourceName: "some-name",
302+
expectedStatusCode: http.StatusBadRequest,
303+
},
304+
}
305+
306+
for _, tt := range tests {
307+
t.Run(tt.name, func(t *testing.T) {
308+
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/%s/%s", tt.kind, tt.resourceName), nil, nil)
309+
if err != nil {
310+
t.Fatalf("unexpected error during request: %s", err)
311+
}
312+
if resp.StatusCode != tt.expectedStatusCode {
313+
t.Fatalf("response status code is not %d, got %d, %s", tt.expectedStatusCode, resp.StatusCode, string(body))
314+
}
315+
if tt.expectedStatusCode == http.StatusOK {
316+
var got map[string]any
317+
if err := json.Unmarshal(body, &got); err != nil {
318+
t.Fatalf("error unmarshaling response body")
319+
}
320+
want, err := json.Marshal(tt.want)
321+
if err != nil {
322+
t.Fatalf("error marshaling want struct")
323+
}
324+
if !reflect.DeepEqual(got, want) {
325+
t.Fatalf("unexpected output: got %+v, want %+v", got, want)
326+
}
327+
}
328+
})
329+
}
330+
}

internal/sources/alloydbpg/alloydb_pg.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ type Config struct {
5656
Cluster string `yaml:"cluster" validate:"required"`
5757
Instance string `yaml:"instance" validate:"required"`
5858
IPType sources.IPType `yaml:"ipType" validate:"required"`
59-
User string `yaml:"user"`
60-
Password string `yaml:"password"`
59+
User util.Secret `yaml:"user"`
60+
Password util.Secret `yaml:"password"`
6161
Database string `yaml:"database" validate:"required"`
6262
}
6363

@@ -66,7 +66,7 @@ func (r Config) SourceConfigType() string {
6666
}
6767

6868
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
69-
pool, err := initAlloyDBPgConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Cluster, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
69+
pool, err := initAlloyDBPgConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Cluster, r.Instance, r.IPType.String(), r.User.String(), r.Password.String(), r.Database)
7070
if err != nil {
7171
return nil, fmt.Errorf("unable to create pool: %w", err)
7272
}

internal/sources/alloydbpg/alloydb_pg_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/googleapis/genai-toolbox/internal/sources"
2424
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
2525
"github.com/googleapis/genai-toolbox/internal/testutils"
26+
"github.com/googleapis/genai-toolbox/internal/util"
2627
)
2728

2829
func TestParseFromYamlAlloyDBPg(t *testing.T) {
@@ -55,8 +56,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
5556
Instance: "my-instance",
5657
IPType: "public",
5758
Database: "my_db",
58-
User: "my_user",
59-
Password: "my_pass",
59+
User: util.Secret("my_user"),
60+
Password: util.Secret("my_pass"),
6061
},
6162
},
6263
},
@@ -85,8 +86,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
8586
Instance: "my-instance",
8687
IPType: "public",
8788
Database: "my_db",
88-
User: "my_user",
89-
Password: "my_pass",
89+
User: util.Secret("my_user"),
90+
Password: util.Secret("my_pass"),
9091
},
9192
},
9293
},
@@ -115,8 +116,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
115116
Instance: "my-instance",
116117
IPType: "private",
117118
Database: "my_db",
118-
User: "my_user",
119-
Password: "my_pass",
119+
User: util.Secret("my_user"),
120+
Password: util.Secret("my_pass"),
120121
},
121122
},
122123
},

internal/sources/clickhouse/clickhouse.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
_ "github.com/ClickHouse/clickhouse-go/v2"
2525
"github.com/goccy/go-yaml"
2626
"github.com/googleapis/genai-toolbox/internal/sources"
27+
"github.com/googleapis/genai-toolbox/internal/util"
2728
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2829
"go.opentelemetry.io/otel/trace"
2930
)
@@ -48,23 +49,23 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
4849
}
4950

5051
type Config struct {
51-
Name string `yaml:"name" validate:"required"`
52-
Type string `yaml:"type" validate:"required"`
53-
Host string `yaml:"host" validate:"required"`
54-
Port string `yaml:"port" validate:"required"`
55-
Database string `yaml:"database" validate:"required"`
56-
User string `yaml:"user" validate:"required"`
57-
Password string `yaml:"password"`
58-
Protocol string `yaml:"protocol"`
59-
Secure bool `yaml:"secure"`
52+
Name string `yaml:"name" validate:"required"`
53+
Type string `yaml:"type" validate:"required"`
54+
Host string `yaml:"host" validate:"required"`
55+
Port string `yaml:"port" validate:"required"`
56+
Database string `yaml:"database" validate:"required"`
57+
User util.Secret `yaml:"user" validate:"required"`
58+
Password util.Secret `yaml:"password"`
59+
Protocol string `yaml:"protocol"`
60+
Secure bool `yaml:"secure"`
6061
}
6162

6263
func (r Config) SourceConfigType() string {
6364
return SourceType
6465
}
6566

6667
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
67-
pool, err := initClickHouseConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.Protocol, r.Secure)
68+
pool, err := initClickHouseConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User.String(), r.Password.String(), r.Database, r.Protocol, r.Secure)
6869
if err != nil {
6970
return nil, fmt.Errorf("unable to create pool: %w", err)
7071
}

internal/sources/cloudsqlmssql/cloud_sql_mssql.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ type Config struct {
5656
Region string `yaml:"region" validate:"required"`
5757
Instance string `yaml:"instance" validate:"required"`
5858
IPType sources.IPType `yaml:"ipType" validate:"required"`
59-
User string `yaml:"user" validate:"required"`
60-
Password string `yaml:"password" validate:"required"`
59+
User util.Secret `yaml:"user" validate:"required"`
60+
Password util.Secret `yaml:"password" validate:"required"`
6161
Database string `yaml:"database" validate:"required"`
6262
}
6363

@@ -68,7 +68,7 @@ func (r Config) SourceConfigType() string {
6868

6969
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
7070
// Initializes a Cloud SQL MSSQL source
71-
db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
71+
db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User.String(), r.Password.String(), r.Database)
7272
if err != nil {
7373
return nil, fmt.Errorf("unable to create db connection: %w", err)
7474
}

internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/googleapis/genai-toolbox/internal/sources"
2424
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
2525
"github.com/googleapis/genai-toolbox/internal/testutils"
26+
"github.com/googleapis/genai-toolbox/internal/util"
2627
)
2728

2829
func TestParseFromYamlCloudSQLMssql(t *testing.T) {
@@ -53,8 +54,8 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
5354
Instance: "my-instance",
5455
IPType: "public",
5556
Database: "my_db",
56-
User: "my_user",
57-
Password: "my_pass",
57+
User: util.Secret("my_user"),
58+
Password: util.Secret("my_pass"),
5859
},
5960
},
6061
},
@@ -81,8 +82,8 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
8182
Instance: "my-instance",
8283
IPType: "psc",
8384
Database: "my_db",
84-
User: "my_user",
85-
Password: "my_pass",
85+
User: util.Secret("my_user"),
86+
Password: util.Secret("my_pass"),
8687
},
8788
},
8889
},

internal/sources/cloudsqlmysql/cloud_sql_mysql.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ type Config struct {
5656
Region string `yaml:"region" validate:"required"`
5757
Instance string `yaml:"instance" validate:"required"`
5858
IPType sources.IPType `yaml:"ipType"`
59-
User string `yaml:"user"`
60-
Password string `yaml:"password"`
59+
User util.Secret `yaml:"user"`
60+
Password util.Secret `yaml:"password"`
6161
Database string `yaml:"database"`
6262
}
6363

@@ -66,7 +66,7 @@ func (r Config) SourceConfigType() string {
6666
}
6767

6868
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
69-
pool, err := initCloudSQLMySQLConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
69+
pool, err := initCloudSQLMySQLConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User.String(), r.Password.String(), r.Database)
7070
if err != nil {
7171
return nil, fmt.Errorf("unable to create pool: %w", err)
7272
}

0 commit comments

Comments
 (0)