Skip to content

Commit de5ed17

Browse files
committed
Merge branch 'main' into df51
# Conflicts: # native/spark-expr/src/math_funcs/internal/make_decimal.rs
2 parents 961957f + bf1f3a2 commit de5ed17

11 files changed

Lines changed: 206 additions & 110 deletions

File tree

common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,35 @@ public class CometFileKeyUnwrapper {
101101
// Cache the hadoopConf just to assert the assumption above.
102102
private Configuration conf = null;
103103

104+
/**
105+
* Normalizes S3 URI schemes to a canonical form. S3 can be accessed via multiple schemes (s3://,
106+
* s3a://, s3n://) that refer to the same logical filesystem. This method ensures consistent cache
107+
* lookups regardless of which scheme is used.
108+
*
109+
* @param filePath The file path that may contain an S3 URI
110+
* @return The file path with normalized S3 scheme (s3a://)
111+
*/
112+
private String normalizeS3Scheme(final String filePath) {
113+
// Normalize s3:// and s3n:// to s3a:// for consistent cache lookups
114+
// This handles the case where ObjectStoreUrl uses s3:// but Spark uses s3a://
115+
String s3Prefix = "s3://";
116+
String s3nPrefix = "s3n://";
117+
if (filePath.startsWith(s3Prefix)) {
118+
return "s3a://" + filePath.substring(s3Prefix.length());
119+
} else if (filePath.startsWith(s3nPrefix)) {
120+
return "s3a://" + filePath.substring(s3nPrefix.length());
121+
}
122+
return filePath;
123+
}
124+
104125
/**
105126
* Creates and stores a DecryptionKeyRetriever instance for the given file path.
106127
*
107128
* @param filePath The path to the Parquet file
108129
* @param hadoopConf The Hadoop Configuration to use for this file path
109130
*/
110131
public void storeDecryptionKeyRetriever(final String filePath, final Configuration hadoopConf) {
132+
final String normalizedPath = normalizeS3Scheme(filePath);
111133
// Use DecryptionPropertiesFactory.loadFactory to get the factory and then call
112134
// getFileDecryptionProperties
113135
if (factory == null) {
@@ -122,7 +144,7 @@ public void storeDecryptionKeyRetriever(final String filePath, final Configurati
122144
factory.getFileDecryptionProperties(hadoopConf, path);
123145

124146
DecryptionKeyRetriever keyRetriever = decryptionProperties.getKeyRetriever();
125-
retrieverCache.put(filePath, keyRetriever);
147+
retrieverCache.put(normalizedPath, keyRetriever);
126148
}
127149

128150
/**
@@ -136,7 +158,8 @@ public void storeDecryptionKeyRetriever(final String filePath, final Configurati
136158
*/
137159
public byte[] getKey(final String filePath, final byte[] keyMetadata)
138160
throws ParquetCryptoRuntimeException {
139-
DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath);
161+
final String normalizedPath = normalizeS3Scheme(filePath);
162+
DecryptionKeyRetriever keyRetriever = retrieverCache.get(normalizedPath);
140163
if (keyRetriever == null) {
141164
throw new ParquetCryptoRuntimeException(
142165
"Failed to find DecryptionKeyRetriever for path: " + filePath);

native/spark-expr/src/math_funcs/internal/make_decimal.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,21 @@ pub fn spark_make_decimal(
4040
))),
4141
sv => internal_err!("Expected Int64 but found {sv:?}"),
4242
},
43-
ColumnarValue::Array(a) => {
44-
let arr = a.as_primitive::<Int64Type>();
45-
let mut result = Decimal128Builder::new();
46-
for v in arr.into_iter() {
47-
result.append_option(long_to_decimal(&v, precision, scale))
48-
}
49-
let result_type = DataType::Decimal128(precision, scale);
43+
ColumnarValue::Array(a) => match a.data_type() {
44+
DataType::Int64 => {
45+
let arr = a.as_primitive::<Int64Type>();
46+
let mut result = Decimal128Builder::new();
47+
for v in arr.into_iter() {
48+
result.append_option(long_to_decimal(&v, precision, scale))
49+
}
50+
let result_type = DataType::Decimal128(precision, scale);
5051

51-
Ok(ColumnarValue::Array(Arc::new(
52-
result.finish().with_data_type(result_type),
53-
)))
54-
}
52+
Ok(ColumnarValue::Array(Arc::new(
53+
result.finish().with_data_type(result_type),
54+
)))
55+
}
56+
av => internal_err!("Expected Int64 but found {av:?}"),
57+
},
5558
}
5659
}
5760

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,14 @@ package org.apache.comet
2121

2222
import java.nio.ByteOrder
2323

24-
import scala.collection.mutable.ListBuffer
25-
2624
import org.apache.spark.SparkConf
2725
import org.apache.spark.internal.Logging
2826
import org.apache.spark.network.util.ByteUnit
2927
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
3028
import org.apache.spark.sql.catalyst.rules.Rule
3129
import org.apache.spark.sql.catalyst.trees.TreeNode
3230
import org.apache.spark.sql.comet._
33-
import org.apache.spark.sql.comet.util.Utils
3431
import org.apache.spark.sql.execution._
35-
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
36-
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
37-
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
38-
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
39-
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
40-
import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
41-
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
4232
import org.apache.spark.sql.internal.SQLConf
4333

4434
import org.apache.comet.CometConf._
@@ -76,10 +66,6 @@ class CometSparkSessionExtensions
7666
object CometSparkSessionExtensions extends Logging {
7767
lazy val isBigEndian: Boolean = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)
7868

79-
private[comet] def isANSIEnabled(conf: SQLConf): Boolean = {
80-
conf.getConf(SQLConf.ANSI_ENABLED)
81-
}
82-
8369
/**
8470
* Checks whether Comet extension should be loaded for Spark.
8571
*/
@@ -122,21 +108,6 @@ object CometSparkSessionExtensions extends Logging {
122108
}
123109
}
124110

125-
private[comet] def isCometBroadCastForceEnabled(conf: SQLConf): Boolean = {
126-
COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf)
127-
}
128-
129-
private[comet] def getCometBroadcastNotEnabledReason(conf: SQLConf): Option[String] = {
130-
if (!CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) &&
131-
!isCometBroadCastForceEnabled(conf)) {
132-
Some(
133-
s"${COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.key}.enabled is not specified and " +
134-
s"${COMET_EXEC_BROADCAST_FORCE_ENABLED.key} is not specified")
135-
} else {
136-
None
137-
}
138-
}
139-
140111
// Check whether Comet shuffle is enabled:
141112
// 1. `COMET_EXEC_SHUFFLE_ENABLED` is true
142113
// 2. `spark.shuffle.manager` is set to `CometShuffleManager`
@@ -149,62 +120,10 @@ object CometSparkSessionExtensions extends Logging {
149120
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager"
150121
}
151122

152-
private[comet] def isCometScanEnabled(conf: SQLConf): Boolean = {
153-
COMET_NATIVE_SCAN_ENABLED.get(conf)
154-
}
155-
156-
private[comet] def isCometExecEnabled(conf: SQLConf): Boolean = {
157-
COMET_EXEC_ENABLED.get(conf)
158-
}
159-
160123
def isCometScan(op: SparkPlan): Boolean = {
161124
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
162125
}
163126

164-
def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
165-
// Only consider converting leaf nodes to columnar currently, so that all the following
166-
// operators can have a chance to be converted to columnar. Leaf operators that output
167-
// columnar batches, such as Spark's vectorized readers, will also be converted to native
168-
// comet batches.
169-
val fallbackReasons = new ListBuffer[String]()
170-
if (CometSparkToColumnarExec.isSchemaSupported(op.schema, fallbackReasons)) {
171-
op match {
172-
// Convert Spark DS v1 scan to Arrow format
173-
case scan: FileSourceScanExec =>
174-
scan.relation.fileFormat match {
175-
case _: CSVFileFormat => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
176-
case _: JsonFileFormat => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
177-
case _: ParquetFileFormat => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
178-
case _ => isSparkToArrowEnabled(conf, op)
179-
}
180-
// Convert Spark DS v2 scan to Arrow format
181-
case scan: BatchScanExec =>
182-
scan.scan match {
183-
case _: CSVScan => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
184-
case _: JsonScan => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
185-
case _: ParquetScan => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
186-
case _ => isSparkToArrowEnabled(conf, op)
187-
}
188-
// other leaf nodes
189-
case _: LeafExecNode =>
190-
isSparkToArrowEnabled(conf, op)
191-
case _ =>
192-
// TODO: consider converting other intermediate operators to columnar.
193-
false
194-
}
195-
} else {
196-
false
197-
}
198-
}
199-
200-
private def isSparkToArrowEnabled(conf: SQLConf, op: SparkPlan) = {
201-
COMET_SPARK_TO_ARROW_ENABLED.get(conf) && {
202-
val simpleClassName = Utils.getSimpleName(op.getClass)
203-
val nodeName = simpleClassName.replaceAll("Exec$", "")
204-
COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
205-
}
206-
}
207-
208127
def isSpark35Plus: Boolean = {
209128
org.apache.spark.SPARK_VERSION >= "3.5"
210129
}
@@ -364,12 +283,4 @@ object CometSparkSessionExtensions extends Logging {
364283
node.getTagValue(CometExplainInfo.EXTENSION_INFO).exists(_.nonEmpty)
365284
}
366285

367-
// Helper to reduce boilerplate
368-
def createMessage(condition: Boolean, message: => String): Option[String] = {
369-
if (condition) {
370-
Some(message)
371-
} else {
372-
None
373-
}
374-
}
375286
}

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,35 @@
1919

2020
package org.apache.comet.rules
2121

22+
import scala.collection.mutable.ListBuffer
23+
2224
import org.apache.spark.sql.SparkSession
2325
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
2426
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2527
import org.apache.spark.sql.catalyst.rules.Rule
2628
import org.apache.spark.sql.catalyst.util.sideBySide
2729
import org.apache.spark.sql.comet._
2830
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
31+
import org.apache.spark.sql.comet.util.Utils
2932
import org.apache.spark.sql.execution._
3033
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
3134
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
3235
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
33-
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
36+
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
37+
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
38+
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
39+
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, V2CommandExec}
40+
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
41+
import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
42+
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
3443
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
3544
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
3645
import org.apache.spark.sql.execution.window.WindowExec
46+
import org.apache.spark.sql.internal.SQLConf
3747
import org.apache.spark.sql.types._
3848

3949
import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo}
50+
import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
4051
import org.apache.comet.CometSparkSessionExtensions._
4152
import org.apache.comet.rules.CometExecRule.allExecs
4253
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported}
@@ -211,7 +222,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
211222
}
212223
if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) {
213224
val newPlan = convertNode(plan.withNewChildren(newChildren))
214-
if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) {
225+
if (isCometNative(newPlan) || CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf)) {
215226
newPlan
216227
} else {
217228
// copy fallback reasons to the original plan
@@ -347,7 +358,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
347358
// We shouldn't transform Spark query plan if Comet is not loaded.
348359
if (!isCometLoaded(conf)) return plan
349360

350-
if (!isCometExecEnabled(conf)) {
361+
if (!CometConf.COMET_EXEC_ENABLED.get(conf)) {
351362
// Comet exec is disabled, but for Spark shuffle, we still can use Comet columnar shuffle
352363
if (isCometShuffleEnabled(conf)) {
353364
applyCometShuffle(plan)
@@ -518,4 +529,49 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
518529
false
519530
}
520531
}
532+
533+
private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
534+
// Only consider converting leaf nodes to columnar currently, so that all the following
535+
// operators can have a chance to be converted to columnar. Leaf operators that output
536+
// columnar batches, such as Spark's vectorized readers, will also be converted to native
537+
// comet batches.
538+
val fallbackReasons = new ListBuffer[String]()
539+
if (CometSparkToColumnarExec.isSchemaSupported(op.schema, fallbackReasons)) {
540+
op match {
541+
// Convert Spark DS v1 scan to Arrow format
542+
case scan: FileSourceScanExec =>
543+
scan.relation.fileFormat match {
544+
case _: CSVFileFormat => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
545+
case _: JsonFileFormat => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
546+
case _: ParquetFileFormat => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
547+
case _ => isSparkToArrowEnabled(conf, op)
548+
}
549+
// Convert Spark DS v2 scan to Arrow format
550+
case scan: BatchScanExec =>
551+
scan.scan match {
552+
case _: CSVScan => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
553+
case _: JsonScan => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
554+
case _: ParquetScan => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
555+
case _ => isSparkToArrowEnabled(conf, op)
556+
}
557+
// other leaf nodes
558+
case _: LeafExecNode =>
559+
isSparkToArrowEnabled(conf, op)
560+
case _ =>
561+
// TODO: consider converting other intermediate operators to columnar.
562+
false
563+
}
564+
} else {
565+
false
566+
}
567+
}
568+
569+
private def isSparkToArrowEnabled(conf: SQLConf, op: SparkPlan) = {
570+
COMET_SPARK_TO_ARROW_ENABLED.get(conf) && {
571+
val simpleClassName = Utils.getSimpleName(op.getClass)
572+
val nodeName = simpleClassName.replaceAll("Exec$", "")
573+
COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
574+
}
575+
}
576+
521577
}

spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import org.apache.spark.sql.types._
4242

4343
import org.apache.comet.{CometConf, CometNativeException, DataTypeSupport}
4444
import org.apache.comet.CometConf._
45-
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanEnabled, withInfo, withInfos}
45+
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, withInfo, withInfos}
4646
import org.apache.comet.DataTypeSupport.isComplexType
4747
import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata, IcebergReflection}
4848
import org.apache.comet.objectstore.NativeConfig
@@ -108,7 +108,7 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com
108108
}
109109

110110
def transformScan(plan: SparkPlan): SparkPlan = plan match {
111-
case scan if !isCometScanEnabled(conf) =>
111+
case scan if !CometConf.COMET_NATIVE_SCAN_ENABLED.get(conf) =>
112112
withInfo(scan, "Comet Scan is not enabled")
113113

114114
case scan if hasMetadataCol(scan) =>

spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ object CometUnscaledValue extends CometExpressionSerde[UnscaledValue] {
3838
}
3939

4040
object CometMakeDecimal extends CometExpressionSerde[MakeDecimal] {
41+
42+
override def getSupportLevel(expr: MakeDecimal): SupportLevel = {
43+
expr.child.dataType match {
44+
case LongType => Compatible()
45+
case other => Unsupported(Some(s"Unsupported input data type: $other"))
46+
}
47+
}
48+
4149
override def convert(
4250
expr: MakeDecimal,
4351
inputs: Seq[Attribute],

0 commit comments

Comments
 (0)