Skip to content

Commit 4964df5

Browse files
authored
Merge pull request #919 from krissetto/fix-yaml-marshalling-in-push-pull
Fix thinking budget and rag strategy marshaling/unmarshaling
2 parents c519da0 + bd9f990 commit 4964df5

File tree

3 files changed

+470
-0
lines changed

3 files changed

+470
-0
lines changed

pkg/config/latest/types.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
package latest
22

33
import (
4+
"encoding/json"
45
"fmt"
56

7+
"github.com/goccy/go-yaml"
8+
69
"github.com/docker/cagent/pkg/config/types"
710
)
811

@@ -194,6 +197,49 @@ func (t *ThinkingBudget) UnmarshalYAML(unmarshal func(any) error) error {
194197
return nil
195198
}
196199

200+
// MarshalYAML implements custom marshaling to output simple string or int format
201+
func (t ThinkingBudget) MarshalYAML() ([]byte, error) {
202+
// If Effort string is set (non-empty), marshal as string
203+
if t.Effort != "" {
204+
return yaml.Marshal(t.Effort)
205+
}
206+
207+
// Otherwise marshal as integer (includes 0, -1, and positive values)
208+
return yaml.Marshal(t.Tokens)
209+
}
210+
211+
// MarshalJSON implements custom marshaling to output simple string or int format
212+
// This ensures JSON and YAML have the same flattened format for consistency
213+
func (t ThinkingBudget) MarshalJSON() ([]byte, error) {
214+
// If Effort string is set (non-empty), marshal as string
215+
if t.Effort != "" {
216+
return []byte(fmt.Sprintf("%q", t.Effort)), nil
217+
}
218+
219+
// Otherwise marshal as integer (includes 0, -1, and positive values)
220+
return []byte(fmt.Sprintf("%d", t.Tokens)), nil
221+
}
222+
223+
// UnmarshalJSON implements custom unmarshaling to accept simple string or int format
224+
// This ensures JSON and YAML have the same flattened format for consistency
225+
func (t *ThinkingBudget) UnmarshalJSON(data []byte) error {
226+
// Try integer tokens first
227+
var n int
228+
if err := json.Unmarshal(data, &n); err == nil {
229+
*t = ThinkingBudget{Tokens: n}
230+
return nil
231+
}
232+
233+
// Try string level
234+
var s string
235+
if err := json.Unmarshal(data, &s); err == nil {
236+
*t = ThinkingBudget{Effort: s}
237+
return nil
238+
}
239+
240+
return nil
241+
}
242+
197243
// StructuredOutput defines a JSON schema for structured output
198244
type StructuredOutput struct {
199245
// Name is the name of the response format
@@ -283,6 +329,105 @@ func (s *RAGStrategyConfig) UnmarshalYAML(unmarshal func(any) error) error {
283329
return nil
284330
}
285331

332+
// MarshalYAML implements custom marshaling to flatten Params into parent level
333+
func (s RAGStrategyConfig) MarshalYAML() ([]byte, error) {
334+
result := s.buildFlattenedMap()
335+
return yaml.Marshal(result)
336+
}
337+
338+
// MarshalJSON implements custom marshaling to flatten Params into parent level
339+
// This ensures JSON and YAML have the same flattened format for consistency
340+
func (s RAGStrategyConfig) MarshalJSON() ([]byte, error) {
341+
result := s.buildFlattenedMap()
342+
return json.Marshal(result)
343+
}
344+
345+
// UnmarshalJSON implements custom unmarshaling to capture all extra fields into Params
346+
// This ensures JSON and YAML have the same flattened format for consistency
347+
func (s *RAGStrategyConfig) UnmarshalJSON(data []byte) error {
348+
// First unmarshal into a map to capture everything
349+
var raw map[string]any
350+
if err := json.Unmarshal(data, &raw); err != nil {
351+
return err
352+
}
353+
354+
// Extract known fields
355+
if t, ok := raw["type"].(string); ok {
356+
s.Type = t
357+
delete(raw, "type")
358+
}
359+
360+
if docs, ok := raw["docs"].([]any); ok {
361+
s.Docs = make([]string, len(docs))
362+
for i, d := range docs {
363+
if str, ok := d.(string); ok {
364+
s.Docs[i] = str
365+
}
366+
}
367+
delete(raw, "docs")
368+
}
369+
370+
if dbRaw, ok := raw["database"]; ok {
371+
if dbStr, ok := dbRaw.(string); ok {
372+
var db RAGDatabaseConfig
373+
db.value = dbStr
374+
s.Database = db
375+
}
376+
delete(raw, "database")
377+
}
378+
379+
if chunkRaw, ok := raw["chunking"]; ok {
380+
// Re-marshal and unmarshal chunking config
381+
chunkBytes, _ := json.Marshal(chunkRaw)
382+
var chunk RAGChunkingConfig
383+
if err := json.Unmarshal(chunkBytes, &chunk); err == nil {
384+
s.Chunking = chunk
385+
}
386+
delete(raw, "chunking")
387+
}
388+
389+
if limit, ok := raw["limit"].(float64); ok {
390+
s.Limit = int(limit)
391+
delete(raw, "limit")
392+
}
393+
394+
// Everything else goes into Params for strategy-specific configuration
395+
s.Params = raw
396+
397+
return nil
398+
}
399+
400+
// buildFlattenedMap creates a flattened map representation for marshaling
401+
// Used by both MarshalYAML and MarshalJSON to ensure consistent format
402+
func (s RAGStrategyConfig) buildFlattenedMap() map[string]any {
403+
result := make(map[string]any)
404+
405+
if s.Type != "" {
406+
result["type"] = s.Type
407+
}
408+
if len(s.Docs) > 0 {
409+
result["docs"] = s.Docs
410+
}
411+
if !s.Database.IsEmpty() {
412+
dbStr, _ := s.Database.AsString()
413+
result["database"] = dbStr
414+
}
415+
// Only include chunking if any fields are set
416+
if s.Chunking.Size > 0 || s.Chunking.Overlap > 0 || s.Chunking.RespectWordBoundaries {
417+
result["chunking"] = s.Chunking
418+
}
419+
if s.Limit > 0 {
420+
result["limit"] = s.Limit
421+
}
422+
423+
// Flatten Params into the same level
424+
for k, v := range s.Params {
425+
result[k] = v
426+
}
427+
428+
return result
429+
}
430+
286431
// unmarshalDatabaseConfig handles DatabaseConfig unmarshaling from raw YAML data.
287432
// For RAG strategies, the database configuration is intentionally simple:
288433
// a single string value under the `database` key that points to the SQLite

pkg/config/latest/types_test.go

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,188 @@ func TestCommandsUnmarshal_List(t *testing.T) {
3232
require.Equal(t, "check disk", c["df"])
3333
require.Equal(t, "list files", c["ls"])
3434
}
35+
36+
func TestThinkingBudget_MarshalUnmarshal_String(t *testing.T) {
37+
t.Parallel()
38+
39+
// Test string effort level
40+
input := []byte(`thinking_budget: minimal`)
41+
var config struct {
42+
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
43+
}
44+
45+
// Unmarshal
46+
err := yaml.Unmarshal(input, &config)
47+
require.NoError(t, err)
48+
require.NotNil(t, config.ThinkingBudget)
49+
require.Equal(t, "minimal", config.ThinkingBudget.Effort)
50+
require.Equal(t, 0, config.ThinkingBudget.Tokens)
51+
52+
// Marshal back
53+
output, err := yaml.Marshal(config)
54+
require.NoError(t, err)
55+
require.Equal(t, "thinking_budget: minimal\n", string(output))
56+
}
57+
58+
func TestThinkingBudget_MarshalUnmarshal_Integer(t *testing.T) {
59+
t.Parallel()
60+
61+
// Test integer token budget
62+
input := []byte(`thinking_budget: 8192`)
63+
var config struct {
64+
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
65+
}
66+
67+
// Unmarshal
68+
err := yaml.Unmarshal(input, &config)
69+
require.NoError(t, err)
70+
require.NotNil(t, config.ThinkingBudget)
71+
require.Empty(t, config.ThinkingBudget.Effort)
72+
require.Equal(t, 8192, config.ThinkingBudget.Tokens)
73+
74+
// Marshal back
75+
output, err := yaml.Marshal(config)
76+
require.NoError(t, err)
77+
require.Equal(t, "thinking_budget: 8192\n", string(output))
78+
}
79+
80+
func TestThinkingBudget_MarshalUnmarshal_NegativeInteger(t *testing.T) {
81+
t.Parallel()
82+
83+
// Test negative integer token budget (e.g., -1 for Gemini dynamic thinking)
84+
input := []byte(`thinking_budget: -1`)
85+
var config struct {
86+
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
87+
}
88+
89+
// Unmarshal
90+
err := yaml.Unmarshal(input, &config)
91+
require.NoError(t, err)
92+
require.NotNil(t, config.ThinkingBudget)
93+
require.Empty(t, config.ThinkingBudget.Effort)
94+
require.Equal(t, -1, config.ThinkingBudget.Tokens)
95+
96+
// Marshal back
97+
output, err := yaml.Marshal(config)
98+
require.NoError(t, err)
99+
require.Equal(t, "thinking_budget: -1\n", string(output))
100+
}
101+
102+
func TestThinkingBudget_MarshalUnmarshal_Zero(t *testing.T) {
103+
t.Parallel()
104+
105+
// Test zero token budget (e.g., 0 for Gemini no thinking)
106+
input := []byte(`thinking_budget: 0`)
107+
var config struct {
108+
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
109+
}
110+
111+
// Unmarshal
112+
err := yaml.Unmarshal(input, &config)
113+
require.NoError(t, err)
114+
require.NotNil(t, config.ThinkingBudget)
115+
require.Empty(t, config.ThinkingBudget.Effort)
116+
require.Equal(t, 0, config.ThinkingBudget.Tokens)
117+
118+
// Marshal back
119+
output, err := yaml.Marshal(config)
120+
require.NoError(t, err)
121+
require.Equal(t, "thinking_budget: 0\n", string(output))
122+
}
123+
124+
func TestRAGStrategyConfig_MarshalUnmarshal_FlattenedParams(t *testing.T) {
125+
t.Parallel()
126+
127+
// Test that params are flattened during unmarshal and remain flattened after marshal
128+
input := []byte(`type: chunked-embeddings
129+
model: embeddinggemma
130+
database: ./rag/test.db
131+
threshold: 0.5
132+
vector_dimensions: 768
133+
`)
134+
135+
var strategy RAGStrategyConfig
136+
137+
// Unmarshal
138+
err := yaml.Unmarshal(input, &strategy)
139+
require.NoError(t, err)
140+
require.Equal(t, "chunked-embeddings", strategy.Type)
141+
require.Equal(t, "./rag/test.db", mustGetDBString(t, strategy.Database))
142+
require.NotNil(t, strategy.Params)
143+
require.Equal(t, "embeddinggemma", strategy.Params["model"])
144+
require.InEpsilon(t, 0.5, strategy.Params["threshold"], 0.001)
145+
// YAML may unmarshal numbers as different numeric types (int, uint64, float64)
146+
require.InEpsilon(t, float64(768), toFloat64(strategy.Params["vector_dimensions"]), 0.001)
147+
148+
// Marshal back
149+
output, err := yaml.Marshal(strategy)
150+
require.NoError(t, err)
151+
152+
// Verify it's still flattened (no "params:" key)
153+
outputStr := string(output)
154+
require.Contains(t, outputStr, "type: chunked-embeddings")
155+
require.Contains(t, outputStr, "model: embeddinggemma")
156+
require.Contains(t, outputStr, "threshold: 0.5")
157+
require.Contains(t, outputStr, "vector_dimensions: 768")
158+
require.NotContains(t, outputStr, "params:")
159+
160+
// Unmarshal again to verify round-trip
161+
var strategy2 RAGStrategyConfig
162+
err = yaml.Unmarshal(output, &strategy2)
163+
require.NoError(t, err)
164+
require.Equal(t, strategy.Type, strategy2.Type)
165+
require.Equal(t, strategy.Params["model"], strategy2.Params["model"])
166+
require.Equal(t, strategy.Params["threshold"], strategy2.Params["threshold"])
167+
// YAML may unmarshal numbers as different numeric types (int, uint64, float64)
168+
// Just verify the numeric value is correct
169+
require.InEpsilon(t, float64(768), toFloat64(strategy2.Params["vector_dimensions"]), 0.001)
170+
}
171+
172+
func TestRAGStrategyConfig_MarshalUnmarshal_WithDatabase(t *testing.T) {
173+
t.Parallel()
174+
175+
input := []byte(`type: chunked-embeddings
176+
database: ./test.db
177+
model: test-model
178+
`)
179+
180+
var strategy RAGStrategyConfig
181+
err := yaml.Unmarshal(input, &strategy)
182+
require.NoError(t, err)
183+
184+
// Marshal back
185+
output, err := yaml.Marshal(strategy)
186+
require.NoError(t, err)
187+
188+
// Should contain database as a simple string, not nested with sub-fields
189+
outputStr := string(output)
190+
require.Contains(t, outputStr, "database: ./test.db")
191+
require.NotContains(t, outputStr, " value:") // Should not be nested with internal fields
192+
require.Contains(t, outputStr, "model: test-model")
193+
require.NotContains(t, outputStr, "params:") // Should be flattened
194+
}
195+
196+
func mustGetDBString(t *testing.T, db RAGDatabaseConfig) string {
197+
t.Helper()
198+
str, err := db.AsString()
199+
require.NoError(t, err)
200+
return str
201+
}
202+
203+
// toFloat64 converts various numeric types to float64 for comparison
204+
func toFloat64(v any) float64 {
205+
switch val := v.(type) {
206+
case int:
207+
return float64(val)
208+
case int64:
209+
return float64(val)
210+
case uint64:
211+
return float64(val)
212+
case float64:
213+
return val
214+
case float32:
215+
return float64(val)
216+
default:
217+
return 0
218+
}
219+
}

0 commit comments

Comments
 (0)