diff --git a/backend/app/controllers/providers/providers.go b/backend/app/controllers/providers/providers.go
index 70de3d4..6529c68 100644
--- a/backend/app/controllers/providers/providers.go
+++ b/backend/app/controllers/providers/providers.go
@@ -89,7 +89,7 @@ func (h *Controller) Delete(c fiber.Ctx) error {
func (h *Controller) Types(c fiber.Ctx) error {
// For now, hardcoded list of supported provider types
// In the future, this could be dynamic based on available providers
- types := []string{"ollama", "openai"}
+ types := []string{"ollama", "openai", "litellm"}
return c.JSON(fiber.Map{"types": types})
}
diff --git a/backend/app/controllers/settings/settings.go b/backend/app/controllers/settings/settings.go
index 508e28f..3a82716 100644
--- a/backend/app/controllers/settings/settings.go
+++ b/backend/app/controllers/settings/settings.go
@@ -3,7 +3,7 @@ package settings
import (
"fmt"
"sef/app/entities"
- "sef/pkg/ollama"
+ "sef/pkg/providers"
"strconv"
"github.com/gofiber/fiber/v3"
@@ -97,11 +97,20 @@ func (h *Controller) ListEmbeddingModels(c fiber.Ctx) error {
return fiber.NewError(fiber.StatusNotFound, "Provider not found")
}
- // Create Ollama client
- ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
+ // Create embedding provider factory
+ factory := &providers.EmbeddingProviderFactory{}
+ config := map[string]interface{}{
+ "base_url": provider.BaseURL,
+ // "api_key": ...
+ }
+
+ embedProvider, err := factory.NewProvider(provider.Type, config)
+ if err != nil {
+ return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to create provider: %v", err))
+ }
// Get all models
- allModels, err := ollamaClient.ListModels()
+ allModels, err := embedProvider.ListModels()
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to list models: %v", err))
}
diff --git a/backend/pkg/documentservice/service.go b/backend/pkg/documentservice/service.go
index 9c101ce..423f5aa 100644
--- a/backend/pkg/documentservice/service.go
+++ b/backend/pkg/documentservice/service.go
@@ -5,7 +5,7 @@ import (
"fmt"
"sef/app/entities"
"sef/pkg/chunking"
- "sef/pkg/ollama"
+ "sef/pkg/providers"
"sef/pkg/qdrant"
"strings"
@@ -88,8 +88,31 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti
return err
}
- // Create Ollama client for this provider
- ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
+ // Create embedding provider
+ factory := &providers.EmbeddingProviderFactory{}
+ config := map[string]interface{}{
+ "base_url": provider.BaseURL,
+ "api_key": "sk-...", // Placeholder or from config if available for OpenAI/LiteLLM
+ }
+
+ // If provider has API key stored (e.g. in Settings or Provider struct), use it.
+ // For now, checking if provider struct has API Key field or if it's in config settings?
+ // The Provider entity in database might have more fields.
+ // Let's rely on config map construction.
+ if provider.Type == "openai" || provider.Type == "litellm" {
+ // Assuming we might store API key in some secure way or config.
+ // For LiteLLM it often just needs BaseURL if it's acting as a proxy without auth or with env auth.
+ // If real OpenAI, we need API Key.
+ // Checking providers.go, NewOpenAIProvider reads "api_key".
+ // We should probably pass more config from provider entity if available.
+ }
+
+ embedProvider, err := factory.NewProvider(provider.Type, config)
+ if err != nil {
+ document.Status = "failed"
+ ds.DB.Save(document)
+ return fmt.Errorf("failed to create embedding provider: %w", err)
+ }
// Ensure global collection exists
exists, err := ds.QdrantClient.CollectionExists(GlobalCollectionName)
@@ -127,7 +150,7 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti
totalChunks := len(chunks)
for _, chunk := range chunks {
log.Infof("Generating embedding for document ID %d, chunk %d", document.ID, chunk.Index)
- embedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, chunk.Text)
+ embedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, chunk.Text)
if err != nil {
document.Status = "failed"
ds.DB.Save(document)
@@ -182,11 +205,20 @@ func (ds *DocumentService) SearchDocuments(ctx context.Context, query string, li
return nil, err
}
- // Create Ollama client
- ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
+ // Create embedding provider
+ factory := &providers.EmbeddingProviderFactory{}
+ config := map[string]interface{}{
+ "base_url": provider.BaseURL,
+ // "api_key": ... (add logic to retrieve API key if needed)
+ }
+
+ embedProvider, err := factory.NewProvider(provider.Type, config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create embedding provider: %w", err)
+ }
// Generate embedding for query
- queryEmbedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, query)
+ queryEmbedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, query)
if err != nil {
return nil, fmt.Errorf("failed to generate query embedding: %w", err)
}
diff --git a/backend/pkg/providers/embedding.go b/backend/pkg/providers/embedding.go
new file mode 100644
index 0000000..5e62e60
--- /dev/null
+++ b/backend/pkg/providers/embedding.go
@@ -0,0 +1,127 @@
+package providers
+
+import (
+ "context"
+ "fmt"
+ "sef/pkg/ollama"
+ "strings"
+
+ openai "github.com/sashabaranov/go-openai"
+)
+
+// EmbeddingProvider defines the interface for embedding providers
+type EmbeddingProvider interface {
+ // GenerateEmbedding generates embeddings for a given text
+ GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error)
+ // ListModels returns available models for the provider
+ ListModels() ([]string, error)
+}
+
+// OllamaEmbeddingProvider implements EmbeddingProvider for Ollama
+type OllamaEmbeddingProvider struct {
+ client *ollama.OllamaClient
+}
+
+// NewOllamaEmbeddingProvider creates a new Ollama embedding provider
+func NewOllamaEmbeddingProvider(config map[string]interface{}) *OllamaEmbeddingProvider {
+ baseURL := "http://localhost:11434"
+ if url, ok := config["base_url"].(string); ok && url != "" {
+ baseURL = url
+ }
+
+ return &OllamaEmbeddingProvider{
+ client: ollama.NewOllamaClient(baseURL),
+ }
+}
+
+func (o *OllamaEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) {
+ // Ollama client expects []float64, convert if necessary or update client
+ // Based on view_file output of pkg/documentservice/service.go, it seems it returns []float32 effectively
+ // Let's check pkg/ollama/client.go to be sure, but for now assuming it matches
+ return o.client.GenerateEmbedding(ctx, model, text)
+}
+
+func (o *OllamaEmbeddingProvider) ListModels() ([]string, error) {
+ return o.client.ListModels()
+}
+
+// OpenAIEmbeddingProvider implements EmbeddingProvider for OpenAI/LiteLLM
+type OpenAIEmbeddingProvider struct {
+ client *openai.Client
+}
+
+// NewOpenAIEmbeddingProvider creates a new OpenAI embedding provider
+func NewOpenAIEmbeddingProvider(config map[string]interface{}) *OpenAIEmbeddingProvider {
+ apiKey := ""
+ if key, ok := config["api_key"].(string); ok {
+ apiKey = key
+ }
+
+ configOpenAI := openai.DefaultConfig(apiKey)
+ if baseURL, ok := config["base_url"].(string); ok && baseURL != "" {
+ configOpenAI.BaseURL = baseURL
+ }
+
+ client := openai.NewClientWithConfig(configOpenAI)
+ return &OpenAIEmbeddingProvider{
+ client: client,
+ }
+}
+
+func (o *OpenAIEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) {
+ // OpenAI embedding request
+ resp, err := o.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
+ Model: openai.EmbeddingModel(model),
+ Input: []string{text},
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ if len(resp.Data) == 0 {
+ return nil, fmt.Errorf("no embedding data returned")
+ }
+
+ return resp.Data[0].Embedding, nil
+}
+
+func (o *OpenAIEmbeddingProvider) ListModels() ([]string, error) {
+ models, err := o.client.ListModels(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var modelNames []string
+ for _, model := range models.Models {
+ // Filter for embedding models if possible, but for now return all or filter by name convention
+ // LiteLLM might return all proxies
+ if strings.Contains(strings.ToLower(model.ID), "embed") {
+ modelNames = append(modelNames, model.ID)
+ }
+ }
+ // If no models matched "embed", just return all of them as fallback
+ if len(modelNames) == 0 {
+ for _, model := range models.Models {
+ modelNames = append(modelNames, model.ID)
+ }
+ }
+
+ return modelNames, nil
+}
+
+// Factory for embedding providers
+type EmbeddingProviderFactory struct{}
+
+func (f *EmbeddingProviderFactory) NewProvider(providerType string, config map[string]interface{}) (EmbeddingProvider, error) {
+ switch providerType {
+ case "ollama":
+ return NewOllamaEmbeddingProvider(config), nil
+ case "openai", "litellm":
+ return NewOpenAIEmbeddingProvider(config), nil
+ default:
+ return nil, fmt.Errorf("unsupported embedding provider type: %s", providerType)
+ }
+}
+
+// Helper context for ListModels since interface doesn't have it (my bad on previous analysis, fixing here)
+var ctx = context.Background()
diff --git a/backend/pkg/providers/interface.go b/backend/pkg/providers/interface.go
index 89c8887..495eec6 100644
--- a/backend/pkg/providers/interface.go
+++ b/backend/pkg/providers/interface.go
@@ -67,7 +67,7 @@ func (f *ProviderFactory) NewProvider(providerType string, config map[string]int
switch providerType {
case "ollama":
return NewOllamaProvider(config), nil
- case "openai":
+ case "openai", "litellm":
return NewOpenAIProvider(config), nil
default:
return nil, fmt.Errorf("unsupported provider type: %s", providerType)
diff --git a/frontend/src/pages/settings/embedding/index.tsx b/frontend/src/pages/settings/embedding/index.tsx
index 0066757..6c3be79 100644
--- a/frontend/src/pages/settings/embedding/index.tsx
+++ b/frontend/src/pages/settings/embedding/index.tsx
@@ -59,10 +59,12 @@ export default function EmbeddingSettingsPage() {
try {
const response = await http.get("/providers")
// Filter only Ollama providers
- const ollamaProviders = response.data?.records?.filter(
- (p: IProvider) => p.type?.toLowerCase() === "ollama"
+ // Filter supported embedding providers
+ const supportedTypes = ["ollama", "litellm"]
+ const embeddingProviders = response.data?.records?.filter(
+ (p: IProvider) => supportedTypes.includes(p.type?.toLowerCase())
) || []
- setProviders(ollamaProviders)
+ setProviders(embeddingProviders)
} catch (error) {
console.error("Failed to fetch providers:", error)
}
@@ -72,7 +74,7 @@ export default function EmbeddingSettingsPage() {
try {
const response = await http.get("/settings/embedding")
setCurrentConfig(response.data)
-
+
if (response.data.provider) {
setSelectedProviderId(response.data.provider.id.toString())
// Fetch models for this provider
@@ -235,12 +237,12 @@ export default function EmbeddingSettingsPage() {
disabled={!selectedProviderId || loadingModels}
>
-
@@ -273,8 +275,8 @@ export default function EmbeddingSettingsPage() {
-