|
1 | | -import { jest, describe, expect, test, beforeEach } from '@jest/globals'; |
2 | | -import OllamaEmbeddingProvider from './OllamaEmbeddingProvider.js'; |
3 | | - |
4 | | -global.fetch = jest.fn(); |
5 | | - |
6 | | -describe('OllamaEmbeddingProvider', () => { |
7 | | - let provider; |
8 | | - |
9 | | - beforeEach(() => { |
10 | | - jest.clearAllMocks(); |
11 | | - |
12 | | - provider = new OllamaEmbeddingProvider({ |
13 | | - baseUrl: 'http://localhost:11434', |
14 | | - model: 'llama3', |
15 | | - }); |
16 | | - |
17 | | - global.fetch.mockResolvedValue({ |
18 | | - ok: true, |
19 | | - json: () => Promise.resolve({ embedding: Array(1536).fill(0.1) }) |
20 | | - }); |
21 | | - }); |
22 | | - |
23 | | - test('should fetch embeddings from Ollama', async () => { |
24 | | - const texts = ['test query']; |
25 | | - const embeddings = await provider.getEmbeddings(texts); |
26 | | - |
27 | | - expect(global.fetch).toHaveBeenCalled(); |
28 | | - expect(embeddings[0]).toHaveLength(1536); |
29 | | - }); |
30 | | - |
31 | | - test('should handle API errors', async () => { |
32 | | - global.fetch.mockResolvedValueOnce({ ok: false, statusText: 'Error' }); |
33 | | - await expect(provider.getEmbeddings(['test'])).rejects.toThrow('Ollama API error: Error'); |
34 | | - }); |
35 | | -}); |
| 1 | +import BaseEmbeddingProvider from './BaseEmbeddingProvider.js'; |
| 2 | +import fetch from 'node-fetch'; |
| 3 | +import debug from 'debug'; |
| 4 | + |
| 5 | +const log = debug('mongodb-rag:embedding:ollama'); |
| 6 | + |
| 7 | +class OllamaEmbeddingProvider extends BaseEmbeddingProvider { |
| 8 | + constructor(options = {}) { |
| 9 | + super({}); // ✅ Skip API key validation |
| 10 | + |
| 11 | + if (!options.baseUrl) { |
| 12 | + throw new Error('Ollama base URL is required (e.g., http://localhost:11434)'); |
| 13 | + } |
| 14 | + if (!options.model) { |
| 15 | + throw new Error('Ollama model name is required (e.g., llama3)'); |
| 16 | + } |
| 17 | + |
| 18 | + this.baseUrl = options.baseUrl; |
| 19 | + this.model = options.model; |
| 20 | + } |
| 21 | + |
| 22 | + async _embedBatch(texts) { |
| 23 | + try { |
| 24 | + log(`Fetching embeddings from Ollama (${this.model})...`); |
| 25 | + |
| 26 | + const responses = await Promise.all(texts.map(async (text) => { |
| 27 | + const response = await fetch(`${this.baseUrl}/api/embeddings`, { |
| 28 | + method: 'POST', |
| 29 | + headers: { 'Content-Type': 'application/json' }, |
| 30 | + body: JSON.stringify({ model: this.model, prompt: text }), |
| 31 | + }); |
| 32 | + |
| 33 | + if (!response.ok) { |
| 34 | + throw new Error(`Ollama API error: ${response.statusText}`); |
| 35 | + } |
| 36 | + |
| 37 | + const data = await response.json(); |
| 38 | + return data.embedding; |
| 39 | + })); |
| 40 | + |
| 41 | + log(`Successfully retrieved ${responses.length} embeddings`); |
| 42 | + return responses; |
| 43 | + } catch (error) { |
| 44 | + throw new Error(`Ollama embedding error: ${error.message}`); |
| 45 | + } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +export default OllamaEmbeddingProvider; |
0 commit comments