Skip to content

Commit 7ea10b8

Browse files
authored
feat: sparse, multiple vectors support (#16)
* feat: sparse, multiple vectors support * refactor: Qdrant options * chore: formatting * chore: removed redundant validator * test: multiple dense, multiple sparse * docs: Updated README.md
1 parent 7579916 commit 7ea10b8

14 files changed

+757
-210
lines changed

README.md

+152-35
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,16 @@ This will build and store the fat JAR in the `target` directory by default.
2626

2727
For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark).
2828

29-
```xml
30-
<dependency>
31-
<groupId>io.qdrant</groupId>
32-
<artifactId>spark</artifactId>
33-
<version>2.0.1</version>
34-
</dependency>
35-
```
36-
3729
## Usage 📝
3830

39-
### Creating a Spark session (Single-node) with Qdrant support 🌟
31+
### Creating a Spark session (Single-node) with Qdrant support
4032

4133
```python
4234
from pyspark.sql import SparkSession
4335

4436
spark = SparkSession.builder.config(
4537
"spark.jars",
46-
"spark-2.0.1.jar", # specify the downloaded JAR file
38+
"spark-2.1.0.jar", # specify the downloaded JAR file
4739
)
4840
.master("local[*]")
4941
.appName("qdrant")
@@ -52,30 +44,150 @@ spark = SparkSession.builder.config(
5244

5345
### Loading data 📊
5446

55-
To load data into Qdrant, a collection has to be created beforehand with the appropriate vector dimensions and configurations.
47+
> [!IMPORTANT]
48+
> Before loading the data using this connector, a collection has to be [created](https://qdrant.tech/documentation/concepts/collections/#create-a-collection) in advance with the appropriate vector dimensions and configurations.
49+
50+
The connector supports ingesting multiple named/unnamed, dense/sparse vectors.
51+
52+
<details>
53+
<summary><b>Unnamed/Default vector</b></summary>
54+
55+
```python
56+
<pyspark.sql.DataFrame>
57+
.write
58+
.format("io.qdrant.spark.Qdrant")
59+
.option("qdrant_url", <QDRANT_GRPC_URL>)
60+
.option("collection_name", <QDRANT_COLLECTION_NAME>)
61+
.option("embedding_field", <EMBEDDING_FIELD_NAME>) # Expected to be a field of type ArrayType(FloatType)
62+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
63+
.mode("append")
64+
.save()
65+
```
66+
67+
</details>
68+
69+
<details>
70+
<summary><b>Named vector</b></summary>
71+
72+
```python
73+
<pyspark.sql.DataFrame>
74+
.write
75+
.format("io.qdrant.spark.Qdrant")
76+
.option("qdrant_url", <QDRANT_GRPC_URL>)
77+
.option("collection_name", <QDRANT_COLLECTION_NAME>)
78+
.option("embedding_field", <EMBEDDING_FIELD_NAME>) # Expected to be a field of type ArrayType(FloatType)
79+
.option("vector_name", <VECTOR_NAME>)
80+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
81+
.mode("append")
82+
.save()
83+
```
84+
85+
> #### NOTE
86+
>
87+
> The `embedding_field` and `vector_name` options are maintained for backward compatibility. It is recommended to use `vector_fields` and `vector_names` for named vectors as shown below.
88+
89+
</details>
90+
91+
<details>
92+
<summary><b>Multiple named vectors</b></summary>
93+
94+
```python
95+
<pyspark.sql.DataFrame>
96+
.write
97+
.format("io.qdrant.spark.Qdrant")
98+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
99+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
100+
.option("vector_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
101+
.option("vector_names", "<VECTOR_NAME>,<ANOTHER_VECTOR_NAME>")
102+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
103+
.mode("append")
104+
.save()
105+
```
106+
107+
</details>
108+
109+
<details>
110+
<summary><b>Sparse vectors</b></summary>
111+
112+
```python
113+
<pyspark.sql.DataFrame>
114+
.write
115+
.format("io.qdrant.spark.Qdrant")
116+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
117+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
118+
.option("sparse_vector_value_fields", "<COLUMN_NAME>")
119+
.option("sparse_vector_index_fields", "<COLUMN_NAME>")
120+
.option("sparse_vector_names", "<SPARSE_VECTOR_NAME>")
121+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
122+
.mode("append")
123+
.save()
124+
```
125+
126+
</details>
127+
128+
<details>
129+
<summary><b>Multiple sparse vectors</b></summary>
130+
131+
```python
132+
<pyspark.sql.DataFrame>
133+
.write
134+
.format("io.qdrant.spark.Qdrant")
135+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
136+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
137+
.option("sparse_vector_value_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
138+
.option("sparse_vector_index_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
139+
.option("sparse_vector_names", "<SPARSE_VECTOR_NAME>,<ANOTHER_SPARSE_VECTOR_NAME>")
140+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
141+
.mode("append")
142+
.save()
143+
```
144+
145+
</details>
146+
147+
<details>
148+
<summary><b>Combination of named dense and sparse vectors</b></summary>
149+
150+
```python
151+
<pyspark.sql.DataFrame>
152+
.write
153+
.format("io.qdrant.spark.Qdrant")
154+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
155+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
156+
.option("vector_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
157+
.option("vector_names", "<VECTOR_NAME>,<ANOTHER_VECTOR_NAME>")
158+
.option("sparse_vector_value_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
159+
.option("sparse_vector_index_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
160+
.option("sparse_vector_names", "<SPARSE_VECTOR_NAME>,<ANOTHER_SPARSE_VECTOR_NAME>")
161+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
162+
.mode("append")
163+
.save()
164+
```
165+
166+
</details>
167+
168+
<details>
169+
<summary><b>No vectors - Entire dataframe is stored as payload</b></summary>
56170

57171
```python
58-
<pyspark.sql.DataFrame>
59-
.write
60-
.format("io.qdrant.spark.Qdrant")
61-
.option("qdrant_url", <QDRANT_GRPC_URL>)
62-
.option("collection_name", <QDRANT_COLLECTION_NAME>)
63-
.option("embedding_field", <EMBEDDING_FIELD_NAME>) # Expected to be a field of type ArrayType(FloatType)
64-
.option("schema", <pyspark.sql.DataFrame>.schema.json())
65-
.mode("append")
66-
.save()
172+
<pyspark.sql.DataFrame>
173+
.write
174+
.format("io.qdrant.spark.Qdrant")
175+
.option("qdrant_url", "<QDRANT_GRPC_URL>")
176+
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
177+
.option("schema", <pyspark.sql.DataFrame>.schema.json())
178+
.mode("append")
179+
.save()
67180
```
68181

69-
- By default, UUIDs are generated for each row. If you need to use custom IDs, you can do so by setting the `id_field` option.
70-
- An API key can be set using the `api_key` option to make authenticated requests.
182+
</details>
71183

72184
## Databricks
73185

74-
You can use the `qdrant-spark` connector as a library in Databricks to ingest data into Qdrant.
186+
You can use the connector as a library in Databricks to ingest data into Qdrant.
75187

76188
- Go to the `Libraries` section in your cluster dashboard.
77189
- Select `Install New` to open the library installation modal.
78-
- Search for `io.qdrant:spark:2.0.1` in the Maven packages and click `Install`.
190+
- Search for `io.qdrant:spark:2.1.0` in the Maven packages and click `Install`.
79191

80192
<img width="1064" alt="Screenshot 2024-01-05 at 17 20 01 (1)" src="https://github.com/qdrant/qdrant-spark/assets/46051506/d95773e0-c5c6-4ff2-bf50-8055bb08fd1b">
81193

@@ -85,17 +197,22 @@ Qdrant supports all the Spark data types. The appropriate types are mapped based
85197

86198
## Options and Spark types 🛠️
87199

88-
| Option | Description | DataType | Required |
89-
| :---------------- | :------------------------------------------------------------------------ | :--------------------- | :------- |
90-
| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: <http://localhost:6334> | `StringType` ||
91-
| `collection_name` | Name of the collection to write data into | `StringType` ||
92-
| `embedding_field` | Name of the field holding the embeddings | `ArrayType(FloatType)` ||
93-
| `schema` | JSON string of the dataframe schema | `StringType` ||
94-
| `id_field` | Name of the field holding the point IDs. Default: Generates a random UUId | `StringType` ||
95-
| `batch_size` | Max size of the upload batch. Default: 100 | `IntType` ||
96-
| `retries` | Number of upload retries. Default: 3 | `IntType` ||
97-
| `api_key` | Qdrant API key to be sent in the header. Default: null | `StringType` ||
98-
| `vector_name` | Name of the vector in the collection. Default: null | `StringType` ||
200+
| Option | Description | Column DataType | Required |
201+
| :--------------------------- | :------------------------------------------------------------------ | :---------------------------- | :------- |
202+
| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: <http://localhost:6334> | - ||
203+
| `collection_name` | Name of the collection to write data into | - ||
204+
| `schema` | JSON string of the dataframe schema | - ||
205+
| `embedding_field` | Name of the column holding the embeddings | `ArrayType(FloatType)` ||
206+
| `id_field` | Name of the column holding the point IDs. Default: Random UUID | `StringType` or `IntegerType` ||
207+
| `batch_size` | Max size of the upload batch. Default: 64 | - ||
208+
| `retries` | Number of upload retries. Default: 3 | - ||
209+
| `api_key` | Qdrant API key for authentication | - ||
210+
| `vector_name` | Name of the vector in the collection. | - ||
211+
| `vector_fields` | Comma-separated names of columns holding the vectors. | `ArrayType(FloatType)` ||
212+
| `vector_names` | Comma-separated names of vectors in the collection. | - ||
213+
| `sparse_vector_index_fields` | Comma-separated names of columns holding the sparse vector indices. | `ArrayType(IntegerType)` ||
214+
| `sparse_vector_value_fields` | Comma-separated names of columns holding the sparse vector values. | `ArrayType(FloatType)` ||
215+
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||
99216

100217
## LICENSE 📜
101218

src/main/java/io/qdrant/spark/Qdrant.java

+1-33
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package io.qdrant.spark;
22

3-
import java.util.Arrays;
4-
import java.util.List;
53
import java.util.Map;
64
import org.apache.spark.sql.connector.catalog.Table;
75
import org.apache.spark.sql.connector.catalog.TableProvider;
@@ -17,8 +15,7 @@
1715
*/
1816
public class Qdrant implements TableProvider, DataSourceRegister {
1917

20-
private final String[] requiredFields =
21-
new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"};
18+
private final String[] requiredFields = new String[] {"schema", "collection_name", "qdrant_url"};
2219

2320
/**
2421
* Returns the short name of the data source.
@@ -44,11 +41,9 @@ public StructType inferSchema(CaseInsensitiveStringMap options) {
4441
}
4542
}
4643
StructType schema = (StructType) StructType.fromJson(options.get("schema"));
47-
validateOptions(options, schema);
4844

4945
return schema;
5046
}
51-
;
5247

5348
/**
5449
* Returns a table for the data source based on the provided schema, partitioning, and properties.
@@ -64,31 +59,4 @@ public Table getTable(
6459
QdrantOptions options = new QdrantOptions(properties);
6560
return new QdrantCluster(options, schema);
6661
}
67-
68-
/**
69-
* Checks if the required options are present in the provided options and chekcs if the specified
70-
* id_field and embedding_field are present in the provided schema.
71-
*
72-
* @param options The options to check.
73-
* @param schema The schema to check.
74-
*/
75-
void validateOptions(CaseInsensitiveStringMap options, StructType schema) {
76-
77-
List<String> fieldNames = Arrays.asList(schema.fieldNames());
78-
79-
if (options.containsKey("id_field")) {
80-
String idField = options.get("id_field").toString();
81-
82-
if (!fieldNames.contains(idField)) {
83-
throw new IllegalArgumentException("Specified 'id_field' is not present in the schema");
84-
}
85-
}
86-
87-
String embeddingField = options.get("embedding_field").toString();
88-
89-
if (!fieldNames.contains(embeddingField)) {
90-
throw new IllegalArgumentException(
91-
"Specified 'embedding_field' is not present in the schema");
92-
}
93-
}
9462
}

src/main/java/io/qdrant/spark/QdrantDataWriter.java

+10-46
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
11
package io.qdrant.spark;
22

3-
import static io.qdrant.client.PointIdFactory.id;
4-
import static io.qdrant.client.VectorFactory.vector;
5-
import static io.qdrant.client.VectorsFactory.namedVectors;
6-
import static io.qdrant.client.VectorsFactory.vectors;
7-
import static io.qdrant.spark.QdrantValueFactory.value;
8-
93
import io.qdrant.client.grpc.JsonWithInt.Value;
4+
import io.qdrant.client.grpc.Points.PointId;
105
import io.qdrant.client.grpc.Points.PointStruct;
6+
import io.qdrant.client.grpc.Points.Vectors;
117
import java.io.Serializable;
128
import java.net.URL;
139
import java.util.ArrayList;
14-
import java.util.Collections;
15-
import java.util.HashMap;
1610
import java.util.Map;
17-
import java.util.UUID;
1811
import org.apache.spark.sql.catalyst.InternalRow;
1912
import org.apache.spark.sql.connector.write.DataWriter;
2013
import org.apache.spark.sql.connector.write.WriterCommitMessage;
21-
import org.apache.spark.sql.types.DataType;
22-
import org.apache.spark.sql.types.StructField;
2314
import org.apache.spark.sql.types.StructType;
2415
import org.slf4j.Logger;
2516
import org.slf4j.LoggerFactory;
@@ -40,7 +31,7 @@ public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {
4031

4132
private final ArrayList<PointStruct> points = new ArrayList<>();
4233

43-
public QdrantDataWriter(QdrantOptions options, StructType schema) throws Exception {
34+
public QdrantDataWriter(QdrantOptions options, StructType schema) {
4435
this.options = options;
4536
this.schema = schema;
4637
this.qdrantUrl = options.qdrantUrl;
@@ -50,44 +41,17 @@ public QdrantDataWriter(QdrantOptions options, StructType schema) throws Excepti
5041
@Override
5142
public void write(InternalRow record) {
5243
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
53-
Map<String, Value> payload = new HashMap<>();
54-
55-
if (this.options.idField == null) {
56-
pointBuilder.setId(id(UUID.randomUUID()));
57-
}
58-
for (StructField field : this.schema.fields()) {
59-
int fieldIndex = this.schema.fieldIndex(field.name());
60-
if (this.options.idField != null && field.name().equals(this.options.idField)) {
6144

62-
DataType dataType = field.dataType();
63-
switch (dataType.typeName()) {
64-
case "string":
65-
pointBuilder.setId(id(UUID.fromString(record.getString(fieldIndex))));
66-
break;
45+
PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options);
46+
pointBuilder.setId(pointId);
6747

68-
case "integer":
69-
case "long":
70-
pointBuilder.setId(id(record.getInt(fieldIndex)));
71-
break;
72-
73-
default:
74-
throw new IllegalArgumentException("Point ID should be of type string or integer");
75-
}
76-
77-
} else if (field.name().equals(this.options.embeddingField)) {
78-
float[] embeddings = record.getArray(fieldIndex).toFloatArray();
79-
if (options.vectorName != null) {
80-
pointBuilder.setVectors(
81-
namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings))));
82-
} else {
83-
pointBuilder.setVectors(vectors(embeddings));
84-
}
85-
} else {
86-
payload.put(field.name(), value(record, field, fieldIndex));
87-
}
88-
}
48+
Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options);
49+
pointBuilder.setVectors(vectors);
8950

51+
Map<String, Value> payload =
52+
QdrantPayloadHandler.preparePayload(record, this.schema, this.options);
9053
pointBuilder.putAllPayload(payload);
54+
9155
this.points.add(pointBuilder.build());
9256

9357
if (this.points.size() >= this.options.batchSize) {

0 commit comments

Comments
 (0)