Skip to content

Commit ca8d442

Browse files
committed
fix ollama provider
1 parent 6858e8c commit ca8d442

File tree

1 file changed

+49
-35
lines changed

1 file changed

+49
-35
lines changed
Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,49 @@
1-
import { jest, describe, expect, test, beforeEach } from '@jest/globals';
2-
import OllamaEmbeddingProvider from './OllamaEmbeddingProvider.js';
3-
4-
global.fetch = jest.fn();
5-
6-
describe('OllamaEmbeddingProvider', () => {
7-
let provider;
8-
9-
beforeEach(() => {
10-
jest.clearAllMocks();
11-
12-
provider = new OllamaEmbeddingProvider({
13-
baseUrl: 'http://localhost:11434',
14-
model: 'llama3',
15-
});
16-
17-
global.fetch.mockResolvedValue({
18-
ok: true,
19-
json: () => Promise.resolve({ embedding: Array(1536).fill(0.1) })
20-
});
21-
});
22-
23-
test('should fetch embeddings from Ollama', async () => {
24-
const texts = ['test query'];
25-
const embeddings = await provider.getEmbeddings(texts);
26-
27-
expect(global.fetch).toHaveBeenCalled();
28-
expect(embeddings[0]).toHaveLength(1536);
29-
});
30-
31-
test('should handle API errors', async () => {
32-
global.fetch.mockResolvedValueOnce({ ok: false, statusText: 'Error' });
33-
await expect(provider.getEmbeddings(['test'])).rejects.toThrow('Ollama API error: Error');
34-
});
35-
});
1+
import BaseEmbeddingProvider from './BaseEmbeddingProvider.js';
2+
import fetch from 'node-fetch';
3+
import debug from 'debug';
4+
5+
const log = debug('mongodb-rag:embedding:ollama');
6+
7+
class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
8+
constructor(options = {}) {
9+
super({}); // ✅ Skip API key validation
10+
11+
if (!options.baseUrl) {
12+
throw new Error('Ollama base URL is required (e.g., http://localhost:11434)');
13+
}
14+
if (!options.model) {
15+
throw new Error('Ollama model name is required (e.g., llama3)');
16+
}
17+
18+
this.baseUrl = options.baseUrl;
19+
this.model = options.model;
20+
}
21+
22+
async _embedBatch(texts) {
23+
try {
24+
log(`Fetching embeddings from Ollama (${this.model})...`);
25+
26+
const responses = await Promise.all(texts.map(async (text) => {
27+
const response = await fetch(`${this.baseUrl}/api/embeddings`, {
28+
method: 'POST',
29+
headers: { 'Content-Type': 'application/json' },
30+
body: JSON.stringify({ model: this.model, prompt: text }),
31+
});
32+
33+
if (!response.ok) {
34+
throw new Error(`Ollama API error: ${response.statusText}`);
35+
}
36+
37+
const data = await response.json();
38+
return data.embedding;
39+
}));
40+
41+
log(`Successfully retrieved ${responses.length} embeddings`);
42+
return responses;
43+
} catch (error) {
44+
throw new Error(`Ollama embedding error: ${error.message}`);
45+
}
46+
}
47+
}
48+
49+
export default OllamaEmbeddingProvider;

0 commit comments

Comments
 (0)