Skip to content

Commit 1c67071

Browse files
feat(embedding): multi-provider abstraction layer (#62)
Add embedding.NewProvider factory supporting OpenAI, Ollama, and Cohere via a unified ProviderConfig. Providers self-register via init() so callers only need a blank import. New packages: pkg/embedding/ollama - local Ollama server (/api/embeddings) pkg/embedding/cohere - Cohere API (embed-english-v3.0 default) New files: pkg/embedding/registry.go - NewProvider, RegisterFactory, ProviderConfig, SupportedProviders pkg/embedding/openai/register.go - registers OpenAI into the factory pkg/embedding/ollama/register.go - registers Ollama into the factory pkg/embedding/cohere/register.go - registers Cohere into the factory pkg/embedding/registry_test.go - custom provider, unknown type, ollama resolution, cache wrapping CacheSize=-1 disables the in-memory cache; 0 uses the default (10k). Existing cmd/ code continues to use openai.NewClient directly and is unaffected. Co-authored-by: Ona <no-reply@ona.com>
1 parent b671cd6 commit 1c67071

8 files changed

Lines changed: 599 additions & 0 deletions

File tree

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,40 @@ Pattern → annotation mapping:
904904
- **Code Review** - Blast radius analysis for PRs
905905
- **Enterprise** - Deterministic outputs with source attribution
906906
907+
## Embedding Providers
908+
909+
Distill supports multiple embedding backends via a unified factory. Import the provider package to register it, then call `embedding.NewProvider`:
910+
911+
```go
912+
import (
913+
"github.com/Siddhant-K-code/distill/pkg/embedding"
914+
_ "github.com/Siddhant-K-code/distill/pkg/embedding/openai" // register OpenAI
915+
_ "github.com/Siddhant-K-code/distill/pkg/embedding/ollama" // register Ollama
916+
_ "github.com/Siddhant-K-code/distill/pkg/embedding/cohere" // register Cohere
917+
)
918+
919+
provider, err := embedding.NewProvider(embedding.ProviderConfig{
920+
Type: embedding.ProviderOllama, // "openai" | "ollama" | "cohere"
921+
BaseURL: "http://localhost:11434", // optional override
922+
Model: "nomic-embed-text", // optional override
923+
CacheSize: 10000, // 0 = default (10k), -1 = disabled
924+
})
925+
```
926+
927+
| Provider | Type string | Default model | Notes |
928+
|----------|-------------|---------------|-------|
929+
| OpenAI | `openai` | `text-embedding-3-small` | Requires `OPENAI_API_KEY` |
930+
| Ollama | `ollama` | `nomic-embed-text` | Local server, no API key |
931+
| Cohere | `cohere` | `embed-english-v3.0` | Requires `COHERE_API_KEY` |
932+
933+
Custom providers can be registered at startup:
934+
935+
```go
936+
embedding.RegisterFactory("my-provider", func(cfg embedding.ProviderConfig) (embedding.Provider, error) {
937+
return myProvider{apiKey: cfg.APIKey}, nil
938+
})
939+
```
940+
907941
## Roadmap
908942

909943
Distill is evolving from a dedup utility into a context intelligence layer. Here's what's next:
@@ -922,6 +956,7 @@ Distill is evolving from a dedup utility into a context intelligence layer. Here
922956
| **Prefix stability validator** | [#48](https://github.com/Siddhant-K-code/distill/issues/48) | Shipped | `StabilityValidator` tracks prefix hashes across requests and detects dynamic content (timestamps, request IDs, UUIDs) bleeding into cached prefixes. |
923957
| **Per-call-site hit rate tracking** | [#47](https://github.com/Siddhant-K-code/distill/issues/47) | Shipped | `CallSiteTracker` records Anthropic cache usage per call site; `AllStats()` returns worst performers first. |
924958
| **TTL-aware cache tracker** | [#49](https://github.com/Siddhant-K-code/distill/issues/49) | Shipped | `TTLTracker` monitors Anthropic's 5-minute cache TTL per prefix hash. `ScheduleDeadline` tells batch jobs the latest safe time to send the next request. |
959+
| **Multi-provider embedding abstraction** | [#33](https://github.com/Siddhant-K-code/distill/issues/33) | Shipped | `embedding.NewProvider` factory supports OpenAI, Ollama, and Cohere via a unified `ProviderConfig`. Custom providers register via `RegisterFactory`. |
925960

926961
### Code Intelligence
927962

pkg/embedding/cohere/client.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Package cohere provides an embedding.Provider backed by the Cohere API.
2+
package cohere
3+
4+
import (
5+
"bytes"
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"time"
12+
13+
"github.com/Siddhant-K-code/distill/pkg/embedding"
14+
)
15+
16+
const (
17+
defaultBaseURL = "https://api.cohere.ai/v1"
18+
defaultModel = "embed-english-v3.0"
19+
defaultTimeout = 30 * time.Second
20+
)
21+
22+
// InputType controls how Cohere classifies the input for retrieval tasks.
23+
type InputType string
24+
25+
const (
26+
InputTypeSearchDocument InputType = "search_document"
27+
InputTypeSearchQuery InputType = "search_query"
28+
InputTypeClassification InputType = "classification"
29+
InputTypeClustering InputType = "clustering"
30+
)
31+
32+
// Model dimensions for common Cohere embedding models.
33+
var modelDimensions = map[string]int{
34+
"embed-english-v3.0": 1024,
35+
"embed-multilingual-v3.0": 1024,
36+
"embed-english-light-v3.0": 384,
37+
}
38+
39+
// Config holds Cohere client configuration.
40+
type Config struct {
41+
// APIKey is the Cohere API key (required).
42+
APIKey string
43+
44+
// Model is the embedding model. Default: embed-english-v3.0
45+
Model string
46+
47+
// InputType controls retrieval optimisation. Default: search_document
48+
InputType InputType
49+
50+
// Timeout for API requests. Default: 30s
51+
Timeout time.Duration
52+
}
53+
54+
// Client implements embedding.Provider for Cohere.
55+
type Client struct {
56+
cfg Config
57+
httpClient *http.Client
58+
dimension int
59+
}
60+
61+
// NewClient creates a new Cohere embedding client.
62+
func NewClient(cfg Config) (*Client, error) {
63+
if cfg.APIKey == "" {
64+
return nil, fmt.Errorf("Cohere API key is required")
65+
}
66+
if cfg.Model == "" {
67+
cfg.Model = defaultModel
68+
}
69+
if cfg.InputType == "" {
70+
cfg.InputType = InputTypeSearchDocument
71+
}
72+
if cfg.Timeout <= 0 {
73+
cfg.Timeout = defaultTimeout
74+
}
75+
dim := modelDimensions[cfg.Model]
76+
return &Client{
77+
cfg: cfg,
78+
httpClient: &http.Client{Timeout: cfg.Timeout},
79+
dimension: dim,
80+
}, nil
81+
}
82+
83+
type embedRequest struct {
84+
Texts []string `json:"texts"`
85+
Model string `json:"model"`
86+
InputType InputType `json:"input_type"`
87+
}
88+
89+
type embedResponse struct {
90+
Embeddings [][]float32 `json:"embeddings"`
91+
}
92+
93+
// Embed returns the embedding for a single text.
94+
func (c *Client) Embed(ctx context.Context, text string) ([]float32, error) {
95+
if text == "" {
96+
return nil, embedding.ErrEmptyInput
97+
}
98+
results, err := c.EmbedBatch(ctx, []string{text})
99+
if err != nil {
100+
return nil, err
101+
}
102+
return results[0], nil
103+
}
104+
105+
// EmbedBatch embeds multiple texts in a single API call.
106+
func (c *Client) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
107+
if len(texts) == 0 {
108+
return nil, nil
109+
}
110+
111+
body, err := json.Marshal(embedRequest{
112+
Texts: texts,
113+
Model: c.cfg.Model,
114+
InputType: c.cfg.InputType,
115+
})
116+
if err != nil {
117+
return nil, fmt.Errorf("marshal request: %w", err)
118+
}
119+
120+
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
121+
defaultBaseURL+"/embed", bytes.NewReader(body))
122+
if err != nil {
123+
return nil, fmt.Errorf("build request: %w", err)
124+
}
125+
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
126+
req.Header.Set("Content-Type", "application/json")
127+
128+
resp, err := c.httpClient.Do(req)
129+
if err != nil {
130+
return nil, fmt.Errorf("cohere request: %w", err)
131+
}
132+
defer resp.Body.Close()
133+
134+
if resp.StatusCode == http.StatusTooManyRequests {
135+
return nil, embedding.ErrRateLimited
136+
}
137+
if resp.StatusCode == http.StatusUnauthorized {
138+
return nil, embedding.ErrInvalidAPIKey
139+
}
140+
if resp.StatusCode != http.StatusOK {
141+
b, _ := io.ReadAll(resp.Body)
142+
return nil, fmt.Errorf("cohere %d: %s", resp.StatusCode, string(b))
143+
}
144+
145+
var result embedResponse
146+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
147+
return nil, fmt.Errorf("decode response: %w", err)
148+
}
149+
if len(result.Embeddings) != len(texts) {
150+
return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(result.Embeddings))
151+
}
152+
return result.Embeddings, nil
153+
}
154+
155+
// Dimension returns the embedding dimension for the configured model.
156+
func (c *Client) Dimension() int { return c.dimension }
157+
158+
// ModelName returns the configured model name.
159+
func (c *Client) ModelName() string { return c.cfg.Model }

pkg/embedding/cohere/register.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package cohere
2+
3+
import (
4+
"github.com/Siddhant-K-code/distill/pkg/embedding"
5+
)
6+
7+
func init() {
8+
embedding.RegisterFactory(embedding.ProviderCohere, func(cfg embedding.ProviderConfig) (embedding.Provider, error) {
9+
return NewClient(Config{
10+
APIKey: cfg.APIKey,
11+
Model: cfg.Model,
12+
})
13+
})
14+
}

pkg/embedding/ollama/client.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Package ollama provides an embedding.Provider backed by a local Ollama server.
2+
package ollama
3+
4+
import (
5+
"bytes"
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"time"
12+
13+
"github.com/Siddhant-K-code/distill/pkg/embedding"
14+
)
15+
16+
const (
17+
defaultBaseURL = "http://localhost:11434"
18+
defaultModel = "nomic-embed-text"
19+
defaultTimeout = 60 * time.Second
20+
)
21+
22+
// Config holds Ollama client configuration.
23+
type Config struct {
24+
// BaseURL is the Ollama server URL. Default: http://localhost:11434
25+
BaseURL string
26+
27+
// Model is the embedding model to use. Default: nomic-embed-text
28+
Model string
29+
30+
// Timeout for API requests. Default: 60s (local models can be slow).
31+
Timeout time.Duration
32+
}
33+
34+
// Client implements embedding.Provider for Ollama.
35+
type Client struct {
36+
cfg Config
37+
httpClient *http.Client
38+
}
39+
40+
// NewClient creates a new Ollama embedding client.
41+
func NewClient(cfg Config) *Client {
42+
if cfg.BaseURL == "" {
43+
cfg.BaseURL = defaultBaseURL
44+
}
45+
if cfg.Model == "" {
46+
cfg.Model = defaultModel
47+
}
48+
if cfg.Timeout <= 0 {
49+
cfg.Timeout = defaultTimeout
50+
}
51+
return &Client{
52+
cfg: cfg,
53+
httpClient: &http.Client{Timeout: cfg.Timeout},
54+
}
55+
}
56+
57+
type embedRequest struct {
58+
Model string `json:"model"`
59+
Prompt string `json:"prompt"`
60+
}
61+
62+
type embedResponse struct {
63+
Embedding []float32 `json:"embedding"`
64+
}
65+
66+
// Embed returns the embedding for a single text.
67+
func (c *Client) Embed(ctx context.Context, text string) ([]float32, error) {
68+
if text == "" {
69+
return nil, embedding.ErrEmptyInput
70+
}
71+
72+
body, err := json.Marshal(embedRequest{Model: c.cfg.Model, Prompt: text})
73+
if err != nil {
74+
return nil, fmt.Errorf("marshal request: %w", err)
75+
}
76+
77+
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
78+
c.cfg.BaseURL+"/api/embeddings", bytes.NewReader(body))
79+
if err != nil {
80+
return nil, fmt.Errorf("build request: %w", err)
81+
}
82+
req.Header.Set("Content-Type", "application/json")
83+
84+
resp, err := c.httpClient.Do(req)
85+
if err != nil {
86+
return nil, fmt.Errorf("ollama request: %w", err)
87+
}
88+
defer resp.Body.Close()
89+
90+
if resp.StatusCode != http.StatusOK {
91+
b, _ := io.ReadAll(resp.Body)
92+
return nil, fmt.Errorf("ollama %d: %s", resp.StatusCode, string(b))
93+
}
94+
95+
var result embedResponse
96+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
97+
return nil, fmt.Errorf("decode response: %w", err)
98+
}
99+
if len(result.Embedding) == 0 {
100+
return nil, fmt.Errorf("ollama returned empty embedding")
101+
}
102+
return result.Embedding, nil
103+
}
104+
105+
// EmbedBatch embeds multiple texts sequentially (Ollama has no batch API).
106+
func (c *Client) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
107+
results := make([][]float32, len(texts))
108+
for i, text := range texts {
109+
emb, err := c.Embed(ctx, text)
110+
if err != nil {
111+
return nil, fmt.Errorf("embed[%d]: %w", i, err)
112+
}
113+
results[i] = emb
114+
}
115+
return results, nil
116+
}
117+
118+
// Dimension returns the embedding dimension. Ollama models vary; we return
119+
// 0 to indicate it is determined at runtime from the first response.
120+
func (c *Client) Dimension() int { return 0 }
121+
122+
// ModelName returns the configured model name.
123+
func (c *Client) ModelName() string { return c.cfg.Model }

pkg/embedding/ollama/register.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package ollama
2+
3+
import (
4+
"time"
5+
6+
"github.com/Siddhant-K-code/distill/pkg/embedding"
7+
)
8+
9+
func init() {
10+
embedding.RegisterFactory(embedding.ProviderOllama, func(cfg embedding.ProviderConfig) (embedding.Provider, error) {
11+
return NewClient(Config{
12+
BaseURL: cfg.BaseURL,
13+
Model: cfg.Model,
14+
Timeout: time.Duration(0), // uses defaultTimeout
15+
}), nil
16+
})
17+
}

pkg/embedding/openai/register.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package openai
2+
3+
import (
4+
"github.com/Siddhant-K-code/distill/pkg/embedding"
5+
)
6+
7+
func init() {
8+
embedding.RegisterFactory(embedding.ProviderOpenAI, func(cfg embedding.ProviderConfig) (embedding.Provider, error) {
9+
return NewClient(Config{
10+
APIKey: cfg.APIKey,
11+
Model: cfg.Model,
12+
BaseURL: cfg.BaseURL,
13+
})
14+
})
15+
}

0 commit comments

Comments
 (0)