Skip to content

Commit 32f1dbd

Browse files
committed
resolved merge conflict with master in SnappySessionState
2 parents 4be42d3 + 5708864 commit 32f1dbd

File tree

400 files changed

+6332
-1708
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

400 files changed

+6332
-1708
lines changed

cluster/src/dunit/scala/io/snappydata/cluster/QueryRoutingDUnitTest.scala

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
package io.snappydata.cluster
1919

2020
import java.io.File
21+
import java.math.BigDecimal
2122
import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, SQLException, Statement}
2223

2324
import com.gemstone.gemfire.distributed.DistributedMember
25+
2426
import scala.collection.mutable
2527
import scala.collection.JavaConverters._
26-
2728
import com.gemstone.gemfire.distributed.internal.membership.InternalDistributedMember
2829
import com.gemstone.gemfire.internal.cache.PartitionedRegion
2930
import com.pivotal.gemfirexd.internal.engine.Misc
@@ -34,7 +35,6 @@ import io.snappydata.test.dunit.{AvailablePortHelper, SerializableRunnable}
3435
import junit.framework.TestCase
3536
import org.apache.commons.io.FileUtils
3637
import org.junit.Assert
37-
3838
import org.apache.spark.Logging
3939
import org.apache.spark.sql.catalyst.TableIdentifier
4040
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -858,6 +858,44 @@ class QueryRoutingDUnitTest(val s: String)
858858
session.dropTable(table)
859859
}
860860

861+
def testSNAP2247(): Unit = {
862+
val serverHostPort = AvailablePortHelper.getRandomAvailableTCPPort
863+
vm2.invoke(classOf[ClusterManagerTestBase], "startNetServer", serverHostPort)
864+
val conn = DriverManager.getConnection(
865+
"jdbc:snappydata://localhost:" + serverHostPort)
866+
val st = conn.createStatement()
867+
try {
868+
val conn = DriverManager.getConnection(
869+
"jdbc:snappydata://localhost:" + serverHostPort)
870+
871+
val st = conn.createStatement()
872+
st.execute(s"create table trade.securities " +
873+
s"(sec_id int not null, symbol varchar(10) not null, " +
874+
s"price decimal (30, 20), exchange varchar(10) not null, " +
875+
s"tid int, constraint sec_pk primary key (sec_id), " +
876+
s"constraint sec_uq unique (symbol, exchange), constraint exc_ch check " +
877+
s"(exchange in ('nasdaq', 'nye', 'amex', 'lse', 'fse', 'hkse', 'tse'))) " +
878+
s"ENABLE CONCURRENCY CHECKS")
879+
880+
val ps = conn.prepareStatement(s"select price, symbol, exchange from trade.securities" +
881+
s" where (price<? or price >=?) and tid =? order by CASE when exchange ='nasdaq'" +
882+
s" then symbol END desc, CASE when exchange in('nye', 'amex') then sec_id END desc," +
883+
s" CASE when exchange ='lse' then symbol END asc, CASE when exchange ='fse' then" +
884+
s" sec_id END desc, CASE when exchange ='hkse' then symbol END asc," +
885+
s" CASE when exchange ='tse' then symbol END desc")
886+
887+
ps.setBigDecimal(1, new BigDecimal("0.02"))
888+
ps.setBigDecimal(2, new BigDecimal("20.02"))
889+
ps.setInt(3, 3)
890+
891+
ps.execute()
892+
assert(!ps.getResultSet.next())
893+
} finally {
894+
st.execute(s"drop table trade.securities")
895+
conn.close()
896+
}
897+
}
898+
861899
def limitInsertRows(numRows: Int, serverHostPort: Int, tableName: String): Unit = {
862900

863901
val conn = DriverManager.getConnection(

cluster/src/dunit/scala/org/apache/spark/sql/TPCHDUnitTest.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,14 @@ object TPCHUtils extends Logging {
349349
def validateResult(snc: SQLContext, isSnappy: Boolean, isTokenization: Boolean = false): Unit = {
350350
val sc: SparkContext = snc.sparkContext
351351

352-
val fineName = if (!isTokenization) {
352+
val fileName = if (!isTokenization) {
353353
if (isSnappy) "Result_Snappy.out" else "Result_Spark.out"
354354
} else {
355355
"Result_Snappy_Tokenization.out"
356356
}
357357

358-
val resultFileStream: FileOutputStream = new FileOutputStream(new File(fineName))
359-
val resultOutputStream: PrintStream = new PrintStream(resultFileStream)
358+
val resultsLogFileStream: FileOutputStream = new FileOutputStream(new File(fileName))
359+
val resultsLogStream: PrintStream = new PrintStream(resultsLogFileStream)
360360

361361
// scalastyle:off
362362
for (query <- queries) {
@@ -366,24 +366,25 @@ object TPCHUtils extends Logging {
366366
val expectedFile = sc.textFile(getClass.getResource(
367367
s"/TPCH/RESULT/Snappy_$query.out").getPath)
368368

369-
val queryFileName = if (isSnappy) s"1_Snappy_$query.out" else s"1_Spark_$query.out"
370-
val actualFile = sc.textFile(queryFileName)
369+
//val queryFileName = if (isSnappy) s"1_Snappy_$query.out" else s"1_Spark_$query.out"
370+
val queryResultsFileName = if (isSnappy) s"1_Snappy_Q${query}_Results.out" else s"1_Spark_Q${query}_Results.out"
371+
val actualFile = sc.textFile(queryResultsFileName)
371372

372373
val expectedLineSet = expectedFile.collect().toList.sorted
373374
val actualLineSet = actualFile.collect().toList.sorted
374375

375376
if (!actualLineSet.equals(expectedLineSet)) {
376377
if (!(expectedLineSet.size == actualLineSet.size)) {
377-
resultOutputStream.println(s"For $query " +
378+
resultsLogStream.println(s"For $query " +
378379
s"result count mismatched observed with " +
379380
s"expected ${expectedLineSet.size} and actual ${actualLineSet.size}")
380381
} else {
381382
for ((expectedLine, actualLine) <- expectedLineSet zip actualLineSet) {
382383
if (!expectedLine.equals(actualLine)) {
383-
resultOutputStream.println(s"For $query result mismatched observed")
384-
resultOutputStream.println(s"Expected : $expectedLine")
385-
resultOutputStream.println(s"Found : $actualLine")
386-
resultOutputStream.println(s"-------------------------------------")
384+
resultsLogStream.println(s"For $query result mismatched observed")
385+
resultsLogStream.println(s"Expected : $expectedLine")
386+
resultsLogStream.println(s"Found : $actualLine")
387+
resultsLogStream.println(s"-------------------------------------")
387388
}
388389
}
389390
}
@@ -399,16 +400,16 @@ object TPCHUtils extends Logging {
399400
val actualLineSet = secondRunFile.collect().toList.sorted
400401

401402
if (actualLineSet.equals(expectedLineSet)) {
402-
resultOutputStream.println(s"For $query result matched observed")
403-
resultOutputStream.println(s"-------------------------------------")
403+
resultsLogStream.println(s"For $query result matched observed")
404+
resultsLogStream.println(s"-------------------------------------")
404405
}
405406
}
406407
}
407408
// scalastyle:on
408-
resultOutputStream.close()
409-
resultFileStream.close()
409+
resultsLogStream.close()
410+
resultsLogFileStream.close()
410411

411-
val resultOutputFile = sc.textFile(fineName)
412+
val resultOutputFile = sc.textFile(fileName)
412413

413414
if(!isTokenization) {
414415
assert(resultOutputFile.count() == 0,
@@ -433,11 +434,8 @@ object TPCHUtils extends Logging {
433434
fileName: String = ""): Unit = {
434435
snc.sql(s"set spark.sql.crossJoin.enabled = true")
435436

436-
// queries.foreach(query => TPCH_Snappy.execute(query, snc,
437-
// isResultCollection, isSnappy, warmup = warmup,
438-
// runsForAverage = runsForAverage, avgPrintStream = System.out))
439437
queries.foreach(query => QueryExecutor.execute(query, snc, isResultCollection,
440438
isSnappy, isDynamic = isDynamic, warmup = warmup, runsForAverage = runsForAverage,
441-
avgPrintStream = System.out))
439+
avgTimePrintStream = System.out))
442440
}
443441
}

cluster/src/main/scala/io/snappydata/gemxd/SparkSQLExecuteImpl.scala

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import com.pivotal.gemfirexd.internal.snappy.{LeadNodeExecutionContext, SparkSQL
3838
import io.snappydata.{Constant, QueryHint}
3939

4040
import org.apache.spark.serializer.{KryoSerializerPool, StructTypeSerializer}
41+
import org.apache.spark.sql.catalyst.expressions
4142
import org.apache.spark.sql.catalyst.util.DateTimeUtils
4243
import org.apache.spark.sql.collection.Utils
4344
import org.apache.spark.sql.types._
@@ -88,17 +89,9 @@ class SparkSQLExecuteImpl(val sql: String,
8889
private[this] lazy val colTypes = getColumnTypes
8990

9091
// check for query hint to serialize complex types as JSON strings
91-
private[this] val complexTypeAsJson = session.getPreviousQueryHints.get(
92-
QueryHint.ComplexTypeAsJson.toString) match {
93-
case null => true
94-
case v => Misc.parseBoolean(v)
95-
}
92+
private[this] val complexTypeAsJson = SparkSQLExecuteImpl.getJsonProperties(session)
9693

97-
private val (allAsClob, columnsAsClob) = session.getPreviousQueryHints.get(
98-
QueryHint.ColumnsAsClob.toString) match {
99-
case null => (false, Set.empty[String])
100-
case v => Utils.parseColumnsAsClob(v)
101-
}
94+
private val (allAsClob, columnsAsClob) = SparkSQLExecuteImpl.getClobProperties(session)
10295

10396
override def packRows(msg: LeadNodeExecutorMsg,
10497
snappyResultHolder: SnappyResultHolder): Unit = {
@@ -121,7 +114,8 @@ class SparkSQLExecuteImpl(val sql: String,
121114
CachedDataFrame.localBlockStoreResultHandler(rddId, bm),
122115
CachedDataFrame.localBlockStoreDecoder(querySchema.length, bm))
123116
hdos.clearForReuse()
124-
writeMetaData(srh)
117+
SparkSQLExecuteImpl.writeMetaData(srh, hdos, tableNames, nullability, getColumnNames,
118+
getColumnTypes, getColumnDataTypes)
125119

126120
var id = 0
127121
for (block <- partitionBlocks) {
@@ -191,77 +185,53 @@ class SparkSQLExecuteImpl(val sql: String,
191185
override def serializeRows(out: DataOutput, hasMetadata: Boolean): Unit =
192186
SparkSQLExecuteImpl.serializeRows(out, hasMetadata, hdos)
193187

194-
private lazy val (tableNames, nullability) = getTableNamesAndNullability
195-
196-
def getTableNamesAndNullability: (Array[String], Array[Boolean]) = {
197-
var i = 0
198-
val output = df.queryExecution.analyzed.output
199-
val tables = new Array[String](output.length)
200-
val nullables = new Array[Boolean](output.length)
201-
output.foreach { a =>
202-
val fn = a.qualifiedName
203-
val dotIdx = fn.lastIndexOf('.')
204-
if (dotIdx > 0) {
205-
tables(i) = fn.substring(0, dotIdx)
206-
} else {
207-
tables(i) = ""
208-
}
209-
nullables(i) = a.nullable
210-
i += 1
211-
}
212-
(tables, nullables)
213-
}
214-
215-
private def writeMetaData(srh: SnappyResultHolder): Unit = {
216-
val hdos = this.hdos
217-
// indicates that the metadata is being packed too
218-
srh.setHasMetadata()
219-
DataSerializer.writeStringArray(tableNames, hdos)
220-
DataSerializer.writeStringArray(getColumnNames, hdos)
221-
DataSerializer.writeBooleanArray(nullability, hdos)
222-
for (i <- colTypes.indices) {
223-
val (tp, precision, scale) = colTypes(i)
224-
InternalDataSerializer.writeSignedVL(tp, hdos)
225-
tp match {
226-
case StoredFormatIds.SQL_DECIMAL_ID =>
227-
InternalDataSerializer.writeSignedVL(precision, hdos) // precision
228-
InternalDataSerializer.writeSignedVL(scale, hdos) // scale
229-
case StoredFormatIds.SQL_VARCHAR_ID |
230-
StoredFormatIds.SQL_CHAR_ID =>
231-
// Write the size as precision
232-
InternalDataSerializer.writeSignedVL(precision, hdos)
233-
case StoredFormatIds.REF_TYPE_ID =>
234-
// Write the DataType
235-
hdos.write(KryoSerializerPool.serialize((kryo, out) =>
236-
StructTypeSerializer.writeType(kryo, out, querySchema(i).dataType)))
237-
case _ => // ignore for others
238-
}
239-
}
240-
}
188+
private lazy val (tableNames, nullability) = SparkSQLExecuteImpl.
189+
getTableNamesAndNullability(df.queryExecution.analyzed.output)
241190

242191
def getColumnNames: Array[String] = {
243192
querySchema.fieldNames
244193
}
245194

246195
private def getColumnTypes: Array[(Int, Int, Int)] =
247-
querySchema.map(f => getSQLType(f)).toArray
196+
querySchema.map(f => SparkSQLExecuteImpl.getSQLType(f.dataType, complexTypeAsJson,
197+
Some(f.metadata), Some(f.name), Some(allAsClob), Some(columnsAsClob))).toArray
198+
199+
private def getColumnDataTypes: Array[DataType] =
200+
querySchema.map(_.dataType).toArray
201+
}
202+
203+
object SparkSQLExecuteImpl {
248204

249-
private def getSQLType(f: StructField): (Int, Int, Int) = {
250-
val dataType = f.dataType
205+
def getJsonProperties(session: SnappySession): Boolean = session.getPreviousQueryHints.get(
206+
QueryHint.ComplexTypeAsJson.toString) match {
207+
case null => true
208+
case v => Misc.parseBoolean(v)
209+
}
210+
211+
def getClobProperties(session: SnappySession): (Boolean, Set[String]) =
212+
session.getPreviousQueryHints.get(QueryHint.ColumnsAsClob.toString) match {
213+
case null => (false, Set.empty[String])
214+
case v => Utils.parseColumnsAsClob(v)
215+
}
216+
217+
def getSQLType(dataType: DataType, complexTypeAsJson: Boolean,
218+
metaData: Option[Metadata] = None, metaName: Option[String] = None,
219+
allAsClob: Option[Boolean] = None, columnsAsClob: Option[Set[String]] = None): (Int,
220+
Int, Int) = {
251221
dataType match {
252222
case IntegerType => (StoredFormatIds.SQL_INTEGER_ID, -1, -1)
253-
case StringType =>
223+
case StringType if metaData.isDefined =>
254224
TypeUtilities.getMetadata[String](Constant.CHAR_TYPE_BASE_PROP,
255-
f.metadata) match {
225+
metaData.get) match {
256226
case Some(base) if base != "CLOB" =>
257227
lazy val size = TypeUtilities.getMetadata[Long](
258-
Constant.CHAR_TYPE_SIZE_PROP, f.metadata)
228+
Constant.CHAR_TYPE_SIZE_PROP, metaData.get)
259229
lazy val varcharSize = size.getOrElse(
260230
Constant.MAX_VARCHAR_SIZE.toLong).toInt
261231
lazy val charSize = size.getOrElse(
262232
Constant.MAX_CHAR_SIZE.toLong).toInt
263-
if (allAsClob ||
264-
(columnsAsClob.nonEmpty && columnsAsClob.contains(f.name))) {
233+
if (allAsClob.get ||
234+
(columnsAsClob.get.nonEmpty && columnsAsClob.get.contains(metaName.get))) {
265235
if (base != "STRING") {
266236
if (base == "VARCHAR") {
267237
(StoredFormatIds.SQL_VARCHAR_ID, varcharSize, -1)
@@ -282,6 +252,7 @@ class SparkSQLExecuteImpl(val sql: String,
282252

283253
case _ => (StoredFormatIds.SQL_CLOB_ID, -1, -1) // CLOB
284254
}
255+
case StringType => (StoredFormatIds.SQL_CLOB_ID, -1, -1) // CLOB
285256
case LongType => (StoredFormatIds.SQL_LONGINT_ID, -1, -1)
286257
case TimestampType => (StoredFormatIds.SQL_TIMESTAMP_ID, -1, -1)
287258
case DateType => (StoredFormatIds.SQL_DATE_ID, -1, -1)
@@ -302,9 +273,48 @@ class SparkSQLExecuteImpl(val sql: String,
302273
case _ => (StoredFormatIds.REF_TYPE_ID, -1, -1)
303274
}
304275
}
305-
}
306276

307-
object SparkSQLExecuteImpl {
277+
def getTableNamesAndNullability(output: Seq[expressions.Attribute]):
278+
(Seq[String], Seq[Boolean]) = {
279+
output.map { a =>
280+
val fn = a.qualifiedName
281+
val dotIdx = fn.lastIndexOf('.')
282+
if (dotIdx > 0) {
283+
(fn.substring(0, dotIdx), a.nullable)
284+
} else {
285+
("", a.nullable)
286+
}
287+
}.unzip
288+
}
289+
290+
def writeMetaData(srh: SnappyResultHolder, hdos: GfxdHeapDataOutputStream,
291+
tableNames: Seq[String], nullability: Seq[Boolean], columnNames: Seq[String],
292+
colTypes: Seq[(Int, Int, Int)], dataTypes: Seq[DataType]): Unit = {
293+
// indicates that the metadata is being packed too
294+
srh.setHasMetadata()
295+
DataSerializer.writeStringArray(tableNames.toArray, hdos)
296+
DataSerializer.writeStringArray(columnNames.toArray, hdos)
297+
DataSerializer.writeBooleanArray(nullability.toArray, hdos)
298+
for (i <- colTypes.indices) {
299+
val (tp, precision, scale) = colTypes(i)
300+
InternalDataSerializer.writeSignedVL(tp, hdos)
301+
tp match {
302+
case StoredFormatIds.SQL_DECIMAL_ID =>
303+
InternalDataSerializer.writeSignedVL(precision, hdos) // precision
304+
InternalDataSerializer.writeSignedVL(scale, hdos) // scale
305+
case StoredFormatIds.SQL_VARCHAR_ID |
306+
StoredFormatIds.SQL_CHAR_ID =>
307+
// Write the size as precision
308+
InternalDataSerializer.writeSignedVL(precision, hdos)
309+
case StoredFormatIds.REF_TYPE_ID =>
310+
// Write the DataType
311+
hdos.write(KryoSerializerPool.serialize((kryo, out) =>
312+
StructTypeSerializer.writeType(kryo, out, dataTypes(i))))
313+
case _ => // ignore for others
314+
}
315+
}
316+
}
317+
308318
def getContextOrCurrentClassLoader: ClassLoader =
309319
Option(Thread.currentThread().getContextClassLoader)
310320
.getOrElse(getClass.getClassLoader)

0 commit comments

Comments
 (0)