Skip to content

Commit d1c37eb

Browse files
authored
feat: EF Interface (#55)
Closes #48
1 parent 112a1c5 commit d1c37eb

19 files changed

Lines changed: 280 additions & 184 deletions
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package tech.amikos.chromadb;
2+
3+
public class ChromaException extends Exception {
4+
public ChromaException(String message) {
5+
super(message);
6+
}
7+
8+
public ChromaException(String message, Throwable cause) {
9+
super(message, cause);
10+
}
11+
12+
public ChromaException(Throwable cause) {
13+
super(cause);
14+
}
15+
}

src/main/java/tech/amikos/chromadb/Collection.java

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,33 +82,41 @@ public Object delete() throws ApiException {
8282
return this.delete(null, null, null);
8383
}
8484

85-
public Object upsert(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
85+
public Object upsert(List<Embedding> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ChromaException {
8686
AddEmbedding req = new AddEmbedding();
87-
List<List<Float>> _embeddings = embeddings;
87+
List<Embedding> _embeddings = embeddings;
8888
if (_embeddings == null) {
89-
_embeddings = this.embeddingFunction.createEmbedding(documents);
89+
_embeddings = this.embeddingFunction.embedDocuments(documents);
9090
}
9191
req.setEmbeddings((List<Object>) (Object) _embeddings);
9292
req.setMetadatas((List<Map<String, Object>>) (Object) metadatas);
9393
req.setDocuments(documents);
9494
req.incrementIndex(true);
9595
req.setIds(ids);
96-
return api.upsert(req, this.collectionId);
96+
try {
97+
return api.upsert(req, this.collectionId);
98+
} catch (ApiException e) {
99+
throw new ChromaException(e);
100+
}
97101
}
98102

99103

100-
public Object add(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
104+
public Object add(List<Embedding> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ChromaException {
101105
AddEmbedding req = new AddEmbedding();
102-
List<List<Float>> _embeddings = embeddings;
106+
List<Embedding> _embeddings = embeddings;
103107
if (_embeddings == null) {
104-
_embeddings = this.embeddingFunction.createEmbedding(documents);
108+
_embeddings = this.embeddingFunction.embedDocuments(documents);
105109
}
106110
req.setEmbeddings((List<Object>) (Object) _embeddings);
107111
req.setMetadatas((List<Map<String, Object>>) (Object) metadatas);
108112
req.setDocuments(documents);
109113
req.incrementIndex(true);
110114
req.setIds(ids);
111-
return api.add(req, this.collectionId);
115+
try {
116+
return api.add(req, this.collectionId);
117+
} catch (ApiException e) {
118+
throw new ChromaException(e);
119+
}
112120
}
113121

114122
public Integer count() throws ApiException {
@@ -161,23 +169,27 @@ public Object update(String newName, Map<String, Object> newMetadata) throws Api
161169
return resp;
162170
}
163171

164-
public Object updateEmbeddings(List<List<Float>> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ApiException {
172+
public Object updateEmbeddings(List<Embedding> embeddings, List<Map<String, String>> metadatas, List<String> documents, List<String> ids) throws ChromaException {
165173
UpdateEmbedding req = new UpdateEmbedding();
166-
List<List<Float>> _embeddings = embeddings;
174+
List<Embedding> _embeddings = embeddings;
167175
if (_embeddings == null) {
168-
_embeddings = this.embeddingFunction.createEmbedding(documents);
176+
_embeddings = this.embeddingFunction.embedDocuments(documents);
169177
}
170178
req.setEmbeddings((List<Object>) (Object) _embeddings);
171179
req.setDocuments(documents);
172180
req.setMetadatas((List<Object>) (Object) metadatas);
173181
req.setIds(ids);
174-
return api.update(req, this.collectionId);
182+
try {
183+
return api.update(req, this.collectionId);
184+
} catch (ApiException e) {
185+
throw new ChromaException(e);
186+
}
175187
}
176188

177189

178-
public QueryResponse query(List<String> queryTexts, Integer nResults, Map<String, Object> where, Map<String, Object> whereDocument, List<QueryEmbedding.IncludeEnum> include) throws ApiException {
190+
public QueryResponse query(List<String> queryTexts, Integer nResults, Map<String, Object> where, Map<String, Object> whereDocument, List<QueryEmbedding.IncludeEnum> include) throws ChromaException {
179191
QueryEmbedding body = new QueryEmbedding();
180-
body.queryEmbeddings((List<Object>) (Object) this.embeddingFunction.createEmbedding(queryTexts));
192+
body.queryEmbeddings((List<Object>) (Object) this.embeddingFunction.embedDocuments(queryTexts));
181193
body.nResults(nResults);
182194
body.include(include);
183195
if (where != null) {
@@ -186,9 +198,13 @@ public QueryResponse query(List<String> queryTexts, Integer nResults, Map<String
186198
if (whereDocument != null) {
187199
body.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
188200
}
189-
Gson gson = new Gson();
190-
String json = gson.toJson(api.getNearestNeighbors(body, this.collectionId));
191-
return new Gson().fromJson(json, QueryResponse.class);
201+
try {
202+
Gson gson = new Gson();
203+
String json = gson.toJson(api.getNearestNeighbors(body, this.collectionId));
204+
return new Gson().fromJson(json, QueryResponse.class);
205+
} catch (ApiException e) {
206+
throw new ChromaException(e);
207+
}
192208
}
193209

194210
public static class QueryResponse {

src/main/java/tech/amikos/chromadb/EFException.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
/**
44
* This exception encapsulates all exceptions thrown by the EmbeddingFunction class.
55
*/
6-
public class EFException extends Exception {
6+
public class EFException extends ChromaException {
77
public EFException(String message) {
88
super(message);
99
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package tech.amikos.chromadb;
2+
3+
import java.util.List;
4+
import java.util.stream.Collectors;
5+
import java.util.stream.IntStream;
6+
7+
public class Embedding {
8+
private final float[] embedding;
9+
10+
public Embedding(float[] embeddings) {
11+
this.embedding = embeddings;
12+
}
13+
14+
public Embedding(List<? extends Number> embedding) {
15+
this.embedding = new float[embedding.size()];
16+
for (int i = 0; i < embedding.size(); i++) {
17+
//TODO what if embeddings are integers?
18+
this.embedding[i] = embedding.get(i).floatValue();
19+
}
20+
}
21+
22+
23+
public List<Float> asList() {
24+
return IntStream.range(0, embedding.length)
25+
.mapToObj(i -> embedding[i])
26+
.collect(Collectors.toList());
27+
28+
}
29+
30+
public int getDimensions() {
31+
return embedding.length;
32+
}
33+
34+
public float[] asArray() {
35+
return embedding;
36+
}
37+
38+
public static Embedding fromList(List<Float> embedding) {
39+
return new Embedding(embedding);
40+
}
41+
42+
public static Embedding fromArray(float[] embedding) {
43+
return new Embedding(embedding);
44+
}
45+
}

src/main/java/tech/amikos/chromadb/EmbeddingFunction.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
public interface EmbeddingFunction {
66

7-
List<List<Float>> createEmbedding(List<String> documents);
7+
Embedding embedQuery(String query) throws EFException;
8+
9+
List<Embedding> embedDocuments(List<String> documents) throws EFException;
10+
11+
List<Embedding> embedDocuments(String[] documents) throws EFException;
812

9-
List<List<Float>> createEmbedding(List<String> documents, String model);
1013
}

src/main/java/tech/amikos/chromadb/embeddings/DefaultEmbeddingFunction.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
55
import ai.onnxruntime.*;
66

7+
import java.util.stream.Collectors;
78
import java.util.zip.GZIPInputStream;
89

910
import org.apache.commons.compress.archivers.tar.*;
@@ -12,6 +13,7 @@
1213
import org.nd4j.linalg.factory.Nd4j;
1314
import org.nd4j.shade.guava.primitives.Floats;
1415
import tech.amikos.chromadb.EFException;
16+
import tech.amikos.chromadb.Embedding;
1517
import tech.amikos.chromadb.EmbeddingFunction;
1618

1719
import java.io.*;
@@ -203,22 +205,25 @@ private boolean validateModel() {
203205
}
204206

205207
@Override
206-
public List<List<Float>> createEmbedding(List<String> documents) {
208+
public Embedding embedQuery(String query) throws EFException {
207209
try {
208-
return forward(documents);
210+
return Embedding.fromList(forward(Collections.singletonList(query)).get(0));
209211
} catch (OrtException e) {
210-
//TODO not great to throw a runtime exception but we need to update the interface in upcoming release to rethrow
211-
throw new RuntimeException(e);
212+
throw new EFException(e);
212213
}
213214
}
214215

215216
@Override
216-
public List<List<Float>> createEmbedding(List<String> documents, String model) {
217+
public List<Embedding> embedDocuments(List<String> documents) throws EFException {
217218
try {
218-
return forward(documents);
219+
return forward(documents).stream().map(Embedding::new).collect(Collectors.toList());
219220
} catch (OrtException e) {
220-
//TODO not great to throw a runtime exception but we need to update the interface in upcoming release to rethrow
221-
throw new RuntimeException(e);
221+
throw new EFException(e);
222222
}
223223
}
224+
225+
@Override
226+
public List<Embedding> embedDocuments(String[] documents) throws EFException {
227+
return embedDocuments(Arrays.asList(documents));
228+
}
224229
}

src/main/java/tech/amikos/chromadb/embeddings/cohere/CohereClient.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.google.gson.Gson;
44
import okhttp3.*;
5+
import tech.amikos.chromadb.EFException;
56

67
import java.io.IOException;
78

@@ -43,7 +44,7 @@ private String getApiKey() {
4344
return this.apiKey;
4445
}
4546

46-
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) {
47+
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException {
4748
Request request = new Request.Builder()
4849
.url(this.baseUrl + apiVersion + "/embed")
4950
.post(RequestBody.create(req.json(), JSON))
@@ -61,9 +62,8 @@ public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) {
6162

6263
return gson.fromJson(responseData, CreateEmbeddingResponse.class);
6364
} catch (IOException e) {
64-
e.printStackTrace();
65+
throw new EFException(e);
6566
}
66-
return null;
6767
}
6868

6969
}
Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
11
package tech.amikos.chromadb.embeddings.cohere;
22

3+
import org.jetbrains.annotations.NotNull;
4+
import tech.amikos.chromadb.EFException;
5+
import tech.amikos.chromadb.Embedding;
36
import tech.amikos.chromadb.EmbeddingFunction;
47

58
import java.util.List;
9+
import java.util.stream.Collectors;
610

711
public class CohereEmbeddingFunction implements EmbeddingFunction {
812

9-
private final String cohereAPIKey;
13+
private final CohereClient client;
1014

1115
public CohereEmbeddingFunction(String cohereAPIKey) {
12-
this.cohereAPIKey = cohereAPIKey;
16+
this.client = new CohereClient(cohereAPIKey);
1317

1418
}
1519

1620
@Override
17-
public List<List<Float>> createEmbedding(List<String> documents) {
18-
CohereClient client = new CohereClient(this.cohereAPIKey);
21+
public Embedding embedQuery(String query) throws EFException {
22+
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(new String[]{query}));
23+
return new Embedding(response.getEmbeddings().get(0));
24+
}
25+
26+
@Override
27+
public List<Embedding> embedDocuments(@NotNull List<String> documents) throws EFException{
1928
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents.toArray(new String[0])));
20-
return response.getEmbeddings();
29+
return response.getEmbeddings().stream().map(Embedding::new).collect(Collectors.toList());
2130
}
2231

2332
@Override
24-
public List<List<Float>> createEmbedding(List<String> documents, String model) {
25-
CohereClient client = new CohereClient(this.cohereAPIKey);
26-
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents.toArray(new String[0])).model(model));
27-
return response.getEmbeddings();
33+
public List<Embedding> embedDocuments(String[] documents) throws EFException{
34+
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().texts(documents));
35+
return response.getEmbeddings().stream().map(Embedding::new).collect(Collectors.toList());
2836
}
2937
}

src/main/java/tech/amikos/chromadb/embeddings/hf/HuggingFaceClient.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.google.gson.Gson;
44
import okhttp3.*;
5+
import tech.amikos.chromadb.EFException;
56

67
import java.io.IOException;
78
import java.util.List;
@@ -46,7 +47,7 @@ private String getApiKey() {
4647
return this.apiKey;
4748
}
4849

49-
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) {
50+
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException {
5051
Request request = new Request.Builder()
5152
.url(this.baseUrl + this.modelId)
5253
.post(RequestBody.create(req.json(), JSON))
@@ -65,9 +66,8 @@ public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) {
6566

6667
return new CreateEmbeddingResponse(parsedResponse);
6768
} catch (IOException e) {
68-
e.printStackTrace();
69+
throw new EFException(e);
6970
}
70-
return null;
7171
}
7272

7373
}
Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,38 @@
11
package tech.amikos.chromadb.embeddings.hf;
22

33

4+
import org.jetbrains.annotations.NotNull;
5+
import tech.amikos.chromadb.EFException;
6+
import tech.amikos.chromadb.Embedding;
47
import tech.amikos.chromadb.EmbeddingFunction;
58

69
import java.util.List;
10+
import java.util.stream.Collectors;
711

812
public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {
913
private final String hfAPIKey;
1014
private final HuggingFaceClient client;
15+
1116
public HuggingFaceEmbeddingFunction(String hfAPIKey) {
1217
this.hfAPIKey = hfAPIKey;
13-
this.client = new HuggingFaceClient(this.hfAPIKey);
18+
this.client = new HuggingFaceClient(this.hfAPIKey);
1419
}
1520

1621
@Override
17-
public List<List<Float>> createEmbedding(List<String> documents) {
22+
public Embedding embedQuery(String query) throws EFException {
23+
CreateEmbeddingResponse response = this.client.createEmbedding(new CreateEmbeddingRequest().inputs(new String[]{query}));
24+
return new Embedding(response.getEmbeddings().get(0));
25+
}
1826

27+
@Override
28+
public List<Embedding> embedDocuments(@NotNull List<String> documents) throws EFException {
1929
CreateEmbeddingResponse response = this.client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
20-
return response.getEmbeddings();
30+
return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList());
2131
}
2232

2333
@Override
24-
public List<List<Float>> createEmbedding(List<String> documents, String model) {
25-
client.modelId(model);
26-
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
27-
return response.getEmbeddings();
34+
public List<Embedding> embedDocuments(String[] documents) throws EFException {
35+
CreateEmbeddingResponse response = this.client.createEmbedding(new CreateEmbeddingRequest().inputs(documents));
36+
return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList());
2837
}
2938
}

0 commit comments

Comments
 (0)