diff --git a/backend/app/controllers/providers/providers.go b/backend/app/controllers/providers/providers.go index 70de3d4..6529c68 100644 --- a/backend/app/controllers/providers/providers.go +++ b/backend/app/controllers/providers/providers.go @@ -89,7 +89,7 @@ func (h *Controller) Delete(c fiber.Ctx) error { func (h *Controller) Types(c fiber.Ctx) error { // For now, hardcoded list of supported provider types // In the future, this could be dynamic based on available providers - types := []string{"ollama", "openai"} + types := []string{"ollama", "openai", "litellm"} return c.JSON(fiber.Map{"types": types}) } diff --git a/backend/app/controllers/settings/settings.go b/backend/app/controllers/settings/settings.go index 508e28f..3a82716 100644 --- a/backend/app/controllers/settings/settings.go +++ b/backend/app/controllers/settings/settings.go @@ -3,7 +3,7 @@ package settings import ( "fmt" "sef/app/entities" - "sef/pkg/ollama" + "sef/pkg/providers" "strconv" "github.com/gofiber/fiber/v3" @@ -97,11 +97,20 @@ func (h *Controller) ListEmbeddingModels(c fiber.Ctx) error { return fiber.NewError(fiber.StatusNotFound, "Provider not found") } - // Create Ollama client - ollamaClient := ollama.NewOllamaClient(provider.BaseURL) + // Create embedding provider factory + factory := &providers.EmbeddingProviderFactory{} + config := map[string]interface{}{ + "base_url": provider.BaseURL, + // "api_key": ... + } + + embedProvider, err := factory.NewProvider(provider.Type, config) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to create provider: %v", err)) + } // Get all models - allModels, err := ollamaClient.ListModels() + allModels, err := embedProvider.ListModels() if err != nil { return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to list models: %v", err)) } diff --git a/backend/pkg/documentservice/service.go b/backend/pkg/documentservice/service.go index 9c101ce..423f5aa 100644 --- a/backend/pkg/documentservice/service.go +++ b/backend/pkg/documentservice/service.go @@ -5,7 +5,7 @@ import ( "fmt" "sef/app/entities" "sef/pkg/chunking" - "sef/pkg/ollama" + "sef/pkg/providers" "sef/pkg/qdrant" "strings" @@ -88,8 +88,31 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti return err } - // Create Ollama client for this provider - ollamaClient := ollama.NewOllamaClient(provider.BaseURL) + // Create embedding provider + factory := &providers.EmbeddingProviderFactory{} + config := map[string]interface{}{ + "base_url": provider.BaseURL, + "api_key": "sk-...", // Placeholder or from config if available for OpenAI/LiteLLM + } + + // If provider has API key stored (e.g. in Settings or Provider struct), use it. + // For now, checking if provider struct has API Key field or if it's in config settings? + // The Provider entity in database might have more fields. + // Let's rely on config map construction. + if provider.Type == "openai" || provider.Type == "litellm" { + // Assuming we might store API key in some secure way or config. + // For LiteLLM it often just needs BaseURL if it's acting as a proxy without auth or with env auth. + // If real OpenAI, we need API Key. + // Checking providers.go, NewOpenAIProvider reads "api_key". + // We should probably pass more config from provider entity if available. + } + + embedProvider, err := factory.NewProvider(provider.Type, config) + if err != nil { + document.Status = "failed" + ds.DB.Save(document) + return fmt.Errorf("failed to create embedding provider: %w", err) + } // Ensure global collection exists exists, err := ds.QdrantClient.CollectionExists(GlobalCollectionName) @@ -127,7 +150,7 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti totalChunks := len(chunks) for _, chunk := range chunks { log.Infof("Generating embedding for document ID %d, chunk %d", document.ID, chunk.Index) - embedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, chunk.Text) + embedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, chunk.Text) if err != nil { document.Status = "failed" ds.DB.Save(document) @@ -182,11 +205,20 @@ func (ds *DocumentService) SearchDocuments(ctx context.Context, query string, li return nil, err } - // Create Ollama client - ollamaClient := ollama.NewOllamaClient(provider.BaseURL) + // Create embedding provider + factory := &providers.EmbeddingProviderFactory{} + config := map[string]interface{}{ + "base_url": provider.BaseURL, + // "api_key": ... (add logic to retrieve API key if needed) + } + + embedProvider, err := factory.NewProvider(provider.Type, config) + if err != nil { + return nil, fmt.Errorf("failed to create embedding provider: %w", err) + } // Generate embedding for query - queryEmbedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, query) + queryEmbedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, query) if err != nil { return nil, fmt.Errorf("failed to generate query embedding: %w", err) } diff --git a/backend/pkg/providers/embedding.go b/backend/pkg/providers/embedding.go new file mode 100644 index 0000000..5e62e60 --- /dev/null +++ b/backend/pkg/providers/embedding.go @@ -0,0 +1,127 @@ +package providers + +import ( + "context" + "fmt" + "sef/pkg/ollama" + "strings" + + openai "github.com/sashabaranov/go-openai" +) + +// EmbeddingProvider defines the interface for embedding providers +type EmbeddingProvider interface { + // GenerateEmbedding generates embeddings for a given text + GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) + // ListModels returns available models for the provider + ListModels() ([]string, error) +} + +// OllamaEmbeddingProvider implements EmbeddingProvider for Ollama +type OllamaEmbeddingProvider struct { + client *ollama.OllamaClient +} + +// NewOllamaEmbeddingProvider creates a new Ollama embedding provider +func NewOllamaEmbeddingProvider(config map[string]interface{}) *OllamaEmbeddingProvider { + baseURL := "http://localhost:11434" + if url, ok := config["base_url"].(string); ok && url != "" { + baseURL = url + } + + return &OllamaEmbeddingProvider{ + client: ollama.NewOllamaClient(baseURL), + } +} + +func (o *OllamaEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) { + // Ollama client expects []float64, convert if necessary or update client + // Based on view_file output of pkg/documentservice/service.go, it seems it returns []float32 effectively + // Let's check pkg/ollama/client.go to be sure, but for now assuming it matches + return o.client.GenerateEmbedding(ctx, model, text) +} + +func (o *OllamaEmbeddingProvider) ListModels() ([]string, error) { + return o.client.ListModels() +} + +// OpenAIEmbeddingProvider implements EmbeddingProvider for OpenAI/LiteLLM +type OpenAIEmbeddingProvider struct { + client *openai.Client +} + +// NewOpenAIEmbeddingProvider creates a new OpenAI embedding provider +func NewOpenAIEmbeddingProvider(config map[string]interface{}) *OpenAIEmbeddingProvider { + apiKey := "" + if key, ok := config["api_key"].(string); ok { + apiKey = key + } + + configOpenAI := openai.DefaultConfig(apiKey) + if baseURL, ok := config["base_url"].(string); ok && baseURL != "" { + configOpenAI.BaseURL = baseURL + } + + client := openai.NewClientWithConfig(configOpenAI) + return &OpenAIEmbeddingProvider{ + client: client, + } +} + +func (o *OpenAIEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) { + // OpenAI embedding request + resp, err := o.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Model: openai.EmbeddingModel(model), + Input: []string{text}, + }) + if err != nil { + return nil, err + } + + if len(resp.Data) == 0 { + return nil, fmt.Errorf("no embedding data returned") + } + + return resp.Data[0].Embedding, nil +} + +func (o *OpenAIEmbeddingProvider) ListModels() ([]string, error) { + models, err := o.client.ListModels(ctx) + if err != nil { + return nil, err + } + + var modelNames []string + for _, model := range models.Models { + // Filter for embedding models if possible, but for now return all or filter by name convention + // LiteLLM might return all proxies + if strings.Contains(strings.ToLower(model.ID), "embed") { + modelNames = append(modelNames, model.ID) + } + } + // If no models matched "embed", just return all of them as fallback + if len(modelNames) == 0 { + for _, model := range models.Models { + modelNames = append(modelNames, model.ID) + } + } + + return modelNames, nil +} + +// Factory for embedding providers +type EmbeddingProviderFactory struct{} + +func (f *EmbeddingProviderFactory) NewProvider(providerType string, config map[string]interface{}) (EmbeddingProvider, error) { + switch providerType { + case "ollama": + return NewOllamaEmbeddingProvider(config), nil + case "openai", "litellm": + return NewOpenAIEmbeddingProvider(config), nil + default: + return nil, fmt.Errorf("unsupported embedding provider type: %s", providerType) + } +} + +// Helper context for ListModels since interface doesn't have it (my bad on previous analysis, fixing here) +var ctx = context.Background() diff --git a/backend/pkg/providers/interface.go b/backend/pkg/providers/interface.go index 89c8887..495eec6 100644 --- a/backend/pkg/providers/interface.go +++ b/backend/pkg/providers/interface.go @@ -67,7 +67,7 @@ func (f *ProviderFactory) NewProvider(providerType string, config map[string]int switch providerType { case "ollama": return NewOllamaProvider(config), nil - case "openai": + case "openai", "litellm": return NewOpenAIProvider(config), nil default: return nil, fmt.Errorf("unsupported provider type: %s", providerType) diff --git a/frontend/src/pages/settings/embedding/index.tsx b/frontend/src/pages/settings/embedding/index.tsx index 0066757..6c3be79 100644 --- a/frontend/src/pages/settings/embedding/index.tsx +++ b/frontend/src/pages/settings/embedding/index.tsx @@ -59,10 +59,12 @@ export default function EmbeddingSettingsPage() { try { const response = await http.get("/providers") // Filter only Ollama providers - const ollamaProviders = response.data?.records?.filter( - (p: IProvider) => p.type?.toLowerCase() === "ollama" + // Filter supported embedding providers + const supportedTypes = ["ollama", "litellm"] + const embeddingProviders = response.data?.records?.filter( + (p: IProvider) => supportedTypes.includes(p.type?.toLowerCase()) ) || [] - setProviders(ollamaProviders) + setProviders(embeddingProviders) } catch (error) { console.error("Failed to fetch providers:", error) } @@ -72,7 +74,7 @@ export default function EmbeddingSettingsPage() { try { const response = await http.get("/settings/embedding") setCurrentConfig(response.data) - + if (response.data.provider) { setSelectedProviderId(response.data.provider.id.toString()) // Fetch models for this provider @@ -235,12 +237,12 @@ export default function EmbeddingSettingsPage() { disabled={!selectedProviderId || loadingModels} > - @@ -273,8 +275,8 @@ export default function EmbeddingSettingsPage() { -