13
13
import java .util .Map ;
14
14
import org .apache .spark .sql .catalyst .InternalRow ;
15
15
import org .apache .spark .sql .catalyst .util .ArrayData ;
16
+ import org .apache .spark .sql .types .ArrayType ;
16
17
import org .apache .spark .sql .types .DataType ;
17
18
import org .apache .spark .sql .types .StructField ;
18
19
import org .apache .spark .sql .types .StructType ;
@@ -98,6 +99,12 @@ private static float[] extractFloatArray(InternalRow record, int fieldIndex, Dat
98
99
throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
99
100
}
100
101
102
+ ArrayType arrayType = (ArrayType ) dataType ;
103
+
104
+ if (!arrayType .elementType ().typeName ().equalsIgnoreCase ("float" )) {
105
+ throw new IllegalArgumentException ("Expected array elements to be of FloatType" );
106
+ }
107
+
101
108
return record .getArray (fieldIndex ).toFloatArray ();
102
109
}
103
110
@@ -107,14 +114,26 @@ private static int[] extractIntArray(InternalRow record, int fieldIndex, DataTyp
107
114
throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
108
115
}
109
116
117
+ ArrayType arrayType = (ArrayType ) dataType ;
118
+
119
+ if (!arrayType .elementType ().typeName ().equalsIgnoreCase ("integer" )) {
120
+ throw new IllegalArgumentException ("Expected array elements to be of IntegerType" );
121
+ }
122
+
110
123
return record .getArray (fieldIndex ).toIntArray ();
111
124
}
112
125
113
126
private static float [][] extractMultiVecArray (
114
127
InternalRow record , int fieldIndex , DataType dataType ) {
115
128
116
129
if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
117
- throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
130
+ throw new IllegalArgumentException ("Multi Vector field must be of type ArrayType" );
131
+ }
132
+
133
+ ArrayType arrayType = (ArrayType ) dataType ;
134
+
135
+ if (!arrayType .elementType ().typeName ().equalsIgnoreCase ("array" )) {
136
+ throw new IllegalArgumentException ("Multi Vector elements must be of type ArrayType" );
118
137
}
119
138
120
139
ArrayData arrayData = record .getArray (fieldIndex );
0 commit comments