Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/app/controllers/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (h *Controller) Delete(c fiber.Ctx) error {
func (h *Controller) Types(c fiber.Ctx) error {
// For now, hardcoded list of supported provider types
// In the future, this could be dynamic based on available providers
types := []string{"ollama", "openai"}
types := []string{"ollama", "openai", "litellm"}
return c.JSON(fiber.Map{"types": types})
}

Expand Down
17 changes: 13 additions & 4 deletions backend/app/controllers/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package settings
import (
"fmt"
"sef/app/entities"
"sef/pkg/ollama"
"sef/pkg/providers"
"strconv"

"github.com/gofiber/fiber/v3"
Expand Down Expand Up @@ -97,11 +97,20 @@ func (h *Controller) ListEmbeddingModels(c fiber.Ctx) error {
return fiber.NewError(fiber.StatusNotFound, "Provider not found")
}

// Create Ollama client
ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
// Create embedding provider factory
factory := &providers.EmbeddingProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
// "api_key": ...
}

embedProvider, err := factory.NewProvider(provider.Type, config)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to create provider: %v", err))
}

// Get all models
allModels, err := ollamaClient.ListModels()
allModels, err := embedProvider.ListModels()
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to list models: %v", err))
}
Expand Down
46 changes: 39 additions & 7 deletions backend/pkg/documentservice/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"
"sef/app/entities"
"sef/pkg/chunking"
"sef/pkg/ollama"
"sef/pkg/providers"
"sef/pkg/qdrant"
"strings"

Expand Down Expand Up @@ -88,8 +88,31 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti
return err
}

// Create Ollama client for this provider
ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
// Create embedding provider
factory := &providers.EmbeddingProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
"api_key": "sk-...", // Placeholder or from config if available for OpenAI/LiteLLM
}

// If provider has API key stored (e.g. in Settings or Provider struct), use it.
// For now, checking if provider struct has API Key field or if it's in config settings?
// The Provider entity in database might have more fields.
// Let's rely on config map construction.
if provider.Type == "openai" || provider.Type == "litellm" {
// Assuming we might store API key in some secure way or config.
// For LiteLLM it often just needs BaseURL if it's acting as a proxy without auth or with env auth.
// If real OpenAI, we need API Key.
// Checking providers.go, NewOpenAIProvider reads "api_key".
// We should probably pass more config from provider entity if available.
}

embedProvider, err := factory.NewProvider(provider.Type, config)
if err != nil {
document.Status = "failed"
ds.DB.Save(document)
return fmt.Errorf("failed to create embedding provider: %w", err)
}

// Ensure global collection exists
exists, err := ds.QdrantClient.CollectionExists(GlobalCollectionName)
Expand Down Expand Up @@ -127,7 +150,7 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti
totalChunks := len(chunks)
for _, chunk := range chunks {
log.Infof("Generating embedding for document ID %d, chunk %d", document.ID, chunk.Index)
embedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, chunk.Text)
embedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, chunk.Text)
if err != nil {
document.Status = "failed"
ds.DB.Save(document)
Expand Down Expand Up @@ -182,11 +205,20 @@ func (ds *DocumentService) SearchDocuments(ctx context.Context, query string, li
return nil, err
}

// Create Ollama client
ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
// Create embedding provider
factory := &providers.EmbeddingProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
// "api_key": ... (add logic to retrieve API key if needed)
}

embedProvider, err := factory.NewProvider(provider.Type, config)
if err != nil {
return nil, fmt.Errorf("failed to create embedding provider: %w", err)
}

// Generate embedding for query
queryEmbedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, query)
queryEmbedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, query)
if err != nil {
return nil, fmt.Errorf("failed to generate query embedding: %w", err)
}
Expand Down
127 changes: 127 additions & 0 deletions backend/pkg/providers/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package providers

import (
"context"
"fmt"
"sef/pkg/ollama"
"strings"

openai "github.com/sashabaranov/go-openai"
)

// EmbeddingProvider defines the interface for embedding providers
type EmbeddingProvider interface {
// GenerateEmbedding generates embeddings for a given text
GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error)
// ListModels returns available models for the provider
ListModels() ([]string, error)
}

// OllamaEmbeddingProvider implements EmbeddingProvider for Ollama
type OllamaEmbeddingProvider struct {
client *ollama.OllamaClient
}

// NewOllamaEmbeddingProvider creates a new Ollama embedding provider
func NewOllamaEmbeddingProvider(config map[string]interface{}) *OllamaEmbeddingProvider {
baseURL := "http://localhost:11434"
if url, ok := config["base_url"].(string); ok && url != "" {
baseURL = url
}

return &OllamaEmbeddingProvider{
client: ollama.NewOllamaClient(baseURL),
}
}

func (o *OllamaEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) {
// Ollama client expects []float64, convert if necessary or update client
// Based on view_file output of pkg/documentservice/service.go, it seems it returns []float32 effectively
// Let's check pkg/ollama/client.go to be sure, but for now assuming it matches
return o.client.GenerateEmbedding(ctx, model, text)
}

func (o *OllamaEmbeddingProvider) ListModels() ([]string, error) {
return o.client.ListModels()
}

// OpenAIEmbeddingProvider implements EmbeddingProvider for OpenAI/LiteLLM
type OpenAIEmbeddingProvider struct {
client *openai.Client
}

// NewOpenAIEmbeddingProvider creates a new OpenAI embedding provider
func NewOpenAIEmbeddingProvider(config map[string]interface{}) *OpenAIEmbeddingProvider {
apiKey := ""
if key, ok := config["api_key"].(string); ok {
apiKey = key
}

configOpenAI := openai.DefaultConfig(apiKey)
if baseURL, ok := config["base_url"].(string); ok && baseURL != "" {
configOpenAI.BaseURL = baseURL
}

client := openai.NewClientWithConfig(configOpenAI)
return &OpenAIEmbeddingProvider{
client: client,
}
}

func (o *OpenAIEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) {
// OpenAI embedding request
resp, err := o.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Model: openai.EmbeddingModel(model),
Input: []string{text},
})
if err != nil {
return nil, err
}

if len(resp.Data) == 0 {
return nil, fmt.Errorf("no embedding data returned")
}

return resp.Data[0].Embedding, nil
}

func (o *OpenAIEmbeddingProvider) ListModels() ([]string, error) {
models, err := o.client.ListModels(ctx)
if err != nil {
return nil, err
}

var modelNames []string
for _, model := range models.Models {
// Filter for embedding models if possible, but for now return all or filter by name convention
// LiteLLM might return all proxies
if strings.Contains(strings.ToLower(model.ID), "embed") {
modelNames = append(modelNames, model.ID)
}
}
// If no models matched "embed", just return all of them as fallback
if len(modelNames) == 0 {
for _, model := range models.Models {
modelNames = append(modelNames, model.ID)
}
}

return modelNames, nil
}

// Factory for embedding providers
type EmbeddingProviderFactory struct{}

func (f *EmbeddingProviderFactory) NewProvider(providerType string, config map[string]interface{}) (EmbeddingProvider, error) {
switch providerType {
case "ollama":
return NewOllamaEmbeddingProvider(config), nil
case "openai", "litellm":
return NewOpenAIEmbeddingProvider(config), nil
default:
return nil, fmt.Errorf("unsupported embedding provider type: %s", providerType)
}
}

// Helper context for ListModels since interface doesn't have it (my bad on previous analysis, fixing here)
var ctx = context.Background()
2 changes: 1 addition & 1 deletion backend/pkg/providers/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (f *ProviderFactory) NewProvider(providerType string, config map[string]int
switch providerType {
case "ollama":
return NewOllamaProvider(config), nil
case "openai":
case "openai", "litellm":
return NewOpenAIProvider(config), nil
default:
return nil, fmt.Errorf("unsupported provider type: %s", providerType)
Expand Down
24 changes: 13 additions & 11 deletions frontend/src/pages/settings/embedding/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ export default function EmbeddingSettingsPage() {
try {
const response = await http.get("/providers")
// Filter only Ollama providers
const ollamaProviders = response.data?.records?.filter(
(p: IProvider) => p.type?.toLowerCase() === "ollama"
// Filter supported embedding providers
const supportedTypes = ["ollama", "litellm"]
const embeddingProviders = response.data?.records?.filter(
(p: IProvider) => supportedTypes.includes(p.type?.toLowerCase())
) || []
setProviders(ollamaProviders)
setProviders(embeddingProviders)
} catch (error) {
console.error("Failed to fetch providers:", error)
}
Expand All @@ -72,7 +74,7 @@ export default function EmbeddingSettingsPage() {
try {
const response = await http.get("/settings/embedding")
setCurrentConfig(response.data)

if (response.data.provider) {
setSelectedProviderId(response.data.provider.id.toString())
// Fetch models for this provider
Expand Down Expand Up @@ -235,12 +237,12 @@ export default function EmbeddingSettingsPage() {
disabled={!selectedProviderId || loadingModels}
>
<SelectTrigger id="model">
<SelectValue
<SelectValue
placeholder={
loadingModels
? t("embedding.loading_models")
loadingModels
? t("embedding.loading_models")
: t("embedding.select_model")
}
}
/>
</SelectTrigger>
<SelectContent>
Expand Down Expand Up @@ -273,8 +275,8 @@ export default function EmbeddingSettingsPage() {
</div>
</CardContent>
<CardFooter>
<Button
onClick={handleSave}
<Button
onClick={handleSave}
disabled={loading || !selectedProviderId || !selectedModel || !vectorSize}
className="w-full"
>
Expand All @@ -297,4 +299,4 @@ export default function EmbeddingSettingsPage() {
</div>
</>
)
}
}