Skip to content

Commit 0bbda17

Browse files
committed
feat: add apply subcommand
1 parent 6b31505 commit 0bbda17

File tree

3 files changed

+501
-0
lines changed

3 files changed

+501
-0
lines changed

cmd/internal/apply/command.go

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package apply
16+
17+
import (
18+
"bytes"
19+
"context"
20+
"encoding/json"
21+
"fmt"
22+
"io"
23+
"net/http"
24+
"os"
25+
"strings"
26+
"time"
27+
28+
"github.com/goccy/go-yaml"
29+
"github.com/googleapis/genai-toolbox/cmd/internal"
30+
"github.com/googleapis/genai-toolbox/internal/log"
31+
"github.com/spf13/cobra"
32+
)
33+
34+
type applyCmd struct {
35+
*cobra.Command
36+
port int
37+
address string
38+
}
39+
40+
func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
41+
cmd := &applyCmd{}
42+
cmd.Command = &cobra.Command{
43+
Use: "apply",
44+
Short: "Apply configuration to the toolbox server",
45+
Long: "Apply configuration to the toolbox server",
46+
}
47+
flags := cmd.Flags()
48+
internal.ConfigFileFlags(flags, opts)
49+
flags.StringVarP(&cmd.address, "address", "a", "127.0.0.1", "Address of the server that is running on.")
50+
flags.IntVarP(&cmd.port, "port", "p", 5000, "Port of the server that is listening on.")
51+
cmd.RunE = func(*cobra.Command, []string) error { return runApply(cmd, opts) }
52+
return cmd.Command
53+
}
54+
55+
// using this type allow O(1) lookups and deletes
56+
type primitives map[string]map[string]struct{}
57+
58+
func NewPrimitives() primitives {
59+
return make(primitives)
60+
}
61+
62+
// Helper to fetch and populate the map
63+
func (p primitives) Load(ctx context.Context, address string, port int) error {
64+
kinds := []string{"source", "authservice", "embeddingmodel", "tool", "toolset", "prompt"}
65+
66+
for _, kind := range kinds {
67+
list, err := getByPrimitiveRequest(ctx, address, port, kind)
68+
if err != nil {
69+
return fmt.Errorf("error getting %s primitives: %w", kind, err)
70+
}
71+
72+
p[kind] = make(map[string]struct{})
73+
for _, name := range list {
74+
p[kind][name] = struct{}{}
75+
}
76+
}
77+
return nil
78+
}
79+
80+
func (p primitives) Exists(kind, name string) bool {
81+
names, ok := p[strings.ToLower(kind)]
82+
if !ok {
83+
return false
84+
}
85+
_, exists := names[name]
86+
return exists
87+
}
88+
89+
func (p primitives) Remove(kind, name string) {
90+
if names, ok := p[strings.ToLower(kind)]; ok {
91+
delete(names, name)
92+
}
93+
}
94+
95+
// adminRequest is a generic helper for admin api requests.
96+
func adminRequest(ctx context.Context, method, url string, body any) ([]byte, error) {
97+
var bodyReader io.Reader
98+
if body != nil {
99+
b, err := json.Marshal(body)
100+
if err != nil {
101+
return nil, fmt.Errorf("error marshaling body: %w", err)
102+
}
103+
bodyReader = bytes.NewReader(b)
104+
}
105+
106+
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
107+
if err != nil {
108+
return nil, fmt.Errorf("request creation failed: %w", err)
109+
}
110+
req.Header.Set("Content-Type", "application/json")
111+
112+
client := &http.Client{
113+
Timeout: 30 * time.Second,
114+
}
115+
resp, err := client.Do(req)
116+
if err != nil {
117+
return nil, err
118+
}
119+
defer resp.Body.Close()
120+
121+
respBody, err := io.ReadAll(resp.Body)
122+
if err != nil {
123+
return nil, fmt.Errorf("could not read response: %w", err)
124+
}
125+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
126+
return nil, fmt.Errorf("server returned %d: %s", resp.StatusCode, string(bytes.TrimSpace(respBody)))
127+
}
128+
return respBody, nil
129+
}
130+
131+
// getByPrimitiveRequest sends a GET request to the admin endpoint
132+
func getByPrimitiveRequest(ctx context.Context, address string, port int, kind string) ([]string, error) {
133+
url := fmt.Sprintf("http://%s:%d/admin/%s", address, port, kind)
134+
135+
respBody, err := adminRequest(ctx, http.MethodGet, url, nil)
136+
if err != nil {
137+
return nil, err
138+
}
139+
140+
var resList []string
141+
if err := json.Unmarshal(respBody, &resList); err != nil {
142+
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
143+
}
144+
return resList, nil
145+
}
146+
147+
// getPrimitiveByName sends a GET request (by primitive name) to the admin endpoint
148+
func getPrimitiveByName(ctx context.Context, address string, port int, kind, name string) (map[string]any, error) {
149+
url := fmt.Sprintf("http://%s:%d/admin/%s/%s", address, port, kind, name)
150+
151+
respBody, err := adminRequest(ctx, http.MethodGet, url, nil)
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
var body map[string]any
157+
if err := json.Unmarshal(respBody, &body); err != nil {
158+
return nil, fmt.Errorf("erorr unmarshaling response body: %w", err)
159+
}
160+
return body, nil
161+
}
162+
163+
// applyPrimitive sends a PUT request to the admin endpoint
164+
func applyPrimitive(ctx context.Context, address string, port int, kind, name string, data map[string]any) error {
165+
url := fmt.Sprintf("http://%s:%d/admin/%s/%s", address, port, kind, name)
166+
167+
_, err := adminRequest(ctx, http.MethodPut, url, data)
168+
if err != nil {
169+
return err
170+
}
171+
return nil
172+
}
173+
174+
func runApply(cmd *applyCmd, opts *internal.ToolboxOptions) error {
175+
ctx, cancel := context.WithCancel(cmd.Context())
176+
defer cancel()
177+
178+
ctx, shutdown, err := opts.Setup(ctx)
179+
if err != nil {
180+
return err
181+
}
182+
defer func() {
183+
_ = shutdown(ctx)
184+
}()
185+
186+
filePaths, _, err := opts.GetCustomConfigFiles(ctx)
187+
if err != nil {
188+
errMsg := fmt.Errorf("failed to retrieve config files: %w", err)
189+
opts.Logger.ErrorContext(ctx, errMsg.Error())
190+
return errMsg
191+
}
192+
193+
// GET all the primitive lists in the server
194+
p := NewPrimitives()
195+
if err := p.Load(ctx, cmd.address, cmd.port); err != nil {
196+
return err
197+
}
198+
199+
opts.Logger.InfoContext(ctx, "starting apply sequence", "count", len(filePaths))
200+
for _, filePath := range filePaths {
201+
if err := processFile(ctx, opts.Logger, filePath, p, cmd.address, cmd.port); err != nil {
202+
opts.Logger.ErrorContext(ctx, err.Error())
203+
return err
204+
}
205+
}
206+
opts.Logger.InfoContext(ctx, "Done applying")
207+
return nil
208+
}
209+
210+
func processFile(ctx context.Context, logger log.Logger, path string, p primitives, address string, port int) error {
211+
f, err := os.Open(path)
212+
if err != nil {
213+
return fmt.Errorf("unable to open file at %q: %w", path, err)
214+
}
215+
defer f.Close() // Safe closure
216+
217+
decoder := yaml.NewDecoder(f)
218+
// loop through documents with the `---` separator
219+
for {
220+
var doc map[string]any
221+
if err := decoder.Decode(&doc); err == io.EOF {
222+
break
223+
} else if err != nil {
224+
return fmt.Errorf("unable to decode YAML document: %w", err)
225+
}
226+
227+
if len(doc) == 0 {
228+
continue
229+
}
230+
231+
kind, kOk := doc["kind"].(string)
232+
name, nOk := doc["name"].(string)
233+
if !kOk || !nOk || kind == "" || name == "" {
234+
logger.WarnContext(ctx, fmt.Sprintf("invalid primitive schema: missing metadata in %s: kind and name are required", path))
235+
continue
236+
}
237+
238+
delete(doc, "kind")
239+
240+
if p.Exists(kind, name) {
241+
p.Remove(kind, name)
242+
remoteBody, err := getPrimitiveByName(ctx, address, port, kind, name)
243+
if err != nil {
244+
return err
245+
}
246+
localJSON, err := json.Marshal(doc)
247+
if err != nil {
248+
return fmt.Errorf("failed to marshal local config for %s: %w", name, err)
249+
}
250+
remoteJSON, err := json.Marshal(remoteBody)
251+
if err != nil {
252+
return fmt.Errorf("failed to marshal remote config for %s: %w", name, err)
253+
}
254+
255+
if bytes.Equal(localJSON, remoteJSON) {
256+
logger.DebugContext(ctx, "skipping: no changes detected", "kind", kind, "name", name)
257+
continue
258+
}
259+
logger.DebugContext(ctx, "change detected, updating resource", "kind", kind, "name", name)
260+
}
261+
262+
// TODO: check --prune flag: if prune, delete primitives that are left
263+
// in the primitive list
264+
// TODO: check for --dry-run flag.
265+
266+
if err := applyPrimitive(ctx, address, port, kind, name, doc); err != nil {
267+
return err
268+
}
269+
}
270+
return nil
271+
}

0 commit comments

Comments
 (0)