Skip to content

Commit 15521a1

Browse files
authored
refactor: explicitly validate-array-type (#29)
1 parent 0dc72a7 commit 15521a1

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<modelVersion>4.0.0</modelVersion>
77
<groupId>io.qdrant</groupId>
88
<artifactId>spark</artifactId>
9-
<version>2.3.0</version>
9+
<version>2.3.1</version>
1010
<name>qdrant-spark</name>
1111
<url>https://github.com/qdrant/qdrant-spark</url>
1212
<description>An Apache Spark connector for the Qdrant vector database</description>

src/main/java/io/qdrant/spark/QdrantVectorHandler.java

+39-15
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,25 @@
1313
import java.util.Map;
1414
import org.apache.spark.sql.catalyst.InternalRow;
1515
import org.apache.spark.sql.catalyst.util.ArrayData;
16+
import org.apache.spark.sql.types.DataType;
17+
import org.apache.spark.sql.types.StructField;
1618
import org.apache.spark.sql.types.StructType;
1719

1820
public class QdrantVectorHandler {
1921

2022
public static Vectors prepareVectors(
2123
InternalRow record, StructType schema, QdrantOptions options) {
2224
Vectors.Builder vectorsBuilder = Vectors.newBuilder();
23-
2425
// Combine sparse, dense and multi vectors
2526
vectorsBuilder.mergeFrom(prepareSparseVectors(record, schema, options));
2627
vectorsBuilder.mergeFrom(prepareDenseVectors(record, schema, options));
2728
vectorsBuilder.mergeFrom(prepareMultiVectors(record, schema, options));
2829

2930
// Maitaining support for the "embedding_field" and "vector_name" options
3031
if (!options.embeddingField.isEmpty()) {
31-
float[] embeddings = extractFloatArray(record, schema, options.embeddingField);
32+
int fieldIndex = schema.fieldIndex(options.embeddingField);
33+
StructField field = schema.fields()[fieldIndex];
34+
float[] embeddings = extractFloatArray(record, fieldIndex, field.dataType());
3235
// 'options.vectorName' defaults to ""
3336
vectorsBuilder.mergeFrom(
3437
namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings))));
@@ -42,10 +45,15 @@ private static Vectors prepareSparseVectors(
4245
Map<String, Vector> sparseVectors = new HashMap<>();
4346

4447
for (int i = 0; i < options.sparseVectorNames.length; i++) {
45-
String name = options.sparseVectorNames[i];
46-
float[] values = extractFloatArray(record, schema, options.sparseVectorValueFields[i]);
47-
int[] indices = extractIntArray(record, schema, options.sparseVectorIndexFields[i]);
48+
int fieldIndex = schema.fieldIndex(options.sparseVectorValueFields[i]);
49+
StructField field = schema.fields()[fieldIndex];
50+
float[] values = extractFloatArray(record, fieldIndex, field.dataType());
4851

52+
fieldIndex = schema.fieldIndex(options.sparseVectorIndexFields[i]);
53+
field = schema.fields()[fieldIndex];
54+
int[] indices = extractIntArray(record, fieldIndex, field.dataType());
55+
56+
String name = options.sparseVectorNames[i];
4957
sparseVectors.put(name, vector(Floats.asList(values), Ints.asList(indices)));
5058
}
5159

@@ -57,8 +65,11 @@ private static Vectors prepareDenseVectors(
5765
Map<String, Vector> denseVectors = new HashMap<>();
5866

5967
for (int i = 0; i < options.vectorNames.length; i++) {
68+
int fieldIndex = schema.fieldIndex(options.vectorFields[i]);
69+
StructField field = schema.fields()[fieldIndex];
70+
float[] values = extractFloatArray(record, fieldIndex, field.dataType());
71+
6072
String name = options.vectorNames[i];
61-
float[] values = extractFloatArray(record, schema, options.vectorFields[i]);
6273
denseVectors.put(name, vector(values));
6374
}
6475

@@ -70,29 +81,42 @@ private static Vectors prepareMultiVectors(
7081
Map<String, Vector> multiVectors = new HashMap<>();
7182

7283
for (int i = 0; i < options.multiVectorNames.length; i++) {
73-
String name = options.multiVectorNames[i];
74-
float[][] vectors = extractMultiVecArray(record, schema, options.multiVectorFields[i]);
84+
int fieldIndex = schema.fieldIndex(options.multiVectorFields[i]);
85+
StructField field = schema.fields()[fieldIndex];
86+
float[][] vectors = extractMultiVecArray(record, fieldIndex, field.dataType());
7587

88+
String name = options.multiVectorNames[i];
7689
multiVectors.put(name, multiVector(vectors));
7790
}
7891

7992
return namedVectors(multiVectors);
8093
}
8194

82-
private static float[] extractFloatArray(
83-
InternalRow record, StructType schema, String fieldName) {
84-
int fieldIndex = schema.fieldIndex(fieldName);
95+
private static float[] extractFloatArray(InternalRow record, int fieldIndex, DataType dataType) {
96+
97+
if (!dataType.typeName().equalsIgnoreCase("array")) {
98+
throw new IllegalArgumentException("Vector field must be of type ArrayType");
99+
}
100+
85101
return record.getArray(fieldIndex).toFloatArray();
86102
}
87103

88-
private static int[] extractIntArray(InternalRow record, StructType schema, String fieldName) {
89-
int fieldIndex = schema.fieldIndex(fieldName);
104+
private static int[] extractIntArray(InternalRow record, int fieldIndex, DataType dataType) {
105+
106+
if (!dataType.typeName().equalsIgnoreCase("array")) {
107+
throw new IllegalArgumentException("Vector field must be of type ArrayType");
108+
}
109+
90110
return record.getArray(fieldIndex).toIntArray();
91111
}
92112

93113
private static float[][] extractMultiVecArray(
94-
InternalRow record, StructType schema, String fieldName) {
95-
int fieldIndex = schema.fieldIndex(fieldName);
114+
InternalRow record, int fieldIndex, DataType dataType) {
115+
116+
if (!dataType.typeName().equalsIgnoreCase("array")) {
117+
throw new IllegalArgumentException("Vector field must be of type ArrayType");
118+
}
119+
96120
ArrayData arrayData = record.getArray(fieldIndex);
97121
int numRows = arrayData.numElements();
98122
ArrayData firstRow = arrayData.getArray(0);

0 commit comments

Comments
 (0)