Skip to content

Commit ead927e

Browse files
authored
Merge pull request #2344 from dgageot/board/docker-agent-issue-1701-implementation-0d8027cc
feat: support custom providers in RAG embedding and reranking models
2 parents f711e4e + 5c8319a commit ead927e

File tree

8 files changed

+74
-63
lines changed

8 files changed

+74
-63
lines changed

examples/rag/custom_provider.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# This example demonstrates using a custom provider for RAG embedding models.
2+
# For instance, you can use a local Ollama instance or any OpenAI-compatible
3+
# API endpoint for generating embeddings.
4+
5+
providers:
6+
local-ollama:
7+
base_url: http://localhost:11434/v1
8+
9+
models:
10+
local-embed:
11+
provider: local-ollama
12+
model: nomic-embed-text
13+
14+
agents:
15+
root:
16+
model: openai/gpt-5-mini
17+
description: assistant with RAG using custom embedding provider
18+
instruction: |
19+
You are a helpful assistant with access to a knowledge base.
20+
Use the search tool to find relevant information before answering.
21+
toolsets:
22+
- type: rag
23+
ref: knowledge_base
24+
25+
rag:
26+
knowledge_base:
27+
tool:
28+
description: search the knowledge base for relevant information
29+
docs:
30+
- ./docs
31+
strategies:
32+
- type: chunked-embeddings
33+
embedding_model: local-embed # References the model defined above using the custom provider
34+
database: ./custom_provider.db
35+
vector_dimensions: 768
36+
chunking:
37+
size: 1000
38+
overlap: 100
39+
results:
40+
limit: 5

pkg/config/runtime.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Config struct {
2525
GlobalCodeMode bool
2626
WorkingDir string
2727
Models map[string]latest.ModelConfig
28+
Providers map[string]latest.ProviderConfig
2829

2930
// Hook overrides from CLI flags
3031
HookPreToolUse []string
@@ -40,6 +41,7 @@ func (runConfig *RuntimeConfig) Clone() *RuntimeConfig {
4041
}
4142
clone.EnvFiles = slices.Clone(runConfig.EnvFiles)
4243
clone.Models = maps.Clone(runConfig.Models)
44+
clone.Providers = maps.Clone(runConfig.Providers)
4345
clone.DefaultModel = runConfig.DefaultModel.Clone()
4446
clone.HookPreToolUse = slices.Clone(runConfig.HookPreToolUse)
4547
clone.HookPostToolUse = slices.Clone(runConfig.HookPostToolUse)

pkg/rag/builder.go

Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@ import (
55
"errors"
66
"fmt"
77
"log/slog"
8-
"maps"
9-
"slices"
108

119
"github.com/docker/docker-agent/pkg/config/latest"
1210
"github.com/docker/docker-agent/pkg/environment"
1311
"github.com/docker/docker-agent/pkg/model/provider"
12+
"github.com/docker/docker-agent/pkg/model/provider/options"
1413
"github.com/docker/docker-agent/pkg/rag/rerank"
1514
"github.com/docker/docker-agent/pkg/rag/strategy"
1615
"github.com/docker/docker-agent/pkg/rag/types"
@@ -21,7 +20,16 @@ type ManagersBuildConfig struct {
2120
ParentDir string
2221
ModelsGateway string
2322
Env environment.Provider
24-
Models map[string]latest.ModelConfig // Model configurations from config
23+
Models map[string]latest.ModelConfig // Model configurations from config
24+
Providers map[string]latest.ProviderConfig // Custom provider configurations from config
25+
}
26+
27+
// NewProvider creates a model provider using the build config's environment,
28+
// gateway, and custom provider settings.
29+
func (c ManagersBuildConfig) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) {
30+
return provider.New(ctx, cfg, c.Env,
31+
options.WithGateway(c.ModelsGateway),
32+
options.WithProviders(c.Providers))
2533
}
2634

2735
// NewManager constructs a single RAG manager from a RAGConfig.
@@ -46,6 +54,7 @@ func NewManager(
4654
ParentDir: buildCfg.ParentDir,
4755
SharedDocs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs),
4856
Models: buildCfg.Models,
57+
Providers: buildCfg.Providers,
4958
Env: buildCfg.Env,
5059
ModelsGateway: buildCfg.ModelsGateway,
5160
RespectVCS: ragCfg.GetRespectVCS(),
@@ -146,20 +155,21 @@ func buildRerankingConfig(
146155
"model_ref", rerankCfg.Model)
147156

148157
// Resolve model config - check if it's a reference to a defined model or inline
149-
modelCfg, err := resolveModelConfig(rerankCfg.Model, buildCfg)
158+
modelCfgVal, err := strategy.ResolveModelConfig(rerankCfg.Model, buildCfg.Models)
150159
if err != nil {
151160
slog.Error("Failed to resolve reranking model",
152161
"model_ref", rerankCfg.Model,
153162
"error", err)
154163
return nil, fmt.Errorf("failed to resolve reranking model %q: %w", rerankCfg.Model, err)
155164
}
165+
modelCfg := &modelCfgVal
156166

157167
slog.Debug("Resolved reranking model config",
158168
"provider", modelCfg.Provider,
159169
"model", modelCfg.Model)
160170

161171
// Create provider for reranking model
162-
rerankProvider, err := provider.New(ctx, modelCfg, buildCfg.Env)
172+
rerankProvider, err := buildCfg.NewProvider(ctx, modelCfg)
163173
if err != nil {
164174
slog.Error("Failed to create reranking provider",
165175
"provider", modelCfg.Provider,
@@ -206,55 +216,6 @@ func buildRerankingConfig(
206216
}, nil
207217
}
208218

209-
// resolveModelConfig resolves a model name to a ModelConfig
210-
// Handles both inline model references (e.g., "dmr/model-name") and defined model names
211-
func resolveModelConfig(modelName string, buildCfg ManagersBuildConfig) (*latest.ModelConfig, error) {
212-
// Check if it's an inline model reference (contains a '/')
213-
if modelName != "" {
214-
parts := splitModelRef(modelName)
215-
if len(parts) == 2 {
216-
// Inline model reference like "dmr/hf.co/model" or "openai/gpt-5"
217-
slog.Debug("Using inline model reference",
218-
"provider", parts[0],
219-
"model", parts[1])
220-
return &latest.ModelConfig{
221-
Provider: parts[0],
222-
Model: parts[1],
223-
}, nil
224-
}
225-
}
226-
227-
// Try to find model in defined models
228-
if modelCfg, exists := buildCfg.Models[modelName]; exists {
229-
slog.Debug("Using defined model from config",
230-
"model_name", modelName,
231-
"provider", modelCfg.Provider,
232-
"model", modelCfg.Model)
233-
return &modelCfg, nil
234-
}
235-
236-
slog.Error("Model not found in configuration",
237-
"model_name", modelName,
238-
"available_models", getModelNames(buildCfg.Models))
239-
return nil, fmt.Errorf("model %q not found in configuration", modelName)
240-
}
241-
242-
// getModelNames extracts model names from the models map for logging
243-
func getModelNames(models map[string]latest.ModelConfig) []string {
244-
return slices.Collect(maps.Keys(models))
245-
}
246-
247-
// splitModelRef splits a model reference into provider and model parts
248-
func splitModelRef(ref string) []string {
249-
// Handle common patterns: "provider/model"
250-
for i := range len(ref) {
251-
if ref[i] == '/' {
252-
return []string{ref[:i], ref[i+1:]}
253-
}
254-
}
255-
return []string{ref}
256-
}
257-
258219
// buildStrategyConfigs builds the strategy configs for the RAG.
259220
// Returns a slice of strategy configs and a channel for receiving strategy events.
260221
func buildStrategyConfigs(

pkg/rag/strategy/embedding.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/docker/docker-agent/pkg/config"
1010
"github.com/docker/docker-agent/pkg/config/latest"
1111
"github.com/docker/docker-agent/pkg/model/provider"
12-
"github.com/docker/docker-agent/pkg/model/provider/options"
1312
"github.com/docker/docker-agent/pkg/modelsdev"
1413
"github.com/docker/docker-agent/pkg/rag/embed"
1514
)
@@ -41,8 +40,7 @@ func CreateEmbeddingProvider(ctx context.Context, modelName string, buildCtx Bui
4140
return nil, fmt.Errorf("model '%s' not found: %w", modelName, err)
4241
}
4342

44-
embedModel, err = provider.New(ctx, &modelCfg, buildCtx.Env,
45-
options.WithGateway(buildCtx.ModelsGateway))
43+
embedModel, err = buildCtx.NewProvider(ctx, &modelCfg)
4644
if err != nil {
4745
return nil, fmt.Errorf("failed to create embedding model: %w", err)
4846
}
@@ -80,8 +78,7 @@ func createAutoEmbeddingModel(ctx context.Context, buildCtx BuildContext) (provi
8078
Model: autoModelCfg.Model,
8179
}
8280

83-
model, err := provider.New(ctx, &modelCfg, buildCtx.Env,
84-
options.WithGateway(buildCtx.ModelsGateway))
81+
model, err := buildCtx.NewProvider(ctx, &modelCfg)
8582
if err != nil {
8683
lastErr = err
8784
continue

pkg/rag/strategy/semantic_embeddings.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
"github.com/docker/docker-agent/pkg/config/latest"
1717
"github.com/docker/docker-agent/pkg/js"
1818
"github.com/docker/docker-agent/pkg/model/provider"
19-
"github.com/docker/docker-agent/pkg/model/provider/options"
2019
"github.com/docker/docker-agent/pkg/rag/chunk"
2120
"github.com/docker/docker-agent/pkg/rag/types"
2221
"github.com/docker/docker-agent/pkg/tools"
@@ -89,8 +88,7 @@ func NewSemanticEmbeddingsFromConfig(ctx context.Context, cfg latest.RAGStrategy
8988
return nil, fmt.Errorf("invalid chat_model %q: %w", chatModelName, err)
9089
}
9190

92-
chatProvider, err := provider.New(ctx, &chatModelCfg, buildCtx.Env,
93-
options.WithGateway(buildCtx.ModelsGateway))
91+
chatProvider, err := buildCtx.NewProvider(ctx, &chatModelCfg)
9492
if err != nil {
9593
return nil, fmt.Errorf("failed to create chat model provider: %w", err)
9694
}

pkg/rag/strategy/strategy.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,31 @@ import (
66

77
"github.com/docker/docker-agent/pkg/config/latest"
88
"github.com/docker/docker-agent/pkg/environment"
9+
"github.com/docker/docker-agent/pkg/model/provider"
10+
"github.com/docker/docker-agent/pkg/model/provider/options"
911
"github.com/docker/docker-agent/pkg/rag/types"
1012
)
1113

12-
// BuildContext contains everything needed to build a strategy
14+
// BuildContext contains everything needed to build a strategy.
1315
type BuildContext struct {
1416
RAGName string
1517
ParentDir string
1618
SharedDocs []string
1719
Models map[string]latest.ModelConfig
20+
Providers map[string]latest.ProviderConfig
1821
Env environment.Provider
1922
ModelsGateway string
2023
RespectVCS bool // Whether to respect VCS ignore files (e.g., .gitignore) when collecting files
2124
}
2225

26+
// NewProvider creates a model provider using the build context's environment,
27+
// gateway, and custom provider settings.
28+
func (c BuildContext) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) {
29+
return provider.New(ctx, cfg, c.Env,
30+
options.WithGateway(c.ModelsGateway),
31+
options.WithProviders(c.Providers))
32+
}
33+
2334
// BuildStrategy builds a strategy from config
2435
// Explicitly dispatches to the appropriate constructor based on type
2536
func BuildStrategy(ctx context.Context, cfg latest.RAGStrategyConfig, buildCtx BuildContext, events chan<- types.Event) (*Config, error) {

pkg/teamloader/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ func createRAGTool(ctx context.Context, toolset latest.Toolset, parentDir string
369369
ModelsGateway: runConfig.ModelsGateway,
370370
Env: runConfig.EnvProvider(),
371371
Models: runConfig.Models,
372+
Providers: runConfig.Providers,
372373
})
373374
if err != nil {
374375
return nil, fmt.Errorf("failed to create RAG manager: %w", err)

pkg/teamloader/teamloader.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c
123123

124124
// Make model definitions available to toolset creators (e.g., RAG reranking)
125125
runConfig.Models = cfg.Models
126+
runConfig.Providers = cfg.Providers
126127

127128
// Load agents
128129
parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir)

0 commit comments

Comments
 (0)