@@ -38,6 +38,7 @@ import com.pivotal.gemfirexd.internal.snappy.{LeadNodeExecutionContext, SparkSQL
38
38
import io .snappydata .{Constant , QueryHint }
39
39
40
40
import org .apache .spark .serializer .{KryoSerializerPool , StructTypeSerializer }
41
+ import org .apache .spark .sql .catalyst .expressions
41
42
import org .apache .spark .sql .catalyst .util .DateTimeUtils
42
43
import org .apache .spark .sql .collection .Utils
43
44
import org .apache .spark .sql .types ._
@@ -88,17 +89,9 @@ class SparkSQLExecuteImpl(val sql: String,
88
89
private [this ] lazy val colTypes = getColumnTypes
89
90
90
91
// check for query hint to serialize complex types as JSON strings
91
- private [this ] val complexTypeAsJson = session.getPreviousQueryHints.get(
92
- QueryHint .ComplexTypeAsJson .toString) match {
93
- case null => true
94
- case v => Misc .parseBoolean(v)
95
- }
92
+ private [this ] val complexTypeAsJson = SparkSQLExecuteImpl .getJsonProperties(session)
96
93
97
- private val (allAsClob, columnsAsClob) = session.getPreviousQueryHints.get(
98
- QueryHint .ColumnsAsClob .toString) match {
99
- case null => (false , Set .empty[String ])
100
- case v => Utils .parseColumnsAsClob(v)
101
- }
94
+ private val (allAsClob, columnsAsClob) = SparkSQLExecuteImpl .getClobProperties(session)
102
95
103
96
override def packRows (msg : LeadNodeExecutorMsg ,
104
97
snappyResultHolder : SnappyResultHolder ): Unit = {
@@ -121,7 +114,8 @@ class SparkSQLExecuteImpl(val sql: String,
121
114
CachedDataFrame .localBlockStoreResultHandler(rddId, bm),
122
115
CachedDataFrame .localBlockStoreDecoder(querySchema.length, bm))
123
116
hdos.clearForReuse()
124
- writeMetaData(srh)
117
+ SparkSQLExecuteImpl .writeMetaData(srh, hdos, tableNames, nullability, getColumnNames,
118
+ getColumnTypes, getColumnDataTypes)
125
119
126
120
var id = 0
127
121
for (block <- partitionBlocks) {
@@ -191,77 +185,53 @@ class SparkSQLExecuteImpl(val sql: String,
191
185
override def serializeRows (out : DataOutput , hasMetadata : Boolean ): Unit =
192
186
SparkSQLExecuteImpl .serializeRows(out, hasMetadata, hdos)
193
187
194
- private lazy val (tableNames, nullability) = getTableNamesAndNullability
195
-
196
- def getTableNamesAndNullability : (Array [String ], Array [Boolean ]) = {
197
- var i = 0
198
- val output = df.queryExecution.analyzed.output
199
- val tables = new Array [String ](output.length)
200
- val nullables = new Array [Boolean ](output.length)
201
- output.foreach { a =>
202
- val fn = a.qualifiedName
203
- val dotIdx = fn.lastIndexOf('.' )
204
- if (dotIdx > 0 ) {
205
- tables(i) = fn.substring(0 , dotIdx)
206
- } else {
207
- tables(i) = " "
208
- }
209
- nullables(i) = a.nullable
210
- i += 1
211
- }
212
- (tables, nullables)
213
- }
214
-
215
- private def writeMetaData (srh : SnappyResultHolder ): Unit = {
216
- val hdos = this .hdos
217
- // indicates that the metadata is being packed too
218
- srh.setHasMetadata()
219
- DataSerializer .writeStringArray(tableNames, hdos)
220
- DataSerializer .writeStringArray(getColumnNames, hdos)
221
- DataSerializer .writeBooleanArray(nullability, hdos)
222
- for (i <- colTypes.indices) {
223
- val (tp, precision, scale) = colTypes(i)
224
- InternalDataSerializer .writeSignedVL(tp, hdos)
225
- tp match {
226
- case StoredFormatIds .SQL_DECIMAL_ID =>
227
- InternalDataSerializer .writeSignedVL(precision, hdos) // precision
228
- InternalDataSerializer .writeSignedVL(scale, hdos) // scale
229
- case StoredFormatIds .SQL_VARCHAR_ID |
230
- StoredFormatIds .SQL_CHAR_ID =>
231
- // Write the size as precision
232
- InternalDataSerializer .writeSignedVL(precision, hdos)
233
- case StoredFormatIds .REF_TYPE_ID =>
234
- // Write the DataType
235
- hdos.write(KryoSerializerPool .serialize((kryo, out) =>
236
- StructTypeSerializer .writeType(kryo, out, querySchema(i).dataType)))
237
- case _ => // ignore for others
238
- }
239
- }
240
- }
188
+ private lazy val (tableNames, nullability) = SparkSQLExecuteImpl .
189
+ getTableNamesAndNullability(df.queryExecution.analyzed.output)
241
190
242
191
def getColumnNames : Array [String ] = {
243
192
querySchema.fieldNames
244
193
}
245
194
246
195
private def getColumnTypes : Array [(Int , Int , Int )] =
247
- querySchema.map(f => getSQLType(f)).toArray
196
+ querySchema.map(f => SparkSQLExecuteImpl .getSQLType(f.dataType, complexTypeAsJson,
197
+ Some (f.metadata), Some (f.name), Some (allAsClob), Some (columnsAsClob))).toArray
198
+
199
+ private def getColumnDataTypes : Array [DataType ] =
200
+ querySchema.map(_.dataType).toArray
201
+ }
202
+
203
+ object SparkSQLExecuteImpl {
248
204
249
- private def getSQLType (f : StructField ): (Int , Int , Int ) = {
250
- val dataType = f.dataType
205
+ def getJsonProperties (session : SnappySession ): Boolean = session.getPreviousQueryHints.get(
206
+ QueryHint .ComplexTypeAsJson .toString) match {
207
+ case null => true
208
+ case v => Misc .parseBoolean(v)
209
+ }
210
+
211
+ def getClobProperties (session : SnappySession ): (Boolean , Set [String ]) =
212
+ session.getPreviousQueryHints.get(QueryHint .ColumnsAsClob .toString) match {
213
+ case null => (false , Set .empty[String ])
214
+ case v => Utils .parseColumnsAsClob(v)
215
+ }
216
+
217
+ def getSQLType (dataType : DataType , complexTypeAsJson : Boolean ,
218
+ metaData : Option [Metadata ] = None , metaName : Option [String ] = None ,
219
+ allAsClob : Option [Boolean ] = None , columnsAsClob : Option [Set [String ]] = None ): (Int ,
220
+ Int , Int ) = {
251
221
dataType match {
252
222
case IntegerType => (StoredFormatIds .SQL_INTEGER_ID , - 1 , - 1 )
253
- case StringType =>
223
+ case StringType if metaData.isDefined =>
254
224
TypeUtilities .getMetadata[String ](Constant .CHAR_TYPE_BASE_PROP ,
255
- f.metadata ) match {
225
+ metaData.get ) match {
256
226
case Some (base) if base != " CLOB" =>
257
227
lazy val size = TypeUtilities .getMetadata[Long ](
258
- Constant .CHAR_TYPE_SIZE_PROP , f.metadata )
228
+ Constant .CHAR_TYPE_SIZE_PROP , metaData.get )
259
229
lazy val varcharSize = size.getOrElse(
260
230
Constant .MAX_VARCHAR_SIZE .toLong).toInt
261
231
lazy val charSize = size.getOrElse(
262
232
Constant .MAX_CHAR_SIZE .toLong).toInt
263
- if (allAsClob ||
264
- (columnsAsClob.nonEmpty && columnsAsClob.contains(f.name ))) {
233
+ if (allAsClob.get ||
234
+ (columnsAsClob.get. nonEmpty && columnsAsClob.get. contains(metaName.get ))) {
265
235
if (base != " STRING" ) {
266
236
if (base == " VARCHAR" ) {
267
237
(StoredFormatIds .SQL_VARCHAR_ID , varcharSize, - 1 )
@@ -282,6 +252,7 @@ class SparkSQLExecuteImpl(val sql: String,
282
252
283
253
case _ => (StoredFormatIds .SQL_CLOB_ID , - 1 , - 1 ) // CLOB
284
254
}
255
+ case StringType => (StoredFormatIds .SQL_CLOB_ID , - 1 , - 1 ) // CLOB
285
256
case LongType => (StoredFormatIds .SQL_LONGINT_ID , - 1 , - 1 )
286
257
case TimestampType => (StoredFormatIds .SQL_TIMESTAMP_ID , - 1 , - 1 )
287
258
case DateType => (StoredFormatIds .SQL_DATE_ID , - 1 , - 1 )
@@ -302,9 +273,48 @@ class SparkSQLExecuteImpl(val sql: String,
302
273
case _ => (StoredFormatIds .REF_TYPE_ID , - 1 , - 1 )
303
274
}
304
275
}
305
- }
306
276
307
- object SparkSQLExecuteImpl {
277
+ def getTableNamesAndNullability (output : Seq [expressions.Attribute ]):
278
+ (Seq [String ], Seq [Boolean ]) = {
279
+ output.map { a =>
280
+ val fn = a.qualifiedName
281
+ val dotIdx = fn.lastIndexOf('.' )
282
+ if (dotIdx > 0 ) {
283
+ (fn.substring(0 , dotIdx), a.nullable)
284
+ } else {
285
+ (" " , a.nullable)
286
+ }
287
+ }.unzip
288
+ }
289
+
290
+ def writeMetaData (srh : SnappyResultHolder , hdos : GfxdHeapDataOutputStream ,
291
+ tableNames : Seq [String ], nullability : Seq [Boolean ], columnNames : Seq [String ],
292
+ colTypes : Seq [(Int , Int , Int )], dataTypes : Seq [DataType ]): Unit = {
293
+ // indicates that the metadata is being packed too
294
+ srh.setHasMetadata()
295
+ DataSerializer .writeStringArray(tableNames.toArray, hdos)
296
+ DataSerializer .writeStringArray(columnNames.toArray, hdos)
297
+ DataSerializer .writeBooleanArray(nullability.toArray, hdos)
298
+ for (i <- colTypes.indices) {
299
+ val (tp, precision, scale) = colTypes(i)
300
+ InternalDataSerializer .writeSignedVL(tp, hdos)
301
+ tp match {
302
+ case StoredFormatIds .SQL_DECIMAL_ID =>
303
+ InternalDataSerializer .writeSignedVL(precision, hdos) // precision
304
+ InternalDataSerializer .writeSignedVL(scale, hdos) // scale
305
+ case StoredFormatIds .SQL_VARCHAR_ID |
306
+ StoredFormatIds .SQL_CHAR_ID =>
307
+ // Write the size as precision
308
+ InternalDataSerializer .writeSignedVL(precision, hdos)
309
+ case StoredFormatIds .REF_TYPE_ID =>
310
+ // Write the DataType
311
+ hdos.write(KryoSerializerPool .serialize((kryo, out) =>
312
+ StructTypeSerializer .writeType(kryo, out, dataTypes(i))))
313
+ case _ => // ignore for others
314
+ }
315
+ }
316
+ }
317
+
308
318
def getContextOrCurrentClassLoader : ClassLoader =
309
319
Option (Thread .currentThread().getContextClassLoader)
310
320
.getOrElse(getClass.getClassLoader)
0 commit comments