Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package latest

import (
"encoding/json"
"fmt"

"github.com/goccy/go-yaml"

"github.com/docker/cagent/pkg/config/types"
)

Expand Down Expand Up @@ -194,6 +197,49 @@ func (t *ThinkingBudget) UnmarshalYAML(unmarshal func(any) error) error {
return nil
}

// MarshalYAML implements custom marshaling to output simple string or int format
func (t ThinkingBudget) MarshalYAML() ([]byte, error) {
// If Effort string is set (non-empty), marshal as string
if t.Effort != "" {
return yaml.Marshal(t.Effort)
}

// Otherwise marshal as integer (includes 0, -1, and positive values)
return yaml.Marshal(t.Tokens)
}

// MarshalJSON implements custom marshaling to output simple string or int format
// This ensures JSON and YAML have the same flattened format for consistency
func (t ThinkingBudget) MarshalJSON() ([]byte, error) {
// If Effort string is set (non-empty), marshal as string
if t.Effort != "" {
return []byte(fmt.Sprintf("%q", t.Effort)), nil
}

// Otherwise marshal as integer (includes 0, -1, and positive values)
return []byte(fmt.Sprintf("%d", t.Tokens)), nil
}

// UnmarshalJSON implements custom unmarshaling to accept simple string or int format
// This ensures JSON and YAML have the same flattened format for consistency
func (t *ThinkingBudget) UnmarshalJSON(data []byte) error {
// Try integer tokens first
var n int
if err := json.Unmarshal(data, &n); err == nil {
*t = ThinkingBudget{Tokens: n}
return nil
}

// Try string level
var s string
if err := json.Unmarshal(data, &s); err == nil {
*t = ThinkingBudget{Effort: s}
return nil
}

return nil
}

// StructuredOutput defines a JSON schema for structured output
type StructuredOutput struct {
// Name is the name of the response format
Expand Down Expand Up @@ -283,6 +329,105 @@ func (s *RAGStrategyConfig) UnmarshalYAML(unmarshal func(any) error) error {
return nil
}

// MarshalYAML implements custom marshaling to flatten Params into parent level
func (s RAGStrategyConfig) MarshalYAML() ([]byte, error) {
result := s.buildFlattenedMap()
return yaml.Marshal(result)
}

// MarshalJSON implements custom marshaling to flatten Params into parent level
// This ensures JSON and YAML have the same flattened format for consistency
func (s RAGStrategyConfig) MarshalJSON() ([]byte, error) {
result := s.buildFlattenedMap()
return json.Marshal(result)
}

// UnmarshalJSON implements custom unmarshaling to capture all extra fields into Params
// This ensures JSON and YAML have the same flattened format for consistency
func (s *RAGStrategyConfig) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to capture everything
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}

// Extract known fields
if t, ok := raw["type"].(string); ok {
s.Type = t
delete(raw, "type")
}

if docs, ok := raw["docs"].([]any); ok {
s.Docs = make([]string, len(docs))
for i, d := range docs {
if str, ok := d.(string); ok {
s.Docs[i] = str
}
}
delete(raw, "docs")
}

if dbRaw, ok := raw["database"]; ok {
if dbStr, ok := dbRaw.(string); ok {
var db RAGDatabaseConfig
db.value = dbStr
s.Database = db
}
delete(raw, "database")
}

if chunkRaw, ok := raw["chunking"]; ok {
// Re-marshal and unmarshal chunking config
chunkBytes, _ := json.Marshal(chunkRaw)
var chunk RAGChunkingConfig
if err := json.Unmarshal(chunkBytes, &chunk); err == nil {
s.Chunking = chunk
}
delete(raw, "chunking")
}

if limit, ok := raw["limit"].(float64); ok {
s.Limit = int(limit)
delete(raw, "limit")
}

// Everything else goes into Params for strategy-specific configuration
s.Params = raw

return nil
}

// buildFlattenedMap creates a flattened map representation for marshaling
// Used by both MarshalYAML and MarshalJSON to ensure consistent format
func (s RAGStrategyConfig) buildFlattenedMap() map[string]any {
result := make(map[string]any)

if s.Type != "" {
result["type"] = s.Type
}
if len(s.Docs) > 0 {
result["docs"] = s.Docs
}
if !s.Database.IsEmpty() {
dbStr, _ := s.Database.AsString()
result["database"] = dbStr
}
// Only include chunking if any fields are set
if s.Chunking.Size > 0 || s.Chunking.Overlap > 0 || s.Chunking.RespectWordBoundaries {
result["chunking"] = s.Chunking
}
if s.Limit > 0 {
result["limit"] = s.Limit
}

// Flatten Params into the same level
for k, v := range s.Params {
result[k] = v
}

return result
}

// unmarshalDatabaseConfig handles DatabaseConfig unmarshaling from raw YAML data.
// For RAG strategies, the database configuration is intentionally simple:
// a single string value under the `database` key that points to the SQLite
Expand Down
185 changes: 185 additions & 0 deletions pkg/config/latest/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,188 @@ func TestCommandsUnmarshal_List(t *testing.T) {
require.Equal(t, "check disk", c["df"])
require.Equal(t, "list files", c["ls"])
}

func TestThinkingBudget_MarshalUnmarshal_String(t *testing.T) {
t.Parallel()

// Test string effort level
input := []byte(`thinking_budget: minimal`)
var config struct {
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
}

// Unmarshal
err := yaml.Unmarshal(input, &config)
require.NoError(t, err)
require.NotNil(t, config.ThinkingBudget)
require.Equal(t, "minimal", config.ThinkingBudget.Effort)
require.Equal(t, 0, config.ThinkingBudget.Tokens)

// Marshal back
output, err := yaml.Marshal(config)
require.NoError(t, err)
require.Equal(t, "thinking_budget: minimal\n", string(output))
}

func TestThinkingBudget_MarshalUnmarshal_Integer(t *testing.T) {
t.Parallel()

// Test integer token budget
input := []byte(`thinking_budget: 8192`)
var config struct {
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
}

// Unmarshal
err := yaml.Unmarshal(input, &config)
require.NoError(t, err)
require.NotNil(t, config.ThinkingBudget)
require.Empty(t, config.ThinkingBudget.Effort)
require.Equal(t, 8192, config.ThinkingBudget.Tokens)

// Marshal back
output, err := yaml.Marshal(config)
require.NoError(t, err)
require.Equal(t, "thinking_budget: 8192\n", string(output))
}

func TestThinkingBudget_MarshalUnmarshal_NegativeInteger(t *testing.T) {
t.Parallel()

// Test negative integer token budget (e.g., -1 for Gemini dynamic thinking)
input := []byte(`thinking_budget: -1`)
var config struct {
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
}

// Unmarshal
err := yaml.Unmarshal(input, &config)
require.NoError(t, err)
require.NotNil(t, config.ThinkingBudget)
require.Empty(t, config.ThinkingBudget.Effort)
require.Equal(t, -1, config.ThinkingBudget.Tokens)

// Marshal back
output, err := yaml.Marshal(config)
require.NoError(t, err)
require.Equal(t, "thinking_budget: -1\n", string(output))
}

func TestThinkingBudget_MarshalUnmarshal_Zero(t *testing.T) {
t.Parallel()

// Test zero token budget (e.g., 0 for Gemini no thinking)
input := []byte(`thinking_budget: 0`)
var config struct {
ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"`
}

// Unmarshal
err := yaml.Unmarshal(input, &config)
require.NoError(t, err)
require.NotNil(t, config.ThinkingBudget)
require.Empty(t, config.ThinkingBudget.Effort)
require.Equal(t, 0, config.ThinkingBudget.Tokens)

// Marshal back
output, err := yaml.Marshal(config)
require.NoError(t, err)
require.Equal(t, "thinking_budget: 0\n", string(output))
}

func TestRAGStrategyConfig_MarshalUnmarshal_FlattenedParams(t *testing.T) {
t.Parallel()

// Test that params are flattened during unmarshal and remain flattened after marshal
input := []byte(`type: chunked-embeddings
model: embeddinggemma
database: ./rag/test.db
threshold: 0.5
vector_dimensions: 768
`)

var strategy RAGStrategyConfig

// Unmarshal
err := yaml.Unmarshal(input, &strategy)
require.NoError(t, err)
require.Equal(t, "chunked-embeddings", strategy.Type)
require.Equal(t, "./rag/test.db", mustGetDBString(t, strategy.Database))
require.NotNil(t, strategy.Params)
require.Equal(t, "embeddinggemma", strategy.Params["model"])
require.InEpsilon(t, 0.5, strategy.Params["threshold"], 0.001)
// YAML may unmarshal numbers as different numeric types (int, uint64, float64)
require.InEpsilon(t, float64(768), toFloat64(strategy.Params["vector_dimensions"]), 0.001)

// Marshal back
output, err := yaml.Marshal(strategy)
require.NoError(t, err)

// Verify it's still flattened (no "params:" key)
outputStr := string(output)
require.Contains(t, outputStr, "type: chunked-embeddings")
require.Contains(t, outputStr, "model: embeddinggemma")
require.Contains(t, outputStr, "threshold: 0.5")
require.Contains(t, outputStr, "vector_dimensions: 768")
require.NotContains(t, outputStr, "params:")

// Unmarshal again to verify round-trip
var strategy2 RAGStrategyConfig
err = yaml.Unmarshal(output, &strategy2)
require.NoError(t, err)
require.Equal(t, strategy.Type, strategy2.Type)
require.Equal(t, strategy.Params["model"], strategy2.Params["model"])
require.Equal(t, strategy.Params["threshold"], strategy2.Params["threshold"])
// YAML may unmarshal numbers as different numeric types (int, uint64, float64)
// Just verify the numeric value is correct
require.InEpsilon(t, float64(768), toFloat64(strategy2.Params["vector_dimensions"]), 0.001)
}

func TestRAGStrategyConfig_MarshalUnmarshal_WithDatabase(t *testing.T) {
t.Parallel()

input := []byte(`type: chunked-embeddings
database: ./test.db
model: test-model
`)

var strategy RAGStrategyConfig
err := yaml.Unmarshal(input, &strategy)
require.NoError(t, err)

// Marshal back
output, err := yaml.Marshal(strategy)
require.NoError(t, err)

// Should contain database as a simple string, not nested with sub-fields
outputStr := string(output)
require.Contains(t, outputStr, "database: ./test.db")
require.NotContains(t, outputStr, " value:") // Should not be nested with internal fields
require.Contains(t, outputStr, "model: test-model")
require.NotContains(t, outputStr, "params:") // Should be flattened
}

func mustGetDBString(t *testing.T, db RAGDatabaseConfig) string {
t.Helper()
str, err := db.AsString()
require.NoError(t, err)
return str
}

// toFloat64 converts various numeric types to float64 for comparison
func toFloat64(v any) float64 {
switch val := v.(type) {
case int:
return float64(val)
case int64:
return float64(val)
case uint64:
return float64(val)
case float64:
return val
case float32:
return float64(val)
default:
return 0
}
}
Loading
Loading