Skip to content

Commit 9484a64

Browse files
authored
refactor: Further type validate (#30)
1 parent 15521a1 commit 9484a64

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

src/main/java/io/qdrant/spark/QdrantVectorHandler.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.Map;
1414
import org.apache.spark.sql.catalyst.InternalRow;
1515
import org.apache.spark.sql.catalyst.util.ArrayData;
16+
import org.apache.spark.sql.types.ArrayType;
1617
import org.apache.spark.sql.types.DataType;
1718
import org.apache.spark.sql.types.StructField;
1819
import org.apache.spark.sql.types.StructType;
@@ -98,6 +99,12 @@ private static float[] extractFloatArray(InternalRow record, int fieldIndex, Dat
9899
throw new IllegalArgumentException("Vector field must be of type ArrayType");
99100
}
100101

102+
ArrayType arrayType = (ArrayType) dataType;
103+
104+
if (!arrayType.elementType().typeName().equalsIgnoreCase("float")) {
105+
throw new IllegalArgumentException("Expected array elements to be of FloatType");
106+
}
107+
101108
return record.getArray(fieldIndex).toFloatArray();
102109
}
103110

@@ -107,14 +114,26 @@ private static int[] extractIntArray(InternalRow record, int fieldIndex, DataTyp
107114
throw new IllegalArgumentException("Vector field must be of type ArrayType");
108115
}
109116

117+
ArrayType arrayType = (ArrayType) dataType;
118+
119+
if (!arrayType.elementType().typeName().equalsIgnoreCase("integer")) {
120+
throw new IllegalArgumentException("Expected array elements to be of IntegerType");
121+
}
122+
110123
return record.getArray(fieldIndex).toIntArray();
111124
}
112125

113126
private static float[][] extractMultiVecArray(
114127
InternalRow record, int fieldIndex, DataType dataType) {
115128

116129
if (!dataType.typeName().equalsIgnoreCase("array")) {
117-
throw new IllegalArgumentException("Vector field must be of type ArrayType");
130+
throw new IllegalArgumentException("Multi Vector field must be of type ArrayType");
131+
}
132+
133+
ArrayType arrayType = (ArrayType) dataType;
134+
135+
if (!arrayType.elementType().typeName().equalsIgnoreCase("array")) {
136+
throw new IllegalArgumentException("Multi Vector elements must be of type ArrayType");
118137
}
119138

120139
ArrayData arrayData = record.getArray(fieldIndex);

0 commit comments

Comments
 (0)