Skip to content

Commit 0f84323

Browse files
authored
feat: EF configurability (#56)
Closes #51
1 parent d1c37eb commit 0f84323

21 files changed

Lines changed: 459 additions & 447 deletions

.github/workflows/integration-test.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
name: Integration test
22

33
on:
4-
push:
5-
branches:
6-
- develop
7-
- feature/*
84
pull_request:
95
branches:
106
- main
7+
- "**"
118

129
jobs:
1310
integration-test:

src/main/java/tech/amikos/chromadb/Client.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tech.amikos.chromadb;
22

33
import com.google.gson.internal.LinkedTreeMap;
4+
import tech.amikos.chromadb.embeddings.EmbeddingFunction;
45
import tech.amikos.chromadb.handler.ApiClient;
56
import tech.amikos.chromadb.handler.ApiException;
67
import tech.amikos.chromadb.handler.DefaultApi;

src/main/java/tech/amikos/chromadb/Collection.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.google.gson.Gson;
44
import com.google.gson.annotations.SerializedName;
55
import com.google.gson.internal.LinkedTreeMap;
6+
import tech.amikos.chromadb.embeddings.EmbeddingFunction;
67
import tech.amikos.chromadb.handler.ApiException;
78
import tech.amikos.chromadb.handler.DefaultApi;
89
import tech.amikos.chromadb.model.*;
@@ -11,6 +12,8 @@
1112
import java.util.Map;
1213
import java.util.stream.Collectors;
1314

15+
import static java.lang.Thread.sleep;
16+
1417
public class Collection {
1518
static Gson gson = new Gson();
1619
DefaultApi api;
@@ -88,7 +91,7 @@ public Object upsert(List<Embedding> embeddings, List<Map<String, String>> metad
8891
if (_embeddings == null) {
8992
_embeddings = this.embeddingFunction.embedDocuments(documents);
9093
}
91-
req.setEmbeddings((List<Object>) (Object) _embeddings);
94+
req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
9295
req.setMetadatas((List<Map<String, Object>>) (Object) metadatas);
9396
req.setDocuments(documents);
9497
req.incrementIndex(true);
@@ -107,7 +110,7 @@ public Object add(List<Embedding> embeddings, List<Map<String, String>> metadata
107110
if (_embeddings == null) {
108111
_embeddings = this.embeddingFunction.embedDocuments(documents);
109112
}
110-
req.setEmbeddings((List<Object>) (Object) _embeddings);
113+
req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
111114
req.setMetadatas((List<Map<String, Object>>) (Object) metadatas);
112115
req.setDocuments(documents);
113116
req.incrementIndex(true);
@@ -175,7 +178,7 @@ public Object updateEmbeddings(List<Embedding> embeddings, List<Map<String, Stri
175178
if (_embeddings == null) {
176179
_embeddings = this.embeddingFunction.embedDocuments(documents);
177180
}
178-
req.setEmbeddings((List<Object>) (Object) _embeddings);
181+
req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList()));
179182
req.setDocuments(documents);
180183
req.setMetadatas((List<Object>) (Object) metadatas);
181184
req.setIds(ids);
@@ -189,7 +192,7 @@ public Object updateEmbeddings(List<Embedding> embeddings, List<Map<String, Stri
189192

190193
public QueryResponse query(List<String> queryTexts, Integer nResults, Map<String, Object> where, Map<String, Object> whereDocument, List<QueryEmbedding.IncludeEnum> include) throws ChromaException {
191194
QueryEmbedding body = new QueryEmbedding();
192-
body.queryEmbeddings((List<Object>) (Object) this.embeddingFunction.embedDocuments(queryTexts));
195+
body.queryEmbeddings(this.embeddingFunction.embedDocuments(queryTexts).stream().map(Embedding::asArray).collect(Collectors.toList()));
193196
body.nResults(nResults);
194197
body.include(include);
195198
if (where != null) {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package tech.amikos.chromadb;
2+
3+
import okhttp3.MediaType;
4+
5+
public class Constants {
6+
7+
public static final String EF_PARAMS_BASE_API = "baseAPI";
8+
public static final String EF_PARAMS_MODEL = "modelName";
9+
public static final String EF_PARAMS_API_KEY = "apiKey";
10+
public static final String EF_PARAMS_API_KEY_FROM_ENV = "envAPIKey";
11+
public static final String MODEL_NAME = "MODEL_NAME";
12+
public static final MediaType JSON = MediaType.parse("application/json; charset=utf-8");
13+
public static final String HTTP_AGENT = "chroma-java-client";
14+
}

src/main/java/tech/amikos/chromadb/embeddings/DefaultEmbeddingFunction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.nd4j.shade.guava.primitives.Floats;
1515
import tech.amikos.chromadb.EFException;
1616
import tech.amikos.chromadb.Embedding;
17-
import tech.amikos.chromadb.EmbeddingFunction;
1817

1918
import java.io.*;
2019
import java.net.URL;

src/main/java/tech/amikos/chromadb/EmbeddingFunction.java renamed to src/main/java/tech/amikos/chromadb/embeddings/EmbeddingFunction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
package tech.amikos.chromadb;
1+
package tech.amikos.chromadb.embeddings;
2+
3+
import tech.amikos.chromadb.EFException;
4+
import tech.amikos.chromadb.Embedding;
25

36
import java.util.List;
47

@@ -9,5 +12,4 @@ public interface EmbeddingFunction {
912
List<Embedding> embedDocuments(List<String> documents) throws EFException;
1013

1114
List<Embedding> embedDocuments(String[] documents) throws EFException;
12-
1315
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package tech.amikos.chromadb.embeddings;
2+
3+
4+
import tech.amikos.chromadb.Constants;
5+
import tech.amikos.chromadb.EFException;
6+
7+
import java.util.Map;
8+
9+
public abstract class WithParam {
10+
public abstract void apply(Map<String, Object> params) throws EFException;
11+
12+
public static WithParam apiKey(String apiKey) {
13+
return new WithAPIKey(apiKey);
14+
}
15+
16+
public static WithParam apiKeyFromEnv(String apiKeyEnvVarName) {
17+
return new WithEnvAPIKey(apiKeyEnvVarName);
18+
}
19+
20+
public static WithParam model(String model) {
21+
return new WithModel(model);
22+
}
23+
24+
public static WithParam modelFromEnv(String modelEnvVarName) {
25+
return new WithModelFromEnv(modelEnvVarName);
26+
}
27+
28+
public static WithParam baseAPI(String baseAPI) {
29+
return new WithBaseAPI(baseAPI);
30+
}
31+
32+
public static WithParam defaultModel(String model) {
33+
return new WithDefaultModel(model);
34+
}
35+
36+
37+
}
38+
39+
class WithBaseAPI extends WithParam {
40+
private final String baseAPI;
41+
42+
public WithBaseAPI(String baseAPI) {
43+
this.baseAPI = baseAPI;
44+
}
45+
46+
@Override
47+
public void apply(Map<String, Object> params) {
48+
params.put(Constants.EF_PARAMS_BASE_API, baseAPI);
49+
}
50+
}
51+
52+
class WithModel extends WithParam {
53+
private final String model;
54+
55+
public WithModel(String model) {
56+
this.model = model;
57+
}
58+
59+
@Override
60+
public void apply(Map<String, Object> params) {
61+
params.put(Constants.EF_PARAMS_MODEL, model);
62+
}
63+
}
64+
65+
class WithModelFromEnv extends WithParam {
66+
67+
private String modelEnvVarName = Constants.MODEL_NAME;
68+
69+
public WithModelFromEnv(String modelEnvVarName) {
70+
this.modelEnvVarName = modelEnvVarName;
71+
}
72+
73+
/**
74+
* Reads MODEL_NAME from the environment
75+
*/
76+
public WithModelFromEnv() {
77+
}
78+
79+
@Override
80+
public void apply(Map<String, Object> params) throws EFException {
81+
if (System.getenv(modelEnvVarName) == null) {
82+
throw new EFException("Model not found in environment variable: " + modelEnvVarName);
83+
}
84+
params.put(Constants.EF_PARAMS_MODEL, System.getenv(modelEnvVarName));
85+
}
86+
}
87+
88+
class WithDefaultModel extends WithParam {
89+
90+
private final String model;
91+
92+
public WithDefaultModel(String model) {
93+
this.model = model;
94+
}
95+
96+
@Override
97+
public void apply(Map<String, Object> params) {
98+
params.put(Constants.EF_PARAMS_MODEL, model);
99+
}
100+
}
101+
102+
class WithAPIKey extends WithParam {
103+
private final String apiKey;
104+
105+
public WithAPIKey(String apiKey) {
106+
this.apiKey = apiKey;
107+
}
108+
109+
@Override
110+
public void apply(Map<String, Object> params) {
111+
params.put(Constants.EF_PARAMS_API_KEY, apiKey);
112+
}
113+
}
114+
115+
class WithEnvAPIKey extends WithParam {
116+
private final String apiKeyEnvVarName;
117+
118+
public WithEnvAPIKey(String apiKeyEnvVarName) {
119+
this.apiKeyEnvVarName = apiKeyEnvVarName;
120+
}
121+
122+
@Override
123+
public void apply(Map<String, Object> params) throws EFException {
124+
if (System.getenv(apiKeyEnvVarName) == null) {
125+
throw new EFException("API Key not found in environment variable: " + apiKeyEnvVarName);
126+
}
127+
params.put(Constants.EF_PARAMS_API_KEY_FROM_ENV, System.getenv(apiKeyEnvVarName));
128+
}
129+
}

src/main/java/tech/amikos/chromadb/embeddings/cohere/CohereClient.java

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)