Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions internal/server/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func adminRouter(s *Server) (chi.Router, error) {

r.Put("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { createOrUpdatePrimitives(s, w, r) })
r.Delete("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { deletePrimitives(s, w, r) })
r.Get("/{kind}/{name}", func(w http.ResponseWriter, r *http.Request) { getPrimitiveByName(s, w, r) })

return r, nil
}
Expand Down Expand Up @@ -165,3 +166,77 @@ func createOrUpdatePrimitives(s *Server, w http.ResponseWriter, r *http.Request)
return
}
}

func getPrimitiveByName(s *Server, w http.ResponseWriter, r *http.Request) {
kind := chi.URLParam(r, "kind")
name := chi.URLParam(r, "name")
ctx := r.Context()

switch strings.ToLower(kind) {
case "source":
source, ok := s.ResourceMgr.GetSource(name)
if !ok {
err := fmt.Errorf("%s with name %q does not exist", kind, name)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
m := source.ToConfig()
render.JSON(w, r, m)
case "authservice":
as, ok := s.ResourceMgr.GetAuthService(name)
if !ok {
err := fmt.Errorf("%s with name %q does not exist", kind, name)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
m := as.ToConfig()
render.JSON(w, r, m)
case "embeddingmodel":
em, ok := s.ResourceMgr.GetEmbeddingModel(name)
if !ok {
err := fmt.Errorf("%s with name %q does not exist", kind, name)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
m := em.ToConfig()
render.JSON(w, r, m)
case "tool":
tool, ok := s.ResourceMgr.GetTool(name)
if !ok {
err := fmt.Errorf("%s with name %q does not exist", kind, name)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
m := tool.ToConfig()
render.JSON(w, r, m)
case "toolset":
ts, ok := s.ResourceMgr.GetToolset(name)
if !ok {
err := fmt.Errorf("%s with name %q does not exist", kind, name)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
m := ts.ToConfig()
render.JSON(w, r, m)
case "prompt":
prompt, ok := s.ResourceMgr.GetPrompt(name)
if !ok {
err := fmt.Errorf("%s with name %q does not exist", kind, name)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
m := prompt.ToConfig()
render.JSON(w, r, m)
default:
err := fmt.Errorf("invalid primitive kind provided")
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
}
116 changes: 116 additions & 0 deletions internal/server/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"testing"

"github.com/goccy/go-yaml"
Expand Down Expand Up @@ -225,3 +227,117 @@ func TestAdminDeleteEndpoint(t *testing.T) {
})
}
}

func TestAdminGetEndpoint(t *testing.T) {
mockSource := testutils.MockSource{MockSourceConfig: testutils.MockSourceConfig{Foo: "foo", Password: "password"}}
mockSourceConfigMasked := testutils.MockSourceConfig{Foo: "foo", Password: "***"}
mockAuthService := testutils.MockAuthService{MockAuthServiceConfig: testutils.MockAuthServiceConfig{Foo: "foo"}}
mockEmbeddingModel := testutils.MockEmbeddingModel{MockEmbeddingModelConfig: testutils.MockEmbeddingModelConfig{Foo: "foo"}}
mockTool := testutils.MockTool{MockToolConfig: testutils.MockToolConfig{Foo: "foo"}}
mockToolset := tools.Toolset{ToolsetConfig: tools.ToolsetConfig{ToolNames: []string{"test-tool"}}}
mockPrompt := testutils.MockPrompt{MockPromptConfig: testutils.MockPromptConfig{Foo: "foo"}}

mockSources := map[string]sources.Source{"test-source": mockSource}
mockAuthServices := map[string]auth.AuthService{"test-auth-service": mockAuthService}
mockEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"test-embedding-model": mockEmbeddingModel}
mockTools := map[string]tools.Tool{"test-tool": mockTool}
mockToolsets := map[string]tools.Toolset{"test-toolset": mockToolset}
mockPrompts := map[string]prompts.Prompt{"test-prompt": mockPrompt}

r, shutdown := setUpServer(t, "admin", mockSources, mockAuthServices, mockEmbeddingModels, mockTools, mockToolsets, mockPrompts, map[string]prompts.Promptset{})
defer shutdown()
ts := runServer(r, false)
defer ts.Close()

tests := []struct {
name string
kind string
resourceName string
want any
expectedStatusCode int
}{
{
name: "Get Source - Success",
kind: "source",
resourceName: "test-source",
want: mockSourceConfigMasked,
expectedStatusCode: http.StatusOK,
},
{
name: "Get Auth Service - Success",
kind: "authService",
resourceName: "test-auth-service",
want: mockAuthService.ToConfig(),
expectedStatusCode: http.StatusOK,
},
{
name: "Get Embedding Model - Success",
kind: "embeddingModel",
resourceName: "test-embedding-model",
want: mockEmbeddingModel.ToConfig(),
expectedStatusCode: http.StatusOK,
},
{
name: "Get Tool - Success",
kind: "tool",
resourceName: "test-tool",
want: mockTool.ToConfig(),
expectedStatusCode: http.StatusOK,
},
{
name: "Get Toolset - Success",
kind: "toolset",
resourceName: "test-toolset",
want: mockToolset.ToConfig(),
expectedStatusCode: http.StatusOK,
},
{
name: "Get Prompt - Success",
kind: "prompt",
resourceName: "test-prompt",
want: mockPrompt.ToConfig(),
expectedStatusCode: http.StatusOK,
},
{
name: "Get Non-existent Primitive - Not Found",
kind: "source",
resourceName: "non-existent-source",
expectedStatusCode: http.StatusNotFound,
},
{
name: "Get with Invalid Kind - Bad Request",
kind: "invalidKind",
resourceName: "some-name",
expectedStatusCode: http.StatusBadRequest,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/%s/%s", tt.kind, tt.resourceName), nil, nil)
if err != nil {
t.Fatalf("unexpected error during request: %s", err)
}
if resp.StatusCode != tt.expectedStatusCode {
t.Fatalf("response status code is not %d, got %d, %s", tt.expectedStatusCode, resp.StatusCode, string(body))
}
if tt.expectedStatusCode == http.StatusOK {
var got any
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("error unmarshaling response body")
}
var want any
wantBytes, err := json.Marshal(tt.want)
if err != nil {
t.Fatalf("error marshaling want struct")
}
if err = json.Unmarshal(wantBytes, &want); err != nil {
t.Fatalf("error unmarshaling want bytes")
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected output: got %+v, want %+v", got, want)
}
}
})
}
}
6 changes: 3 additions & 3 deletions internal/sources/alloydbpg/alloydb_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ type Config struct {
Cluster string `yaml:"cluster" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
IPType sources.IPType `yaml:"ipType" validate:"required"`
User string `yaml:"user"`
Password string `yaml:"password"`
User util.Secret `yaml:"user"`
Password util.Secret `yaml:"password"`
Database string `yaml:"database" validate:"required"`
}

Expand All @@ -66,7 +66,7 @@ func (r Config) SourceConfigType() string {
}

func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
pool, err := initAlloyDBPgConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Cluster, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
pool, err := initAlloyDBPgConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Cluster, r.Instance, r.IPType.String(), r.User.String(), r.Password.String(), r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
Expand Down
13 changes: 7 additions & 6 deletions internal/sources/alloydbpg/alloydb_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util"
)

func TestParseFromYamlAlloyDBPg(t *testing.T) {
Expand Down Expand Up @@ -55,8 +56,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
User: util.Secret("my_user"),
Password: util.Secret("my_pass"),
},
},
},
Expand Down Expand Up @@ -85,8 +86,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
User: util.Secret("my_user"),
Password: util.Secret("my_pass"),
},
},
},
Expand Down Expand Up @@ -115,8 +116,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
Instance: "my-instance",
IPType: "private",
Database: "my_db",
User: "my_user",
Password: "my_pass",
User: util.Secret("my_user"),
Password: util.Secret("my_pass"),
},
},
},
Expand Down
21 changes: 11 additions & 10 deletions internal/sources/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
_ "github.com/ClickHouse/clickhouse-go/v2"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.opentelemetry.io/otel/trace"
)
Expand All @@ -48,23 +49,23 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
}

type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
Database string `yaml:"database" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password"`
Protocol string `yaml:"protocol"`
Secure bool `yaml:"secure"`
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
Database string `yaml:"database" validate:"required"`
User util.Secret `yaml:"user" validate:"required"`
Password util.Secret `yaml:"password"`
Protocol string `yaml:"protocol"`
Secure bool `yaml:"secure"`
}

func (r Config) SourceConfigType() string {
return SourceType
}

func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
pool, err := initClickHouseConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.Protocol, r.Secure)
pool, err := initClickHouseConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User.String(), r.Password.String(), r.Database, r.Protocol, r.Secure)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/sources/cloudsqlmssql/cloud_sql_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ type Config struct {
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
IPType sources.IPType `yaml:"ipType" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
User util.Secret `yaml:"user" validate:"required"`
Password util.Secret `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
}

Expand All @@ -68,7 +68,7 @@ func (r Config) SourceConfigType() string {

func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
// Initializes a Cloud SQL MSSQL source
db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User.String(), r.Password.String(), r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create db connection: %w", err)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util"
)

func TestParseFromYamlCloudSQLMssql(t *testing.T) {
Expand Down Expand Up @@ -53,8 +54,8 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
User: util.Secret("my_user"),
Password: util.Secret("my_pass"),
},
},
},
Expand All @@ -81,8 +82,8 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
Instance: "my-instance",
IPType: "psc",
Database: "my_db",
User: "my_user",
Password: "my_pass",
User: util.Secret("my_user"),
Password: util.Secret("my_pass"),
},
},
},
Expand Down
6 changes: 3 additions & 3 deletions internal/sources/cloudsqlmysql/cloud_sql_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ type Config struct {
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
IPType sources.IPType `yaml:"ipType"`
User string `yaml:"user"`
Password string `yaml:"password"`
User util.Secret `yaml:"user"`
Password util.Secret `yaml:"password"`
Database string `yaml:"database"`
}

Expand All @@ -66,7 +66,7 @@ func (r Config) SourceConfigType() string {
}

func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
pool, err := initCloudSQLMySQLConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
pool, err := initCloudSQLMySQLConnectionPool(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPType.String(), r.User.String(), r.Password.String(), r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
Expand Down
Loading
Loading