Skip to content

Commit e2654fc

Browse files
authored
feat: shard key selector (#20)
1 parent beb4c48 commit e2654fc

File tree

5 files changed

+69
-14
lines changed

5 files changed

+69
-14
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ Qdrant supports all the Spark data types. The appropriate types are mapped based
213213
| `sparse_vector_index_fields` | Comma-separated names of columns holding the sparse vector indices. | `ArrayType(IntegerType)` ||
214214
| `sparse_vector_value_fields` | Comma-separated names of columns holding the sparse vector values. | `ArrayType(FloatType)` ||
215215
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||
216+
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||
216217

217218
## LICENSE 📜
218219

src/main/java/io/qdrant/spark/QdrantDataWriter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void write(int retries) {
8181
try {
8282
// Instantiate a new QdrantGrpc object to maintain serializability
8383
QdrantGrpc qdrant = new QdrantGrpc(new URL(this.qdrantUrl), this.apiKey);
84-
qdrant.upsert(this.options.collectionName, this.points);
84+
qdrant.upsert(this.options.collectionName, this.points, this.options.shardKeySelector);
8585
qdrant.close();
8686
this.points.clear();
8787
} catch (Exception e) {

src/main/java/io/qdrant/spark/QdrantGrpc.java

+25-10
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,25 @@
33
import io.qdrant.client.QdrantClient;
44
import io.qdrant.client.QdrantGrpcClient;
55
import io.qdrant.client.grpc.Points.PointStruct;
6+
import io.qdrant.client.grpc.Points.ShardKeySelector;
7+
import io.qdrant.client.grpc.Points.UpsertPoints;
8+
69
import java.io.Serializable;
710
import java.net.MalformedURLException;
811
import java.net.URL;
912
import java.util.List;
1013
import java.util.concurrent.ExecutionException;
1114

15+
import javax.annotation.Nullable;
16+
1217
/** A class that provides methods to interact with Qdrant REST API. */
1318
public class QdrantGrpc implements Serializable {
1419
private final QdrantClient client;
1520

1621
/**
1722
* Constructor for QdrantRest class.
1823
*
19-
* @param url The URL of the Qdrant instance.
24+
* @param url The URL of the Qdrant instance.
2025
* @param apiKey The API key to authenticate with Qdrant.
2126
* @throws MalformedURLException If the URL is invalid.
2227
*/
@@ -26,23 +31,33 @@ public QdrantGrpc(URL url, String apiKey) throws MalformedURLException {
2631
int port = url.getPort() == -1 ? 6334 : url.getPort();
2732
boolean useTls = url.getProtocol().equalsIgnoreCase("https");
2833

29-
this.client =
30-
new QdrantClient(
31-
QdrantGrpcClient.newBuilder(host, port, useTls).withApiKey(apiKey).build());
34+
this.client = new QdrantClient(
35+
QdrantGrpcClient.newBuilder(host, port, useTls).withApiKey(apiKey).build());
3236
}
3337

3438
/**
3539
* Uploads a batch of points to a Qdrant collection.
3640
*
3741
* @param collectionName The name of the collection to upload the points to.
38-
* @param points The list of points to upload.
39-
* @throws InterruptedException If there was an error uploading the batch to Qdrant.
40-
* @throws ExecutionException If there was an error uploading the batch to Qdrant.
42+
* @param points The list of points to upload.
43+
* @param shardKeySelector The shard key selector to use for the upsert.
44+
*
45+
* @throws InterruptedException If there was an error uploading the batch to
46+
* Qdrant.
47+
* @throws ExecutionException If there was an error uploading the batch to
48+
* Qdrant.
4149
*/
42-
public void upsert(String collectionName, List<PointStruct> points)
50+
public void upsert(String collectionName, List<PointStruct> points, @Nullable ShardKeySelector shardKeySelector)
4351
throws InterruptedException, ExecutionException {
44-
this.client.upsertAsync(collectionName, points).get();
45-
return;
52+
53+
UpsertPoints.Builder upsertPoints = UpsertPoints.newBuilder().setCollectionName(collectionName)
54+
.addAllPoints(points);
55+
56+
if (shardKeySelector != null) {
57+
upsertPoints.setShardKeySelector(shardKeySelector);
58+
}
59+
60+
this.client.upsertAsync(upsertPoints.build()).get();
4661
}
4762

4863
public void close() {

src/main/java/io/qdrant/spark/QdrantOptions.java

+41-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
import java.util.Map;
88
import java.util.Objects;
99

10+
import javax.annotation.Nullable;
11+
12+
import io.qdrant.client.grpc.Collections.ShardKey;
13+
import io.qdrant.client.grpc.Points.ShardKeySelector;
14+
15+
import static io.qdrant.client.ShardKeySelectorFactory.shardKeySelector;
16+
import static io.qdrant.client.ShardKeyFactory.shardKey;
17+
1018
public class QdrantOptions implements Serializable {
1119
private static final int DEFAULT_BATCH_SIZE = 64;
1220
private static final int DEFAULT_RETRIES = 3;
@@ -25,14 +33,14 @@ public class QdrantOptions implements Serializable {
2533
public final String[] vectorFields;
2634
public final String[] vectorNames;
2735
public final List<String> payloadFieldsToSkip;
36+
public final ShardKeySelector shardKeySelector;
2837

2938
public QdrantOptions(Map<String, String> options) {
3039
Objects.requireNonNull(options);
3140

3241
qdrantUrl = options.get("qdrant_url");
3342
collectionName = options.get("collection_name");
34-
batchSize =
35-
Integer.parseInt(options.getOrDefault("batch_size", String.valueOf(DEFAULT_BATCH_SIZE)));
43+
batchSize = Integer.parseInt(options.getOrDefault("batch_size", String.valueOf(DEFAULT_BATCH_SIZE)));
3644
retries = Integer.parseInt(options.getOrDefault("retries", String.valueOf(DEFAULT_RETRIES)));
3745
idField = options.getOrDefault("id_field", "");
3846
apiKey = options.getOrDefault("api_key", "");
@@ -45,6 +53,8 @@ public QdrantOptions(Map<String, String> options) {
4553
vectorFields = parseArray(options.get("vector_fields"));
4654
vectorNames = parseArray(options.get("vector_names"));
4755

56+
shardKeySelector = parseShardKeys(options.get("shard_key_selector"));
57+
4858
validateSparseVectorFields();
4959
validateVectorFields();
5060

@@ -83,4 +93,33 @@ private void validateVectorFields() {
8393
throw new IllegalArgumentException("Vector fields and names should have the same length");
8494
}
8595
}
96+
97+
private ShardKeySelector parseShardKeys(@Nullable String shardKeys) {
98+
if (shardKeys == null) {
99+
return null;
100+
}
101+
String[] keys = shardKeys.split(",");
102+
103+
ShardKey[] shardKeysArray = new ShardKey[keys.length];
104+
105+
for (int i = 0; i < keys.length; i++) {
106+
String key = keys[i];
107+
if (isInt(key.trim())) {
108+
shardKeysArray[i] = shardKey(Integer.parseInt(key.trim()));
109+
} else {
110+
shardKeysArray[i] = shardKey(key.trim());
111+
}
112+
}
113+
114+
return shardKeySelector(shardKeysArray);
115+
}
116+
117+
boolean isInt(String s) {
118+
try {
119+
Integer.parseInt(s);
120+
return true;
121+
} catch (NumberFormatException er) {
122+
return false;
123+
}
124+
}
86125
}

src/test/java/io/qdrant/spark/TestQdrantGrpc.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void testUploadBatch() throws Exception {
6969
points.add(point2Builder.build());
7070

7171
// call the uploadBatch method
72-
qdrantGrpc.upsert(collectionName, points);
72+
qdrantGrpc.upsert(collectionName, points, null);
7373

7474
qdrantGrpc.close();
7575
}

0 commit comments

Comments
 (0)