|
3 | 3 | import java.io.Serializable;
|
4 | 4 | import java.util.ArrayList;
|
5 | 5 | import java.util.HashMap;
|
6 |
| -import java.util.Map; |
7 | 6 | import java.util.UUID;
|
8 | 7 | import org.apache.spark.sql.catalyst.InternalRow;
|
9 |
| -import org.apache.spark.sql.catalyst.util.ArrayData; |
10 | 8 | import org.apache.spark.sql.connector.write.DataWriter;
|
11 | 9 | import org.apache.spark.sql.connector.write.WriterCommitMessage;
|
12 |
| -import org.apache.spark.sql.types.ArrayType; |
13 | 10 | import org.apache.spark.sql.types.DataType;
|
14 |
| -import org.apache.spark.sql.types.DataTypes; |
15 | 11 | import org.apache.spark.sql.types.StructField;
|
16 | 12 | import org.apache.spark.sql.types.StructType;
|
17 | 13 | import org.slf4j.Logger;
|
18 | 14 | import org.slf4j.LoggerFactory;
|
19 | 15 |
|
| 16 | +import static io.qdrant.spark.ObjectFactory.object; |
| 17 | + |
20 | 18 | /**
|
21 | 19 | * A DataWriter implementation that writes data to Qdrant, a vector search
|
22 | 20 | * engine. This class takes
|
@@ -53,12 +51,24 @@ public void write(InternalRow record) {
|
53 | 51 | for (StructField field : this.schema.fields()) {
|
54 | 52 | int fieldIndex = this.schema.fieldIndex(field.name());
|
55 | 53 | if (this.options.idField != null && field.name().equals(this.options.idField)) {
|
56 |
| - point.id = record.get(fieldIndex, field.dataType()).toString(); |
| 54 | + |
| 55 | + DataType dataType = field.dataType(); |
| 56 | + switch (dataType.typeName()) { |
| 57 | + case "string": |
| 58 | + point.id = record.getString(fieldIndex); |
| 59 | + break; |
| 60 | + |
| 61 | + case "integer": |
| 62 | + point.id = record.getInt(fieldIndex); |
| 63 | + break; |
| 64 | + default: |
| 65 | + throw new IllegalArgumentException("Point ID should be of type string or integer"); |
| 66 | + } |
57 | 67 | } else if (field.name().equals(this.options.embeddingField)) {
|
58 |
| - float[] vector = record.getArray(fieldIndex).toFloatArray(); |
59 |
| - point.vector = vector; |
| 68 | + point.vector = record.getArray(fieldIndex).toFloatArray(); |
| 69 | + |
60 | 70 | } else {
|
61 |
| - payload.put(field.name(), convertToJavaType(record, field, fieldIndex)); |
| 71 | + payload.put(field.name(), object(record, field, fieldIndex)); |
62 | 72 | }
|
63 | 73 | }
|
64 | 74 |
|
@@ -107,62 +117,10 @@ public void abort() {
|
107 | 117 | @Override
|
108 | 118 | public void close() {
|
109 | 119 | }
|
110 |
| - |
111 |
| - private Object convertToJavaType(InternalRow record, StructField field, int fieldIndex) { |
112 |
| - DataType dataType = field.dataType(); |
113 |
| - |
114 |
| - if (dataType == DataTypes.StringType) { |
115 |
| - return record.getString(fieldIndex); |
116 |
| - } else if (dataType == DataTypes.DateType || dataType == DataTypes.TimestampType) { |
117 |
| - return record.getString(fieldIndex); |
118 |
| - } else if (dataType instanceof ArrayType) { |
119 |
| - ArrayType arrayType = (ArrayType) dataType; |
120 |
| - ArrayData arrayData = record.getArray(fieldIndex); |
121 |
| - return convertArrayToJavaType(arrayData, arrayType.elementType()); |
122 |
| - } else if (dataType instanceof StructType) { |
123 |
| - StructType structType = (StructType) dataType; |
124 |
| - InternalRow structData = record.getStruct(fieldIndex, structType.fields().length); |
125 |
| - return convertStructToJavaType(structData, structType); |
126 |
| - } |
127 |
| - |
128 |
| - // Fall back to the generic get method |
129 |
| - return record.get(fieldIndex, dataType); |
130 |
| - } |
131 |
| - |
132 |
| - private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType) { |
133 |
| - if (elementType == DataTypes.StringType) { |
134 |
| - int length = arrayData.numElements(); |
135 |
| - String[] result = new String[length]; |
136 |
| - for (int i = 0; i < length; i++) { |
137 |
| - result[i] = arrayData.getUTF8String(i).toString(); |
138 |
| - } |
139 |
| - return result; |
140 |
| - } else if (elementType instanceof StructType) { |
141 |
| - StructType structType = (StructType) elementType; |
142 |
| - int length = arrayData.numElements(); |
143 |
| - Object[] result = new Object[length]; |
144 |
| - for (int i = 0; i < length; i++) { |
145 |
| - InternalRow structData = arrayData.getStruct(i, structType.fields().length); |
146 |
| - result[i] = convertStructToJavaType(structData, structType); |
147 |
| - } |
148 |
| - return result; |
149 |
| - } else { |
150 |
| - return arrayData.toObjectArray(elementType); |
151 |
| - } |
152 |
| - } |
153 |
| - |
154 |
| - private Object convertStructToJavaType(InternalRow structData, StructType structType) { |
155 |
| - Map<String, Object> result = new HashMap<>(); |
156 |
| - for (int i = 0; i < structType.fields().length; i++) { |
157 |
| - StructField structField = structType.fields()[i]; |
158 |
| - result.put(structField.name(), convertToJavaType(structData, structField, i)); |
159 |
| - } |
160 |
| - return result; |
161 |
| - } |
162 | 120 | }
|
163 | 121 |
|
164 | 122 | class Point implements Serializable {
|
165 |
| - public String id; |
| 123 | + public Object id; |
166 | 124 | public float[] vector;
|
167 | 125 | public HashMap<String, Object> payload;
|
168 | 126 | }
|
0 commit comments