Skip to content

Commit 8583ff2

Browse files
authored
fix: point id (#10)
* fix: point id * ci: try release master only * chore: bump version
1 parent b2ac52d commit 8583ff2

File tree

6 files changed

+112
-69
lines changed

6 files changed

+112
-69
lines changed

.github/workflows/release.yml

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
name: Release
22

3-
on: [push, workflow_dispatch]
3+
on:
4+
push:
5+
branches:
6+
- master
7+
workflow_dispatch:
48

59
jobs:
610
build:
@@ -38,6 +42,10 @@ jobs:
3842
restore-keys: |
3943
${{ runner.os }}-maven-
4044
45+
- uses: actions/setup-node@v4
46+
with:
47+
node-version: 20
48+
4149
- name: "🔧 setup Bun"
4250
uses: oven-sh/setup-bun@v1
4351

@@ -54,4 +62,4 @@ jobs:
5462
GIT_COMMITTER_NAME: "github-actions[bot]"
5563
GIT_COMMITTER_EMAIL: "41898282+github-actions[bot]@users.noreply.github.com"
5664
GIT_AUTHOR_NAME: ${{ steps.author_info.outputs.AUTHOR_NAME }}
57-
GIT_AUTHOR_EMAIL: ${{ steps.author_info.outputs.AUTHOR_EMAIL }}
65+
GIT_AUTHOR_EMAIL: ${{ steps.author_info.outputs.AUTHOR_EMAIL }}

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ For use with Java and Scala projects, the package can be found [here](https://ce
3030
<dependency>
3131
<groupId>io.qdrant</groupId>
3232
<artifactId>spark</artifactId>
33-
<version>1.12</version>
33+
<version>1.13</version>
3434
</dependency>
3535
```
3636

@@ -43,7 +43,7 @@ from pyspark.sql import SparkSession
4343

4444
spark = SparkSession.builder.config(
4545
"spark.jars",
46-
"spark-1.12-jar-with-dependencies.jar", # specify the downloaded JAR file
46+
"spark-1.13-jar-with-dependencies.jar", # specify the downloaded JAR file
4747
)
4848
.master("local[*]")
4949
.appName("qdrant")
@@ -73,7 +73,7 @@ To load data into Qdrant, a collection has to be created beforehand with the app
7373
You can use the `qdrant-spark` connector as a library in Databricks to ingest data into Qdrant.
7474
- Go to the `Libraries` section in your cluster dashboard.
7575
- Select `Install New` to open the library installation modal.
76-
- Search for `io.qdrant:spark:1.12` in the Maven packages and click `Install`.
76+
- Search for `io.qdrant:spark:1.13` in the Maven packages and click `Install`.
7777

7878
<img width="1064" alt="Screenshot 2024-01-05 at 17 20 01 (1)" src="https://github.com/qdrant/qdrant-spark/assets/46051506/d95773e0-c5c6-4ff2-bf50-8055bb08fd1b">
7979

pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<modelVersion>4.0.0</modelVersion>
77
<groupId>io.qdrant</groupId>
88
<artifactId>spark</artifactId>
9-
<version>1.12</version>
9+
<version>1.13</version>
1010
<name>qdrant-spark</name>
1111
<url>https://github.com/qdrant/qdrant-spark</url>
1212
<description>An Apache Spark connector for the Qdrant vector database</description>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package io.qdrant.spark;
2+
3+
import java.util.HashMap;
4+
import java.util.Map;
5+
6+
import org.apache.spark.sql.catalyst.InternalRow;
7+
import org.apache.spark.sql.catalyst.util.ArrayData;
8+
import org.apache.spark.sql.types.DataType;
9+
import org.apache.spark.sql.types.StructField;
10+
import org.apache.spark.sql.types.StructType;
11+
import org.apache.spark.sql.types.ArrayType;
12+
13+
class ObjectFactory {
14+
public static Object object(InternalRow record, StructField field, int fieldIndex) {
15+
DataType dataType = field.dataType();
16+
17+
switch (dataType.typeName()) {
18+
case "integer":
19+
return record.getInt(fieldIndex);
20+
case "float":
21+
return record.getFloat(fieldIndex);
22+
case "double":
23+
return record.getDouble(fieldIndex);
24+
case "long":
25+
return record.getLong(fieldIndex);
26+
case "boolean":
27+
return record.getBoolean(fieldIndex);
28+
case "string":
29+
return record.getString(fieldIndex);
30+
case "array":
31+
ArrayType arrayType = (ArrayType) dataType;
32+
ArrayData arrayData = record.getArray(fieldIndex);
33+
return object(arrayData, arrayType.elementType());
34+
case "struct":
35+
StructType structType = (StructType) dataType;
36+
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
37+
return object(structData, structType);
38+
default:
39+
return null;
40+
}
41+
}
42+
43+
public static Object object(ArrayData arrayData, DataType elementType) {
44+
45+
switch (elementType.typeName()) {
46+
case "string": {
47+
int length = arrayData.numElements();
48+
String[] result = new String[length];
49+
for (int i = 0; i < length; i++) {
50+
result[i] = arrayData.getUTF8String(i).toString();
51+
}
52+
return result;
53+
}
54+
55+
case "struct": {
56+
StructType structType = (StructType) elementType;
57+
int length = arrayData.numElements();
58+
Object[] result = new Object[length];
59+
for (int i = 0; i < length; i++) {
60+
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
61+
result[i] = object(structData, structType);
62+
}
63+
return result;
64+
}
65+
default:
66+
return arrayData.toObjectArray(elementType);
67+
}
68+
}
69+
70+
public static Object object(InternalRow structData, StructType structType) {
71+
Map<String, Object> result = new HashMap<>();
72+
for (int i = 0; i < structType.fields().length; i++) {
73+
StructField structField = structType.fields()[i];
74+
result.put(structField.name(), object(structData, structField, i));
75+
}
76+
return result;
77+
}
78+
}

src/main/java/io/qdrant/spark/QdrantDataWriter.java

+19-61
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,18 @@
33
import java.io.Serializable;
44
import java.util.ArrayList;
55
import java.util.HashMap;
6-
import java.util.Map;
76
import java.util.UUID;
87
import org.apache.spark.sql.catalyst.InternalRow;
9-
import org.apache.spark.sql.catalyst.util.ArrayData;
108
import org.apache.spark.sql.connector.write.DataWriter;
119
import org.apache.spark.sql.connector.write.WriterCommitMessage;
12-
import org.apache.spark.sql.types.ArrayType;
1310
import org.apache.spark.sql.types.DataType;
14-
import org.apache.spark.sql.types.DataTypes;
1511
import org.apache.spark.sql.types.StructField;
1612
import org.apache.spark.sql.types.StructType;
1713
import org.slf4j.Logger;
1814
import org.slf4j.LoggerFactory;
1915

16+
import static io.qdrant.spark.ObjectFactory.object;
17+
2018
/**
2119
* A DataWriter implementation that writes data to Qdrant, a vector search
2220
* engine. This class takes
@@ -53,12 +51,24 @@ public void write(InternalRow record) {
5351
for (StructField field : this.schema.fields()) {
5452
int fieldIndex = this.schema.fieldIndex(field.name());
5553
if (this.options.idField != null && field.name().equals(this.options.idField)) {
56-
point.id = record.get(fieldIndex, field.dataType()).toString();
54+
55+
DataType dataType = field.dataType();
56+
switch (dataType.typeName()) {
57+
case "string":
58+
point.id = record.getString(fieldIndex);
59+
break;
60+
61+
case "integer":
62+
point.id = record.getInt(fieldIndex);
63+
break;
64+
default:
65+
throw new IllegalArgumentException("Point ID should be of type string or integer");
66+
}
5767
} else if (field.name().equals(this.options.embeddingField)) {
58-
float[] vector = record.getArray(fieldIndex).toFloatArray();
59-
point.vector = vector;
68+
point.vector = record.getArray(fieldIndex).toFloatArray();
69+
6070
} else {
61-
payload.put(field.name(), convertToJavaType(record, field, fieldIndex));
71+
payload.put(field.name(), object(record, field, fieldIndex));
6272
}
6373
}
6474

@@ -107,62 +117,10 @@ public void abort() {
107117
@Override
108118
public void close() {
109119
}
110-
111-
private Object convertToJavaType(InternalRow record, StructField field, int fieldIndex) {
112-
DataType dataType = field.dataType();
113-
114-
if (dataType == DataTypes.StringType) {
115-
return record.getString(fieldIndex);
116-
} else if (dataType == DataTypes.DateType || dataType == DataTypes.TimestampType) {
117-
return record.getString(fieldIndex);
118-
} else if (dataType instanceof ArrayType) {
119-
ArrayType arrayType = (ArrayType) dataType;
120-
ArrayData arrayData = record.getArray(fieldIndex);
121-
return convertArrayToJavaType(arrayData, arrayType.elementType());
122-
} else if (dataType instanceof StructType) {
123-
StructType structType = (StructType) dataType;
124-
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
125-
return convertStructToJavaType(structData, structType);
126-
}
127-
128-
// Fall back to the generic get method
129-
return record.get(fieldIndex, dataType);
130-
}
131-
132-
private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType) {
133-
if (elementType == DataTypes.StringType) {
134-
int length = arrayData.numElements();
135-
String[] result = new String[length];
136-
for (int i = 0; i < length; i++) {
137-
result[i] = arrayData.getUTF8String(i).toString();
138-
}
139-
return result;
140-
} else if (elementType instanceof StructType) {
141-
StructType structType = (StructType) elementType;
142-
int length = arrayData.numElements();
143-
Object[] result = new Object[length];
144-
for (int i = 0; i < length; i++) {
145-
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
146-
result[i] = convertStructToJavaType(structData, structType);
147-
}
148-
return result;
149-
} else {
150-
return arrayData.toObjectArray(elementType);
151-
}
152-
}
153-
154-
private Object convertStructToJavaType(InternalRow structData, StructType structType) {
155-
Map<String, Object> result = new HashMap<>();
156-
for (int i = 0; i < structType.fields().length; i++) {
157-
StructField structField = structType.fields()[i];
158-
result.put(structField.name(), convertToJavaType(structData, structField, i));
159-
}
160-
return result;
161-
}
162120
}
163121

164122
class Point implements Serializable {
165-
public String id;
123+
public Object id;
166124
public float[] vector;
167125
public HashMap<String, Object> payload;
168126
}

src/test/java/io/qdrant/spark/TestQdrant.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ public void testGetTable() {
5454
options.put("embedding_field", "embedding");
5555
options.put("qdrant_url", "http://localhost:8080");
5656
CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options);
57-
var reader = qdrant.getTable(schema, null, dataSourceOptions);
58-
Assert.assertTrue(reader instanceof QdrantCluster);
57+
Assert.assertTrue(qdrant.getTable(schema, null, dataSourceOptions) instanceof QdrantCluster);
5958
}
6059

6160
@Test()

0 commit comments

Comments
 (0)