Skip to content

Commit 6e28433

Browse files
committed
feat: add GET endpoint by primitive kind
1 parent 5713e50 commit 6e28433

File tree

4 files changed

+203
-1
lines changed

4 files changed

+203
-1
lines changed

internal/server/admin.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ func adminRouter(s *Server) (chi.Router, error) {
3737

3838
r.Put("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { createOrUpdatePrimitives(s, w, r) })
3939
r.Delete("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { deletePrimitives(s, w, r) })
40-
r.Get("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { getPrimitiveByName(s, w, r) })
40+
r.Route("/{kind}", func(r chi.Router) {
41+
r.Get("/", func(w http.ResponseWriter, r *http.Request) { getPrimitive(s, w, r) })
42+
r.Get("/{name}", func(w http.ResponseWriter, r *http.Request) { getPrimitiveByName(s, w, r) })
43+
})
4144

4245
return r, nil
4346
}
@@ -167,6 +170,33 @@ func createOrUpdatePrimitives(s *Server, w http.ResponseWriter, r *http.Request)
167170
}
168171
}
169172

173+
func getPrimitive(s *Server, w http.ResponseWriter, r *http.Request) {
174+
kind := chi.URLParam(r, "kind")
175+
ctx := r.Context()
176+
177+
var primitive []string
178+
switch strings.ToLower(kind) {
179+
case "source":
180+
primitive = s.ResourceMgr.GetSources()
181+
case "authservice":
182+
primitive = s.ResourceMgr.GetAuthServices()
183+
case "embeddingmodel":
184+
primitive = s.ResourceMgr.GetEmbeddingModels()
185+
case "tool":
186+
primitive = s.ResourceMgr.GetTools()
187+
case "toolset":
188+
primitive = s.ResourceMgr.GetToolsets()
189+
case "prompt":
190+
primitive = s.ResourceMgr.GetPrompts()
191+
default:
192+
err := fmt.Errorf("invalid primitive kind provided")
193+
s.logger.DebugContext(ctx, err.Error())
194+
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
195+
return
196+
}
197+
render.JSON(w, r, primitive)
198+
}
199+
170200
func getPrimitiveByName(s *Server, w http.ResponseWriter, r *http.Request) {
171201
kind := chi.URLParam(r, "kind")
172202
name := chi.URLParam(r, "name")

internal/server/admin_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,97 @@ func TestAdminDeleteEndpoint(t *testing.T) {
229229
}
230230

231231
func TestAdminGetEndpoint(t *testing.T) {
232+
mockSource := testutils.MockSource{MockSourceConfig: testutils.MockSourceConfig{Foo: "foo", Password: "password"}}
233+
mockAuthService := testutils.MockAuthService{MockAuthServiceConfig: testutils.MockAuthServiceConfig{Foo: "foo"}}
234+
mockEmbeddingModel := testutils.MockEmbeddingModel{MockEmbeddingModelConfig: testutils.MockEmbeddingModelConfig{Foo: "foo"}}
235+
mockTool := testutils.MockTool{MockToolConfig: testutils.MockToolConfig{Foo: "foo"}}
236+
mockToolset := tools.Toolset{ToolsetConfig: tools.ToolsetConfig{ToolNames: []string{"test-tool"}}}
237+
mockPrompt := testutils.MockPrompt{MockPromptConfig: testutils.MockPromptConfig{Foo: "foo"}}
238+
239+
mockSources := map[string]sources.Source{"test-source": mockSource}
240+
mockAuthServices := map[string]auth.AuthService{"test-auth-service": mockAuthService}
241+
mockEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"test-embedding-model": mockEmbeddingModel}
242+
mockTools := map[string]tools.Tool{"test-tool": mockTool}
243+
mockToolsets := map[string]tools.Toolset{"test-toolset": mockToolset}
244+
mockPrompts := map[string]prompts.Prompt{"test-prompt": mockPrompt}
245+
246+
r, shutdown := setUpServer(t, "admin", mockSources, mockAuthServices, mockEmbeddingModels, mockTools, mockToolsets, mockPrompts, map[string]prompts.Promptset{})
247+
defer shutdown()
248+
ts := runServer(r, false)
249+
defer ts.Close()
250+
251+
tests := []struct {
252+
name string
253+
kind string
254+
want []string
255+
expectedStatusCode int
256+
}{
257+
{
258+
name: "Get Source - Success",
259+
kind: "source",
260+
want: []string{"test-source"},
261+
expectedStatusCode: http.StatusOK,
262+
},
263+
{
264+
name: "Get Auth Service - Success",
265+
kind: "authService",
266+
want: []string{"test-auth-service"},
267+
expectedStatusCode: http.StatusOK,
268+
},
269+
{
270+
name: "Get Embedding Model - Success",
271+
kind: "embeddingModel",
272+
want: []string{"test-embedding-model"},
273+
expectedStatusCode: http.StatusOK,
274+
},
275+
{
276+
name: "Get Tool - Success",
277+
kind: "tool",
278+
want: []string{"test-tool"},
279+
expectedStatusCode: http.StatusOK,
280+
},
281+
{
282+
name: "Get Toolset - Success",
283+
kind: "toolset",
284+
want: []string{"test-toolset"},
285+
expectedStatusCode: http.StatusOK,
286+
},
287+
{
288+
name: "Get Prompt - Success",
289+
kind: "prompt",
290+
want: []string{"test-prompt"},
291+
expectedStatusCode: http.StatusOK,
292+
},
293+
{
294+
name: "Get with Invalid Kind - Bad Request",
295+
kind: "invalidKind",
296+
expectedStatusCode: http.StatusBadRequest,
297+
},
298+
}
299+
300+
for _, tt := range tests {
301+
t.Run(tt.name, func(t *testing.T) {
302+
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/%s", tt.kind), nil, nil)
303+
if err != nil {
304+
t.Fatalf("unexpected error during request: %s", err)
305+
}
306+
if resp.StatusCode != tt.expectedStatusCode {
307+
t.Fatalf("response status code is not %d, got %d, %s", tt.expectedStatusCode, resp.StatusCode, string(body))
308+
}
309+
if tt.expectedStatusCode == http.StatusOK {
310+
var got []string
311+
if err := json.Unmarshal(body, &got); err != nil {
312+
t.Fatalf("error unmarshaling response body")
313+
}
314+
if !reflect.DeepEqual(got, tt.want) {
315+
t.Fatalf("unexpected output: got %+v, want %+v", got, tt.want)
316+
}
317+
}
318+
})
319+
}
320+
}
321+
322+
func TestAdminGetByNameEndpoint(t *testing.T) {
232323
mockSource := testutils.MockSource{MockSourceConfig: testutils.MockSourceConfig{Foo: "foo", Password: "password"}}
233324
mockSourceConfigMasked := testutils.MockSourceConfig{Foo: "foo", Password: "***"}
234325
mockAuthService := testutils.MockAuthService{MockAuthServiceConfig: testutils.MockAuthServiceConfig{Foo: "foo"}}

internal/server/resources/resources.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ package resources
1717
import (
1818
"context"
1919
"fmt"
20+
"maps"
2021
"reflect"
22+
"slices"
2123
"sync"
2224

2325
"github.com/googleapis/genai-toolbox/internal/auth"
@@ -76,6 +78,42 @@ func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, aut
7678
r.promptsets = promptsetsMap
7779
}
7880

81+
func (r *ResourceManager) GetSources() []string {
82+
r.mu.RLock()
83+
defer r.mu.RUnlock()
84+
return slices.Collect(maps.Keys(r.sources))
85+
}
86+
87+
func (r *ResourceManager) GetAuthServices() []string {
88+
r.mu.RLock()
89+
defer r.mu.RUnlock()
90+
return slices.Collect(maps.Keys(r.authServices))
91+
}
92+
93+
func (r *ResourceManager) GetEmbeddingModels() []string {
94+
r.mu.RLock()
95+
defer r.mu.RUnlock()
96+
return slices.Collect(maps.Keys(r.embeddingModels))
97+
}
98+
99+
func (r *ResourceManager) GetTools() []string {
100+
r.mu.RLock()
101+
defer r.mu.RUnlock()
102+
return slices.Collect(maps.Keys(r.tools))
103+
}
104+
105+
func (r *ResourceManager) GetToolsets() []string {
106+
r.mu.RLock()
107+
defer r.mu.RUnlock()
108+
return slices.Collect(maps.Keys(r.toolsets))
109+
}
110+
111+
func (r *ResourceManager) GetPrompts() []string {
112+
r.mu.RLock()
113+
defer r.mu.RUnlock()
114+
return slices.Collect(maps.Keys(r.prompts))
115+
}
116+
79117
func (r *ResourceManager) GetSource(sourceName string) (sources.Source, bool) {
80118
r.mu.RLock()
81119
defer r.mu.RUnlock()

internal/server/resources/resources_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,46 @@ func TestDeletePrimitives(t *testing.T) {
346346
t.Fatalf("expected delete to be successful")
347347
}
348348
}
349+
350+
func TestGetPrimitives(t *testing.T) {
351+
resMgr := resources.NewResourceManager(
352+
map[string]sources.Source{"foo": testutils.MockSource{}},
353+
map[string]auth.AuthService{"foo": testutils.MockAuthService{}},
354+
map[string]embeddingmodels.EmbeddingModel{"foo": testutils.MockEmbeddingModel{}},
355+
map[string]tools.Tool{"foo": testutils.MockTool{}},
356+
map[string]tools.Toolset{"foo": tools.Toolset{}},
357+
map[string]prompts.Prompt{"foo": testutils.MockPrompt{}},
358+
map[string]prompts.Promptset{},
359+
)
360+
want := []string{"foo"}
361+
362+
got := resMgr.GetSources()
363+
if !reflect.DeepEqual(got, want) {
364+
t.Fatalf("unexpected sources list: got %v, want %v", got, want)
365+
}
366+
367+
got = resMgr.GetAuthServices()
368+
if !reflect.DeepEqual(got, want) {
369+
t.Fatalf("unexpected auth services list: got %v, want %v", got, want)
370+
}
371+
372+
got = resMgr.GetEmbeddingModels()
373+
if !reflect.DeepEqual(got, want) {
374+
t.Fatalf("unexpected embedding models list: got %v, want %v", got, want)
375+
}
376+
377+
got = resMgr.GetTools()
378+
if !reflect.DeepEqual(got, want) {
379+
t.Fatalf("unexpected tools list: got %v, want %v", got, want)
380+
}
381+
382+
got = resMgr.GetToolsets()
383+
if !reflect.DeepEqual(got, want) {
384+
t.Fatalf("unexpected toolsets list: got %v, want %v", got, want)
385+
}
386+
387+
got = resMgr.GetPrompts()
388+
if !reflect.DeepEqual(got, want) {
389+
t.Fatalf("unexpected prompts list: got %v, want %v", got, want)
390+
}
391+
}

0 commit comments

Comments
 (0)