13
13
import java .util .Map ;
14
14
import org .apache .spark .sql .catalyst .InternalRow ;
15
15
import org .apache .spark .sql .catalyst .util .ArrayData ;
16
+ import org .apache .spark .sql .types .DataType ;
17
+ import org .apache .spark .sql .types .StructField ;
16
18
import org .apache .spark .sql .types .StructType ;
17
19
18
20
public class QdrantVectorHandler {
19
21
20
22
public static Vectors prepareVectors (
21
23
InternalRow record , StructType schema , QdrantOptions options ) {
22
24
Vectors .Builder vectorsBuilder = Vectors .newBuilder ();
23
-
24
25
// Combine sparse, dense and multi vectors
25
26
vectorsBuilder .mergeFrom (prepareSparseVectors (record , schema , options ));
26
27
vectorsBuilder .mergeFrom (prepareDenseVectors (record , schema , options ));
27
28
vectorsBuilder .mergeFrom (prepareMultiVectors (record , schema , options ));
28
29
29
30
// Maitaining support for the "embedding_field" and "vector_name" options
30
31
if (!options .embeddingField .isEmpty ()) {
31
- float [] embeddings = extractFloatArray (record , schema , options .embeddingField );
32
+ int fieldIndex = schema .fieldIndex (options .embeddingField );
33
+ StructField field = schema .fields ()[fieldIndex ];
34
+ float [] embeddings = extractFloatArray (record , fieldIndex , field .dataType ());
32
35
// 'options.vectorName' defaults to ""
33
36
vectorsBuilder .mergeFrom (
34
37
namedVectors (Collections .singletonMap (options .vectorName , vector (embeddings ))));
@@ -42,10 +45,15 @@ private static Vectors prepareSparseVectors(
42
45
Map <String , Vector > sparseVectors = new HashMap <>();
43
46
44
47
for (int i = 0 ; i < options .sparseVectorNames .length ; i ++) {
45
- String name = options .sparseVectorNames [i ];
46
- float [] values = extractFloatArray ( record , schema , options . sparseVectorValueFields [ i ]) ;
47
- int [] indices = extractIntArray (record , schema , options . sparseVectorIndexFields [ i ] );
48
+ int fieldIndex = schema . fieldIndex ( options .sparseVectorValueFields [i ]) ;
49
+ StructField field = schema . fields ()[ fieldIndex ] ;
50
+ float [] values = extractFloatArray (record , fieldIndex , field . dataType () );
48
51
52
+ fieldIndex = schema .fieldIndex (options .sparseVectorIndexFields [i ]);
53
+ field = schema .fields ()[fieldIndex ];
54
+ int [] indices = extractIntArray (record , fieldIndex , field .dataType ());
55
+
56
+ String name = options .sparseVectorNames [i ];
49
57
sparseVectors .put (name , vector (Floats .asList (values ), Ints .asList (indices )));
50
58
}
51
59
@@ -57,8 +65,11 @@ private static Vectors prepareDenseVectors(
57
65
Map <String , Vector > denseVectors = new HashMap <>();
58
66
59
67
for (int i = 0 ; i < options .vectorNames .length ; i ++) {
68
+ int fieldIndex = schema .fieldIndex (options .vectorFields [i ]);
69
+ StructField field = schema .fields ()[fieldIndex ];
70
+ float [] values = extractFloatArray (record , fieldIndex , field .dataType ());
71
+
60
72
String name = options .vectorNames [i ];
61
- float [] values = extractFloatArray (record , schema , options .vectorFields [i ]);
62
73
denseVectors .put (name , vector (values ));
63
74
}
64
75
@@ -70,29 +81,42 @@ private static Vectors prepareMultiVectors(
70
81
Map <String , Vector > multiVectors = new HashMap <>();
71
82
72
83
for (int i = 0 ; i < options .multiVectorNames .length ; i ++) {
73
- String name = options .multiVectorNames [i ];
74
- float [][] vectors = extractMultiVecArray (record , schema , options .multiVectorFields [i ]);
84
+ int fieldIndex = schema .fieldIndex (options .multiVectorFields [i ]);
85
+ StructField field = schema .fields ()[fieldIndex ];
86
+ float [][] vectors = extractMultiVecArray (record , fieldIndex , field .dataType ());
75
87
88
+ String name = options .multiVectorNames [i ];
76
89
multiVectors .put (name , multiVector (vectors ));
77
90
}
78
91
79
92
return namedVectors (multiVectors );
80
93
}
81
94
82
- private static float [] extractFloatArray (
83
- InternalRow record , StructType schema , String fieldName ) {
84
- int fieldIndex = schema .fieldIndex (fieldName );
95
+ private static float [] extractFloatArray (InternalRow record , int fieldIndex , DataType dataType ) {
96
+
97
+ if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
98
+ throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
99
+ }
100
+
85
101
return record .getArray (fieldIndex ).toFloatArray ();
86
102
}
87
103
88
- private static int [] extractIntArray (InternalRow record , StructType schema , String fieldName ) {
89
- int fieldIndex = schema .fieldIndex (fieldName );
104
+ private static int [] extractIntArray (InternalRow record , int fieldIndex , DataType dataType ) {
105
+
106
+ if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
107
+ throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
108
+ }
109
+
90
110
return record .getArray (fieldIndex ).toIntArray ();
91
111
}
92
112
93
113
private static float [][] extractMultiVecArray (
94
- InternalRow record , StructType schema , String fieldName ) {
95
- int fieldIndex = schema .fieldIndex (fieldName );
114
+ InternalRow record , int fieldIndex , DataType dataType ) {
115
+
116
+ if (!dataType .typeName ().equalsIgnoreCase ("array" )) {
117
+ throw new IllegalArgumentException ("Vector field must be of type ArrayType" );
118
+ }
119
+
96
120
ArrayData arrayData = record .getArray (fieldIndex );
97
121
int numRows = arrayData .numElements ();
98
122
ArrayData firstRow = arrayData .getArray (0 );
0 commit comments