Skip to content

Commit cfe3d7d

Browse files
authored
feat: HFEI support (#69)
Closes #66
1 parent c21b146 commit cfe3d7d

3 files changed

Lines changed: 203 additions & 89 deletions

File tree

README.md

Lines changed: 144 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,33 @@ import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction;
7878
import java.util.*;
7979

8080
public class Main {
81-
public static void main(String[] args) {
82-
try {
83-
Client client = new Client(System.getenv("CHROMA_URL"));
84-
client.reset();
85-
EmbeddingFunction ef = new DefaultEmbeddingFunction();
86-
Collection collection = client.createCollection("test-collection", null, true, ef);
87-
List<Map<String, String>> metadata = new ArrayList<>();
88-
metadata.add(new HashMap<String, String>() {{
89-
put("type", "scientist");
90-
}});
91-
metadata.add(new HashMap<String, String>() {{
92-
put("type", "spy");
93-
}});
94-
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
95-
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
96-
System.out.println(qr);
97-
} catch (Exception e) {
98-
System.out.println(e);
81+
public static void main(String[] args) {
82+
try {
83+
Client client = new Client(System.getenv("CHROMA_URL"));
84+
client.reset();
85+
EmbeddingFunction ef = new DefaultEmbeddingFunction();
86+
Collection collection = client.createCollection("test-collection", null, true, ef);
87+
List<Map<String, String>> metadata = new ArrayList<>();
88+
metadata.add(new HashMap<String, String>() {{
89+
put("type", "scientist");
90+
}});
91+
metadata.add(new HashMap<String, String>() {{
92+
put("type", "spy");
93+
}});
94+
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
95+
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
96+
System.out.println(qr);
97+
} catch (Exception e) {
98+
System.out.println(e);
99+
}
99100
}
100-
}
101101
}
102102
```
103103

104104
### Example OpenAI Embedding Function
105105

106-
In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for our documents.
106+
In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for
107+
our documents.
107108

108109
| **Important**: Ensure you have `OPENAI_API_KEY` environment variable set
109110

@@ -118,27 +119,27 @@ import tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction;
118119
import java.util.*;
119120

120121
public class Main {
121-
public static void main(String[] args) {
122-
try {
123-
Client client = new Client(System.getenv("CHROMA_URL"));
124-
String apiKey = System.getenv("OPENAI_API_KEY");
125-
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
126-
Collection collection = client.createCollection("test-collection", null, true, ef);
127-
List<Map<String, String>> metadata = new ArrayList<>();
128-
metadata.add(new HashMap<String, String>() {{
129-
put("type", "scientist");
130-
}});
131-
metadata.add(new HashMap<String, String>() {{
132-
put("type", "spy");
133-
}});
134-
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
135-
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
136-
System.out.println(qr);
137-
} catch (Exception e) {
138-
e.printStackTrace();
139-
System.out.println(e);
122+
public static void main(String[] args) {
123+
try {
124+
Client client = new Client(System.getenv("CHROMA_URL"));
125+
String apiKey = System.getenv("OPENAI_API_KEY");
126+
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
127+
Collection collection = client.createCollection("test-collection", null, true, ef);
128+
List<Map<String, String>> metadata = new ArrayList<>();
129+
metadata.add(new HashMap<String, String>() {{
130+
put("type", "scientist");
131+
}});
132+
metadata.add(new HashMap<String, String>() {{
133+
put("type", "spy");
134+
}});
135+
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
136+
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
137+
System.out.println(qr);
138+
} catch (Exception e) {
139+
e.printStackTrace();
140+
System.out.println(e);
141+
}
140142
}
141-
}
142143
}
143144
```
144145

@@ -174,7 +175,8 @@ curl http://localhost:11434/api/embeddings -d '{\n "model": "llama2",\n "promp
174175

175176
### Example Cohere Embedding Function
176177

177-
In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for our documents.
178+
In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for
179+
our documents.
178180

179181
| **Important**: Ensure you have `COHERE_API_KEY` environment variable set
180182

@@ -188,28 +190,28 @@ import tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction;
188190
import java.util.*;
189191

190192
public class Main {
191-
public static void main(String[] args) {
192-
try {
193-
Client client = new Client(System.getenv("CHROMA_URL"));
194-
client.reset();
195-
String apiKey = System.getenv("COHERE_API_KEY");
196-
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
197-
Collection collection = client.createCollection("test-collection", null, true, ef);
198-
List<Map<String, String>> metadata = new ArrayList<>();
199-
metadata.add(new HashMap<String, String>() {{
200-
put("type", "scientist");
201-
}});
202-
metadata.add(new HashMap<String, String>() {{
203-
put("type", "spy");
204-
}});
205-
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
206-
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
207-
System.out.println(qr);
208-
} catch (Exception e) {
209-
e.printStackTrace();
210-
System.out.println(e);
193+
public static void main(String[] args) {
194+
try {
195+
Client client = new Client(System.getenv("CHROMA_URL"));
196+
client.reset();
197+
String apiKey = System.getenv("COHERE_API_KEY");
198+
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
199+
Collection collection = client.createCollection("test-collection", null, true, ef);
200+
List<Map<String, String>> metadata = new ArrayList<>();
201+
metadata.add(new HashMap<String, String>() {{
202+
put("type", "scientist");
203+
}});
204+
metadata.add(new HashMap<String, String>() {{
205+
put("type", "spy");
206+
}});
207+
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
208+
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
209+
System.out.println(qr);
210+
} catch (Exception e) {
211+
e.printStackTrace();
212+
System.out.println(e);
213+
}
211214
}
212-
}
213215
}
214216
```
215217

@@ -221,7 +223,10 @@ The above should output:
221223

222224
### Example Hugging Face Sentence Transformers Embedding Function
223225

224-
In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for our documents.
226+
#### Hugging Face Inference API
227+
228+
In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for
229+
our documents using HuggingFace cloud-based inference API.
225230

226231
| **Important**: Ensure you have `HF_API_KEY` environment variable set
227232

@@ -235,27 +240,26 @@ import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;
235240
import java.util.*;
236241

237242
public class Main {
238-
public static void main(String[] args) {
239-
try {
240-
Client client = new Client(System.getenv("CHROMA_URL"));
241-
client.reset();
242-
String apiKey = System.getenv("HF_API_KEY");
243-
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
244-
Collection collection = client.createCollection("test-collection", null, true, ef);
245-
List<Map<String, String>> metadata = new ArrayList<>();
246-
metadata.add(new HashMap<String, String>() {{
247-
put("type", "scientist");
248-
}});
249-
metadata.add(new HashMap<String, String>() {{
250-
put("type", "spy");
251-
}});
252-
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
253-
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
254-
System.out.println(qr);
255-
} catch (Exception e) {
256-
System.out.println(e);
243+
public static void main(String[] args) {
244+
try {
245+
Client client = new Client("http://localhost:8000");
246+
String apiKey = System.getenv("HF_API_KEY");
247+
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
248+
Collection collection = client.createCollection("test-collection", null, true, ef);
249+
List<Map<String, String>> metadata = new ArrayList<>();
250+
metadata.add(new HashMap<String, String>() {{
251+
put("type", "scientist");
252+
}});
253+
metadata.add(new HashMap<String, String>() {{
254+
put("type", "spy");
255+
}});
256+
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
257+
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
258+
System.out.println(qr);
259+
} catch (Exception e) {
260+
System.out.println(e);
261+
}
257262
}
258-
}
259263
}
260264
```
261265

@@ -265,6 +269,63 @@ The above should output:
265269
{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.9073759,1.6440368]]}
266270
```
267271

272+
#### Hugging Face Text Embedding Inference (HFEI) API
273+
274+
In this example we'll use a local Docker based server to generate the embeddings with
275+
`Snowflake/snowflake-arctic-embed-s` mode.
276+
277+
First let's start the HFEI server:
278+
279+
```bash
280+
docker run -d -p 8008:80 --platform linux/amd64 --name hfei ghcr.io/huggingface/text-embeddings-inference:cpu-1.5.0 --model-id Snowflake/snowflake-arctic-embed-s --revision main
281+
```
282+
283+
> Note: Check the official documentation for more details - https://github.com/huggingface/text-embeddings-inference
284+
285+
Then we can use the following code to generate embeddings. Note the use of
286+
`new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API));` to define the API type,
287+
this will ensure the client uses the correct endpoint.
288+
289+
```java
290+
package tech.amikos;
291+
292+
import tech.amikos.chromadb.*;
293+
import tech.amikos.chromadb.Collection;
294+
import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;
295+
296+
import java.util.*;
297+
298+
public class Main {
299+
public static void main(String[] args) {
300+
try {
301+
Client client = new Client("http://localhost:8000");
302+
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(
303+
WithParam.baseAPI("http://localhost:8008"),
304+
new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API));
305+
Collection collection = client.createCollection("test-collection", null, true, ef);
306+
List<Map<String, String>> metadata = new ArrayList<>();
307+
metadata.add(new HashMap<String, String>() {{
308+
put("type", "scientist");
309+
}});
310+
metadata.add(new HashMap<String, String>() {{
311+
put("type", "spy");
312+
}});
313+
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
314+
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
315+
System.out.println(qr);
316+
} catch (Exception e) {
317+
System.out.println(e);
318+
}
319+
}
320+
}
321+
```
322+
323+
The above should similar to the following output:
324+
325+
```bash
326+
{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.19665092,0.42433012]]}
327+
```
328+
268329
### Ollama Embedding Function
269330

270331
In this example we rely on `tech.amikos.chromadb.embeddings.ollama.OllamaEmbeddingFunction` to generate embeddings for

src/main/java/tech/amikos/chromadb/embeddings/hf/HuggingFaceEmbeddingFunction.java

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {
1818
public static final String DEFAULT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2";
1919
public static final String DEFAULT_BASE_API = "https://api-inference.huggingface.co/pipeline/feature-extraction/";
20+
public static final String HFEI_API_PATH = "/embed";
2021
public static final String HF_API_KEY_ENV = "HF_API_KEY";
22+
public static final String API_TYPE_CONFIG_KEY = "apiType";
2123
private final OkHttpClient client = new OkHttpClient();
2224
private final Map<String, Object> configParams = new HashMap<>();
2325
private static final Gson gson = new Gson();
2426

2527
private static final List<WithParam> defaults = Arrays.asList(
28+
new WithAPIType(APIType.HF_API),
2629
WithParam.baseAPI(DEFAULT_BASE_API),
2730
WithParam.defaultModel(DEFAULT_MODEL_NAME)
2831
);
@@ -46,14 +49,21 @@ public HuggingFaceEmbeddingFunction(WithParam... params) throws EFException {
4649
}
4750

4851
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException {
49-
Request request = new Request.Builder()
50-
.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString())
52+
Request.Builder rb = new Request.Builder()
53+
5154
.post(RequestBody.create(req.json(), JSON))
5255
.addHeader("Accept", "application/json")
5356
.addHeader("Content-Type", "application/json")
54-
.addHeader("User-Agent", Constants.HTTP_AGENT)
55-
.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString())
56-
.build();
57+
.addHeader("User-Agent", Constants.HTTP_AGENT);
58+
if (configParams.containsKey(API_TYPE_CONFIG_KEY) && configParams.get(API_TYPE_CONFIG_KEY).equals(APIType.HFEI_API)) {
59+
rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + HFEI_API_PATH);
60+
} else {
61+
rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString());
62+
}
63+
if (configParams.containsKey(Constants.EF_PARAMS_API_KEY)) {
64+
rb.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString());
65+
}
66+
Request request = rb.build();
5767
try (Response response = client.newCall(request).execute()) {
5868
if (!response.isSuccessful()) {
5969
throw new IOException("Unexpected code " + response);
@@ -86,4 +96,22 @@ public List<Embedding> embedDocuments(String[] documents) throws EFException {
8696
CreateEmbeddingResponse response = this.createEmbedding(new CreateEmbeddingRequest().inputs(documents));
8797
return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList());
8898
}
99+
100+
public static class WithAPIType extends WithParam {
101+
private final APIType apiType;
102+
103+
public WithAPIType(APIType apitype) {
104+
this.apiType = apitype;
105+
}
106+
107+
@Override
108+
public void apply(Map<String, Object> params) {
109+
params.put(API_TYPE_CONFIG_KEY, apiType);
110+
}
111+
}
112+
113+
public enum APIType{
114+
HF_API,
115+
HFEI_API
116+
}
89117
}

0 commit comments

Comments
 (0)