Skip to content

Commit 0dc72a7

Browse files
authored
feat: Multi-vector support (#28)
1 parent e50330c commit 0dc72a7

File tree

8 files changed

+745
-23
lines changed

8 files changed

+745
-23
lines changed

README.md

+55-17
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,42 @@ _Click each to expand._
167167

168168
</details>
169169

170+
<details>
171+
<summary><b>Multi-vectors</b></summary>
172+
173+
```python
174+
<pyspark.sql.DataFrame>
175+
.write
176+
.format("io.qdrant.spark.Qdrant")
177+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
178+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
179+
.option("multi_vector_fields", "<COLUMN_NAME>")
180+
.option("multi_vector_names", "<MULTI_VECTOR_NAME>")
181+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
182+
.mode("append")
183+
.save()
184+
```
185+
186+
</details>
187+
188+
<details>
189+
<summary><b>Multiple Multi-vectors</b></summary>
190+
191+
```python
192+
<pyspark.sql.DataFrame>
193+
.write
194+
.format("io.qdrant.spark.Qdrant")
195+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
196+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
197+
.option("multi_vector_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
198+
.option("multi_vector_names", "<MULTI_VECTOR_NAME>,<ANOTHER_MULTI_VECTOR_NAME>")
199+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
200+
.mode("append")
201+
.save()
202+
```
203+
204+
</details>
205+
170206
<details>
171207
<summary><b>No vectors - Entire dataframe is stored as payload</b></summary>
172208

@@ -202,23 +238,25 @@ The appropriate Spark data types are mapped to the Qdrant payload based on the p
202238

203239
## Options and Spark types
204240

205-
| Option | Description | Column DataType | Required |
206-
| :--------------------------- | :----------------------------------------------------------------------------------- | :---------------------------- | :------- |
207-
| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: <http://localhost:6334> | - ||
208-
| `collection_name` | Name of the collection to write data into | - ||
209-
| `schema` | JSON string of the dataframe schema | - ||
210-
| `embedding_field` | Name of the column holding the embeddings (Deprecated - Use `vector_fields` instead) | `ArrayType(FloatType)` ||
211-
| `id_field` | Name of the column holding the point IDs. Default: Random UUID | `StringType` or `IntegerType` ||
212-
| `batch_size` | Max size of the upload batch. Default: 64 | - ||
213-
| `retries` | Number of upload retries. Default: 3 | - ||
214-
| `api_key` | Qdrant API key for authentication | - ||
215-
| `vector_name` | Name of the vector in the collection. | - ||
216-
| `vector_fields` | Comma-separated names of columns holding the vectors. | `ArrayType(FloatType)` ||
217-
| `vector_names` | Comma-separated names of vectors in the collection. | - ||
218-
| `sparse_vector_index_fields` | Comma-separated names of columns holding the sparse vector indices. | `ArrayType(IntegerType)` ||
219-
| `sparse_vector_value_fields` | Comma-separated names of columns holding the sparse vector values. | `ArrayType(FloatType)` ||
220-
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||
221-
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||
241+
| Option | Description | Column DataType | Required |
242+
| :--------------------------- | :----------------------------------------------------------------------------------- | :-------------------------------- | :------- |
243+
| `qdrant_url` | gRPC URL of the Qdrant instance. Eg: <http://localhost:6334> | - ||
244+
| `collection_name` | Name of the collection to write data into | - ||
245+
| `schema` | JSON string of the dataframe schema | - ||
246+
| `embedding_field` | Name of the column holding the embeddings (Deprecated - Use `vector_fields` instead) | `ArrayType(FloatType)` ||
247+
| `id_field` | Name of the column holding the point IDs. Default: Random UUID | `StringType` or `IntegerType` ||
248+
| `batch_size` | Max size of the upload batch. Default: 64 | - ||
249+
| `retries` | Number of upload retries. Default: 3 | - ||
250+
| `api_key` | Qdrant API key for authentication | - ||
251+
| `vector_name` | Name of the vector in the collection. | - ||
252+
| `vector_fields` | Comma-separated names of columns holding the vectors. | `ArrayType(FloatType)` ||
253+
| `vector_names` | Comma-separated names of vectors in the collection. | - ||
254+
| `sparse_vector_index_fields` | Comma-separated names of columns holding the sparse vector indices. | `ArrayType(IntegerType)` ||
255+
| `sparse_vector_value_fields` | Comma-separated names of columns holding the sparse vector values. | `ArrayType(FloatType)` ||
256+
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||
257+
| `multi_vector_fields` | Comma-separated names of columns holding the multi-vector values. | `ArrayType(ArrayType(FloatType))` ||
258+
| `multi_vector_names` | Comma-separated names of the multi-vectors in the collection. | - ||
259+
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||
222260

223261
## LICENSE
224262

pom.xml

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<modelVersion>4.0.0</modelVersion>
77
<groupId>io.qdrant</groupId>
88
<artifactId>spark</artifactId>
9-
<version>2.2.1</version>
9+
<version>2.3.0</version>
1010
<name>qdrant-spark</name>
1111
<url>https://github.com/qdrant/qdrant-spark</url>
1212
<description>An Apache Spark connector for the Qdrant vector database</description>
@@ -39,7 +39,7 @@
3939
<dependency>
4040
<groupId>io.qdrant</groupId>
4141
<artifactId>client</artifactId>
42-
<version>1.9.0</version>
42+
<version>1.10.0</version>
4343
</dependency>
4444
<dependency>
4545
<groupId>com.google.guava</groupId>
@@ -206,4 +206,4 @@
206206
</build>
207207

208208

209-
</project>
209+
</project>

src/main/java/io/qdrant/spark/QdrantOptions.java

+6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ public class QdrantOptions implements Serializable {
2929
public final String[] sparseVectorNames;
3030
public final String[] vectorFields;
3131
public final String[] vectorNames;
32+
public final String[] multiVectorFields;
33+
public final String[] multiVectorNames;
3234
public final List<String> payloadFieldsToSkip;
3335
public final ShardKeySelector shardKeySelector;
3436

@@ -49,6 +51,8 @@ public QdrantOptions(Map<String, String> options) {
4951
sparseVectorNames = parseArray(options.get("sparse_vector_names"));
5052
vectorFields = parseArray(options.get("vector_fields"));
5153
vectorNames = parseArray(options.get("vector_names"));
54+
multiVectorFields = parseArray(options.get("multi_vector_fields"));
55+
multiVectorNames = parseArray(options.get("multi_vector_names"));
5256

5357
shardKeySelector = parseShardKeys(options.get("shard_key_selector"));
5458

@@ -121,6 +125,8 @@ private List<String> buildPayloadFieldsToSkip() {
121125
fields.addAll(Arrays.asList(sparseVectorNames));
122126
fields.addAll(Arrays.asList(vectorFields));
123127
fields.addAll(Arrays.asList(vectorNames));
128+
fields.addAll(Arrays.asList(multiVectorFields));
129+
fields.addAll(Arrays.asList(multiVectorNames));
124130
return fields;
125131
}
126132
}

src/main/java/io/qdrant/spark/QdrantVectorHandler.java

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.qdrant.spark;
22

3+
import static io.qdrant.client.VectorFactory.multiVector;
34
import static io.qdrant.client.VectorFactory.vector;
45
import static io.qdrant.client.VectorsFactory.namedVectors;
56

@@ -11,6 +12,7 @@
1112
import java.util.HashMap;
1213
import java.util.Map;
1314
import org.apache.spark.sql.catalyst.InternalRow;
15+
import org.apache.spark.sql.catalyst.util.ArrayData;
1416
import org.apache.spark.sql.types.StructType;
1517

1618
public class QdrantVectorHandler {
@@ -19,9 +21,10 @@ public static Vectors prepareVectors(
1921
InternalRow record, StructType schema, QdrantOptions options) {
2022
Vectors.Builder vectorsBuilder = Vectors.newBuilder();
2123

22-
// Combine sparse and dense vectors
24+
// Combine sparse, dense and multi vectors
2325
vectorsBuilder.mergeFrom(prepareSparseVectors(record, schema, options));
2426
vectorsBuilder.mergeFrom(prepareDenseVectors(record, schema, options));
27+
vectorsBuilder.mergeFrom(prepareMultiVectors(record, schema, options));
2528

2629
// Maitaining support for the "embedding_field" and "vector_name" options
2730
if (!options.embeddingField.isEmpty()) {
@@ -62,6 +65,20 @@ private static Vectors prepareDenseVectors(
6265
return namedVectors(denseVectors);
6366
}
6467

68+
private static Vectors prepareMultiVectors(
69+
InternalRow record, StructType schema, QdrantOptions options) {
70+
Map<String, Vector> multiVectors = new HashMap<>();
71+
72+
for (int i = 0; i < options.multiVectorNames.length; i++) {
73+
String name = options.multiVectorNames[i];
74+
float[][] vectors = extractMultiVecArray(record, schema, options.multiVectorFields[i]);
75+
76+
multiVectors.put(name, multiVector(vectors));
77+
}
78+
79+
return namedVectors(multiVectors);
80+
}
81+
6582
private static float[] extractFloatArray(
6683
InternalRow record, StructType schema, String fieldName) {
6784
int fieldIndex = schema.fieldIndex(fieldName);
@@ -72,4 +89,20 @@ private static int[] extractIntArray(InternalRow record, StructType schema, Stri
7289
int fieldIndex = schema.fieldIndex(fieldName);
7390
return record.getArray(fieldIndex).toIntArray();
7491
}
92+
93+
private static float[][] extractMultiVecArray(
94+
InternalRow record, StructType schema, String fieldName) {
95+
int fieldIndex = schema.fieldIndex(fieldName);
96+
ArrayData arrayData = record.getArray(fieldIndex);
97+
int numRows = arrayData.numElements();
98+
ArrayData firstRow = arrayData.getArray(0);
99+
int numCols = firstRow.numElements();
100+
101+
float[][] multiVecArray = new float[numRows][numCols];
102+
for (int i = 0; i < numRows; i++) {
103+
multiVecArray[i] = arrayData.getArray(i).toFloatArray();
104+
}
105+
106+
return multiVecArray;
107+
}
75108
}

src/test/java/io/qdrant/spark/TestIntegration.java

+30
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import io.qdrant.client.QdrantGrpcClient;
55
import io.qdrant.client.grpc.Collections.CreateCollection;
66
import io.qdrant.client.grpc.Collections.Distance;
7+
import io.qdrant.client.grpc.Collections.MultiVectorComparator;
8+
import io.qdrant.client.grpc.Collections.MultiVectorConfig;
79
import io.qdrant.client.grpc.Collections.SparseVectorConfig;
810
import io.qdrant.client.grpc.Collections.SparseVectorParams;
911
import io.qdrant.client.grpc.Collections.VectorParams;
@@ -70,6 +72,16 @@ public void setup() throws InterruptedException, ExecutionException {
7072
.setDistance(DISTANCE)
7173
.setSize(DIMENSION)
7274
.build())
75+
.putMap(
76+
"multi",
77+
VectorParams.newBuilder()
78+
.setSize(DIMENSION)
79+
.setDistance(DISTANCE)
80+
.setMultivectorConfig(
81+
MultiVectorConfig.newBuilder()
82+
.setComparator(MultiVectorComparator.MaxSim)
83+
.build())
84+
.build())
7385
.build())
7486
.build())
7587
.setSparseVectorsConfig(
@@ -252,6 +264,24 @@ public void testMultipleDenseAndSparseVectors() throws InterruptedException, Exe
252264
df.count());
253265
}
254266

267+
@Test
268+
public void testMultiVector() throws InterruptedException, ExecutionException {
269+
270+
df.write()
271+
.format("io.qdrant.spark.Qdrant")
272+
.option("id_field", "id")
273+
.option("schema", df.schema().json())
274+
.option("collection_name", COLLECTION_NAME)
275+
.option("multi_vector_fields", "multi")
276+
.option("multi_vector_names", "multi")
277+
.option("qdrant_url", qdrantUrl)
278+
.mode("append")
279+
.save();
280+
281+
Assert.assertEquals(
282+
"testMultiVector()", (long) client.countAsync(COLLECTION_NAME).get(), df.count());
283+
}
284+
255285
@Test
256286
public void testNoVectors() throws InterruptedException, ExecutionException {
257287

src/test/java/io/qdrant/spark/TestQdrantGrpc.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class TestQdrantGrpc {
2727
private static int grpcPort = 6334;
2828
private static Distance distance = Distance.Cosine;
2929

30-
@Rule public final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant");
30+
@Rule public final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant:dev");
3131

3232
@Before
3333
public void setup() throws InterruptedException, ExecutionException {

src/test/java/io/qdrant/spark/TestSchema.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ public static StructType schema() {
9191
.add(new StructField("crypto", cryptoSchema, true, Metadata.empty()))
9292
.add(new StructField("dense_vector", denseVectorType, false, Metadata.empty()))
9393
.add(new StructField("sparse_indices", sparseIndicesType, false, Metadata.empty()))
94-
.add(new StructField("sparse_values", sparseValuesType, false, Metadata.empty()));
94+
.add(new StructField("sparse_values", sparseValuesType, false, Metadata.empty()))
95+
.add(
96+
new StructField(
97+
"multi", DataTypes.createArrayType(denseVectorType), false, Metadata.empty()));
9598
}
9699
}

0 commit comments

Comments
 (0)