Skip to content

Commit 966d836

Browse files
authored
feat: Add num_rows and TaskContext to CometUDFBridge.evaluate (#4306)
1 parent f008bc2 commit 966d836

9 files changed

Lines changed: 162 additions & 26 deletions

File tree

common/src/main/java/org/apache/comet/udf/CometUdfBridge.java

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.apache.arrow.memory.BufferAllocator;
2828
import org.apache.arrow.vector.FieldVector;
2929
import org.apache.arrow.vector.ValueVector;
30+
import org.apache.spark.TaskContext;
31+
import org.apache.spark.comet.CometTaskContextShim;
3032

3133
/**
3234
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
@@ -48,13 +50,52 @@ public class CometUdfBridge {
4850
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
4951
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
5052
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
53+
* @param numRows row count of the current batch. Mirrors DataFusion's {@code
54+
* ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
55+
* zero-arg non-deterministic ScalaUDF) ever sees.
56+
* @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
57+
* {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
58+
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
59+
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
60+
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
61+
* left on a worker by a previous task.
5162
*/
5263
public static void evaluate(
5364
String udfClassName,
5465
long[] inputArrayPtrs,
5566
long[] inputSchemaPtrs,
5667
long outArrayPtr,
57-
long outSchemaPtr) {
68+
long outSchemaPtr,
69+
int numRows,
70+
TaskContext taskContext) {
71+
// Save-and-restore rather than only-install-if-null: the propagated context is the ground
72+
// truth for this call. Any value already on the thread is either (a) the same object on a
73+
// Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
74+
TaskContext prior = TaskContext.get();
75+
if (taskContext != null) {
76+
CometTaskContextShim.set(taskContext);
77+
}
78+
try {
79+
evaluateInternal(
80+
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
81+
} finally {
82+
if (taskContext != null) {
83+
if (prior != null) {
84+
CometTaskContextShim.set(prior);
85+
} else {
86+
CometTaskContextShim.unset();
87+
}
88+
}
89+
}
90+
}
91+
92+
private static void evaluateInternal(
93+
String udfClassName,
94+
long[] inputArrayPtrs,
95+
long[] inputSchemaPtrs,
96+
long outArrayPtr,
97+
long outSchemaPtr,
98+
int numRows) {
5899
CometUDF udf =
59100
INSTANCES.computeIfAbsent(
60101
udfClassName,
@@ -84,23 +125,17 @@ public static void evaluate(
84125
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
85126
}
86127

87-
result = udf.evaluate(inputs);
128+
result = udf.evaluate(inputs, numRows);
88129
if (!(result instanceof FieldVector)) {
89130
throw new RuntimeException(
90131
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
91132
}
92-
// Result length must match the longest input. Scalar (length-1) inputs
93-
// are allowed to be shorter, but a vector input bounds the output.
94-
int expectedLen = 0;
95-
for (ValueVector v : inputs) {
96-
expectedLen = Math.max(expectedLen, v.getValueCount());
97-
}
98-
if (result.getValueCount() != expectedLen) {
133+
if (result.getValueCount() != numRows) {
99134
throw new RuntimeException(
100135
"CometUDF.evaluate() returned "
101136
+ result.getValueCount()
102137
+ " rows, expected "
103-
+ expectedLen);
138+
+ numRows);
104139
}
105140
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
106141
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);

common/src/main/scala/org/apache/comet/udf/CometUDF.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector
2727
*
2828
* - Vector arguments arrive at the row count of the current batch.
2929
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
30-
* - The returned vector's length must match the longest input.
30+
* - The returned vector's length must match `numRows`.
31+
*
32+
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
33+
* UDFs that always have at least one batch-length input can derive length from the inputs and
34+
* ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF)
35+
* need `numRows` to know how many rows to produce.
3136
*
3237
* Implementations must have a public no-arg constructor and must be stateless: a single instance
3338
* per class is cached and shared across native worker threads for the lifetime of the JVM.
3439
*/
3540
trait CometUDF {
36-
def evaluate(inputs: Array[ValueVector]): ValueVector
41+
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
3742
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.comet
21+
22+
import org.apache.spark.TaskContext
23+
24+
/**
25+
* Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`.
26+
*
27+
* Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are
28+
* reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`.
29+
* The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a
30+
* Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive
31+
* built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's
32+
* `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the
33+
* protected methods, and exposes plain public forwarders the bridge (which lives in
34+
* `org.apache.comet.udf`) can use.
35+
*/
36+
object CometTaskContextShim {
37+
38+
def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext)
39+
40+
def unset(): Unit = TaskContext.unset()
41+
}

native/core/src/execution/jni_api.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ struct ExecutionContext {
306306
pub tracing_memory_metric_name: String,
307307
/// Pre-computed tracing event name for executePlan calls
308308
pub tracing_event_name: String,
309+
/// Spark `TaskContext` captured on the driving Spark task thread at `createPlan` time.
310+
/// Threaded into every JVM scalar UDF the planner builds so the JNI bridge can install it
311+
/// as the thread-local `TaskContext` for the Tokio worker running the UDF. `None` when no
312+
/// driving Spark task is present (unit tests, direct native driver runs). The `Arc` is
313+
/// cheap to clone; the underlying `Global<JObject>` releases its JNI global ref on drop
314+
/// via `jni`'s `Drop` impl.
315+
pub task_context: Option<Arc<Global<JObject<'static>>>>,
309316
}
310317

311318
/// Accept serialized query plan and return the address of the native query plan.
@@ -332,6 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
332339
task_attempt_id: jlong,
333340
task_cpus: jlong,
334341
key_unwrapper_obj: JObject,
342+
task_context_obj: JObject,
335343
) -> jlong {
336344
try_unwrap_or_throw(&e, |env| {
337345
// Deserialize Spark configs
@@ -453,6 +461,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
453461
String::new()
454462
};
455463

464+
// Capture the driving Spark task's TaskContext as a JNI global reference when
465+
// non-null. The `Arc<Global<JObject>>` releases its global ref on drop, so cleanup
466+
// is automatic when the ExecutionContext drops.
467+
let task_context = if !task_context_obj.is_null() {
468+
Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?))
469+
} else {
470+
None
471+
};
472+
456473
let exec_context = Box::new(ExecutionContext {
457474
id,
458475
task_attempt_id,
@@ -479,6 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
479496
"thread_{rust_thread_id}_comet_memory_reserved"
480497
),
481498
tracing_event_name,
499+
task_context,
482500
});
483501

484502
Ok(Box::into_raw(exec_context) as i64)
@@ -703,7 +721,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
703721
let start = Instant::now();
704722
let planner =
705723
PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
706-
.with_exec_id(exec_context_id);
724+
.with_exec_id(exec_context_id)
725+
.with_task_context(exec_context.task_context.clone());
707726
let (scans, shuffle_scans, root_op) = planner.create_plan(
708727
&exec_context.spark_plan,
709728
&mut exec_context.input_sources.clone(),

native/core/src/execution/planner.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ pub struct PhysicalPlanner {
183183
partition: i32,
184184
session_ctx: Arc<SessionContext>,
185185
query_context_registry: Arc<datafusion_comet_spark_expr::QueryContextMap>,
186+
/// Captured at `createPlan` time on `ExecutionContext`; see that struct for the
187+
/// propagation rationale. `None` when no driving Spark task is available.
188+
task_context: Option<Arc<Global<JObject<'static>>>>,
186189
}
187190

188191
impl Default for PhysicalPlanner {
@@ -198,16 +201,24 @@ impl PhysicalPlanner {
198201
session_ctx,
199202
partition,
200203
query_context_registry: datafusion_comet_spark_expr::create_query_context_map(),
204+
task_context: None,
201205
}
202206
}
203207

204-
pub fn with_exec_id(self, exec_context_id: i64) -> Self {
205-
Self {
206-
exec_context_id,
207-
partition: self.partition,
208-
session_ctx: Arc::clone(&self.session_ctx),
209-
query_context_registry: Arc::clone(&self.query_context_registry),
210-
}
208+
pub fn with_exec_id(mut self, exec_context_id: i64) -> Self {
209+
self.exec_context_id = exec_context_id;
210+
self
211+
}
212+
213+
/// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned
214+
/// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as
215+
/// the thread-local on the Tokio worker driving the UDF.
216+
pub fn with_task_context(
217+
mut self,
218+
task_context: Option<Arc<Global<JObject<'static>>>>,
219+
) -> Self {
220+
self.task_context = task_context;
221+
self
211222
}
212223

213224
/// Return session context of this planner.
@@ -735,6 +746,7 @@ impl PhysicalPlanner {
735746
args,
736747
return_type,
737748
udf.return_nullable,
749+
self.task_context.clone(),
738750
)))
739751
}
740752
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),

native/jni-bridge/src/comet_udf_bridge.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> {
4141
method_evaluate: env.get_static_method_id(
4242
JNIString::new(Self::JVM_CLASS),
4343
jni::jni_str!("evaluate"),
44-
jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"),
44+
jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"),
4545
)?,
4646
method_evaluate_ret: ReturnType::Primitive(Primitive::Void),
4747
class,

native/spark-expr/src/jvm_udf/mod.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr;
3131

3232
use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError};
3333
use datafusion_comet_jni_bridge::JVMClasses;
34-
use jni::objects::{JObject, JValue};
34+
use jni::objects::{Global, JObject, JValue};
3535

3636
/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI.
3737
/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`.
@@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr {
4141
args: Vec<Arc<dyn PhysicalExpr>>,
4242
return_type: DataType,
4343
return_nullable: bool,
44+
/// Captured at `createPlan` time and threaded here by the planner. Passed through the
45+
/// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's
46+
/// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF
47+
/// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
48+
/// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving
49+
/// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already
50+
/// returns in place.
51+
task_context: Option<Arc<Global<JObject<'static>>>>,
4452
}
4553

4654
impl JvmScalarUdfExpr {
@@ -49,12 +57,14 @@ impl JvmScalarUdfExpr {
4957
args: Vec<Arc<dyn PhysicalExpr>>,
5058
return_type: DataType,
5159
return_nullable: bool,
60+
task_context: Option<Arc<Global<JObject<'static>>>>,
5261
) -> Self {
5362
Self {
5463
class_name,
5564
args,
5665
return_type,
5766
return_nullable,
67+
task_context,
5868
}
5969
}
6070
}
@@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr {
186196
.set_region(env, 0, &in_sch_ptrs)
187197
.map_err(|e| CometError::JNI { source: e })?;
188198

189-
// Call CometUdfBridge.evaluate(String, long[], long[], long, long)
199+
// Pass a null jobject when no TaskContext was propagated so the bridge's null-guard
200+
// leaves the worker thread's current TaskContext.get() in place. The borrow must
201+
// outlive `call_static_method_unchecked`.
202+
let null_task_context = JObject::null();
203+
let task_context_ref: &JObject = match &self.task_context {
204+
Some(gref) => gref.as_obj(),
205+
None => &null_task_context,
206+
};
190207
let ret = unsafe {
191208
env.call_static_method_unchecked(
192209
&bridge.class,
@@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr {
198215
JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(),
199216
JValue::Long(out_arr_ptr).as_jni(),
200217
JValue::Long(out_sch_ptr).as_jni(),
218+
JValue::Int(batch.num_rows() as i32).as_jni(),
219+
JValue::Object(task_context_ref).as_jni(),
201220
],
202221
)
203222
};
@@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr {
234253
children,
235254
self.return_type.clone(),
236255
self.return_nullable,
256+
self.task_context.clone(),
237257
)))
238258
}
239259
}

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ class CometExecIterator(
127127
memoryConfig.memoryLimitPerTask,
128128
taskAttemptId,
129129
taskCPUs,
130-
keyUnwrapper)
130+
keyUnwrapper,
131+
// Propagated to Tokio workers running JVM UDFs so they see this Spark task's
132+
// TaskContext. See CometUdfBridge.evaluate.
133+
TaskContext.get())
131134
}
132135

133136
private var nextBatch: Option[ColumnarBatch] = None

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet
2121

2222
import java.nio.ByteBuffer
2323

24-
import org.apache.spark.CometTaskMemoryManager
24+
import org.apache.spark.{CometTaskMemoryManager, TaskContext}
2525
import org.apache.spark.sql.comet.CometMetricNode
2626

2727
import org.apache.comet.parquet.CometFileKeyUnwrapper
@@ -69,7 +69,8 @@ class Native extends NativeBase {
6969
memoryLimitPerTask: Long,
7070
taskAttemptId: Long,
7171
taskCPUs: Long,
72-
keyUnwrapper: CometFileKeyUnwrapper): Long
72+
keyUnwrapper: CometFileKeyUnwrapper,
73+
taskContext: TaskContext): Long
7374
// scalastyle:on
7475

7576
/**

0 commit comments

Comments
 (0)