Skip to content

Commit f38288c

Browse files
authored
feat: Optionally ensure upload (#36)
* feat: wait to ensure upload Signed-off-by: Anush008 <[email protected]> * docs: Updated README.md Signed-off-by: Anush008 <[email protected]> --------- Signed-off-by: Anush008 <[email protected]>
1 parent b66b5ff commit f38288c

File tree

6 files changed

+24
-9
lines changed

6 files changed

+24
-9
lines changed

README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
## Installation
66

7+
To integrate the connector into your Spark environment, get the JAR file from one of the sources listed below.
8+
79
> [!IMPORTANT]
8-
> Requires Java 8 or above.
10+
> Ensure your system is running Java 8.
911
1012
### GitHub Releases
1113

@@ -20,11 +22,11 @@ Once the requirements have been satisfied, run the following command in the proj
2022
mvn package
2123
```
2224

23-
This will build and store the fat JAR in the `target` directory by default.
25+
The JAR file will be written into the `target` directory by default.
2426

2527
### Maven Central
2628

27-
For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark).
29+
Find the project on Maven Central [here](https://central.sonatype.com/artifact/io.qdrant/spark).
2830

2931
## Usage
3032

@@ -257,6 +259,7 @@ The appropriate Spark data types are mapped to the Qdrant payload based on the p
257259
| `multi_vector_fields` | Comma-separated names of columns holding the multi-vector values. | `ArrayType(ArrayType(FloatType))` ||
258260
| `multi_vector_names` | Comma-separated names of the multi-vectors in the collection. | - ||
259261
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||
262+
| `wait` | Wait for each batch upsert to complete. `true` or `false`. Defaults to `true`. | - ||
260263

261264
## LICENSE
262265

pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<modelVersion>4.0.0</modelVersion>
77
<groupId>io.qdrant</groupId>
88
<artifactId>spark</artifactId>
9-
<version>2.3.2</version>
9+
<version>2.3.3</version>
1010
<name>qdrant-spark</name>
1111
<url>https://github.com/qdrant/qdrant-spark</url>
1212
<description>An Apache Spark connector for the Qdrant vector database</description>

src/main/java/io/qdrant/spark/QdrantDataWriter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ private void doWriteBatch() throws Exception {
6868

6969
// Instantiate QdrantGrpc client for each batch to maintain serializability
7070
QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
71-
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector);
71+
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector, options.wait);
7272
qdrant.close();
7373
}
7474

src/main/java/io/qdrant/spark/QdrantGrpc.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ public QdrantGrpc(URL url, String apiKey) throws MalformedURLException {
2626
}
2727

2828
public void upsert(
29-
String collectionName, List<PointStruct> points, ShardKeySelector shardKeySelector)
29+
String collectionName,
30+
List<PointStruct> points,
31+
ShardKeySelector shardKeySelector,
32+
boolean wait)
3033
throws InterruptedException, ExecutionException {
3134
UpsertPoints.Builder upsertPoints =
32-
UpsertPoints.newBuilder().setCollectionName(collectionName).addAllPoints(points);
35+
UpsertPoints.newBuilder()
36+
.setCollectionName(collectionName)
37+
.setWait(wait)
38+
.addAllPoints(points);
3339
if (shardKeySelector != null) {
3440
upsertPoints.setShardKeySelector(shardKeySelector);
3541
}

src/main/java/io/qdrant/spark/QdrantOptions.java

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
public class QdrantOptions implements Serializable {
1616
private static final int DEFAULT_BATCH_SIZE = 64;
1717
private static final int DEFAULT_RETRIES = 3;
18+
private static final boolean DEFAULT_WAIT = true;
1819

1920
public final String qdrantUrl;
2021
public final String apiKey;
@@ -33,6 +34,7 @@ public class QdrantOptions implements Serializable {
3334
public final String[] multiVectorNames;
3435
public final List<String> payloadFieldsToSkip;
3536
public final ShardKeySelector shardKeySelector;
37+
public final boolean wait;
3638

3739
public QdrantOptions(Map<String, String> options) {
3840
Objects.requireNonNull(options);
@@ -45,6 +47,7 @@ public QdrantOptions(Map<String, String> options) {
4547
apiKey = options.getOrDefault("api_key", "");
4648
embeddingField = options.getOrDefault("embedding_field", "");
4749
vectorName = options.getOrDefault("vector_name", "");
50+
wait = getBooleanOption(options, "wait", DEFAULT_WAIT);
4851

4952
sparseVectorValueFields = parseArray(options.get("sparse_vector_value_fields"));
5053
sparseVectorIndexFields = parseArray(options.get("sparse_vector_index_fields"));
@@ -66,6 +69,10 @@ private int getIntOption(Map<String, String> options, String key, int defaultVal
6669
return Integer.parseInt(options.getOrDefault(key, String.valueOf(defaultValue)));
6770
}
6871

72+
private boolean getBooleanOption(Map<String, String> options, String key, boolean defaultValue) {
73+
return Boolean.parseBoolean(options.getOrDefault(key, String.valueOf(defaultValue)));
74+
}
75+
6976
private String[] parseArray(String input) {
7077
return input == null
7178
? new String[0]

src/test/java/io/qdrant/spark/TestQdrantGrpc.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ public void testUploadBatch() throws Exception {
6868
point2Builder.putPayload("rand_number", value(89));
6969
points.add(point2Builder.build());
7070

71-
// call the uploadBatch method
72-
qdrantGrpc.upsert(collectionName, points, null);
71+
qdrantGrpc.upsert(collectionName, points, null, true);
7372

7473
qdrantGrpc.close();
7574
}

0 commit comments

Comments
 (0)