2727import org .apache .arrow .memory .BufferAllocator ;
2828import org .apache .arrow .vector .FieldVector ;
2929import org .apache .arrow .vector .ValueVector ;
30+ import org .apache .spark .TaskContext ;
31+ import org .apache .spark .comet .CometTaskContextShim ;
3032
3133/**
3234 * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
@@ -48,13 +50,52 @@ public class CometUdfBridge {
4850 * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
4951 * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
5052 * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
53+ * @param numRows row count of the current batch. Mirrors DataFusion's {@code
54+ * ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
55+ * zero-arg non-deterministic ScalaUDF) ever sees.
56+ * @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
57+ * {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
58+ * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
59+ * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
60+ * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
61+ * left on a worker by a previous task.
5162 */
5263 public static void evaluate (
5364 String udfClassName ,
5465 long [] inputArrayPtrs ,
5566 long [] inputSchemaPtrs ,
5667 long outArrayPtr ,
57- long outSchemaPtr ) {
68+ long outSchemaPtr ,
69+ int numRows ,
70+ TaskContext taskContext ) {
71+ // Save-and-restore rather than only-install-if-null: the propagated context is the ground
72+ // truth for this call. Any value already on the thread is either (a) the same object on a
73+ // Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
74+ TaskContext prior = TaskContext .get ();
75+ if (taskContext != null ) {
76+ CometTaskContextShim .set (taskContext );
77+ }
78+ try {
79+ evaluateInternal (
80+ udfClassName , inputArrayPtrs , inputSchemaPtrs , outArrayPtr , outSchemaPtr , numRows );
81+ } finally {
82+ if (taskContext != null ) {
83+ if (prior != null ) {
84+ CometTaskContextShim .set (prior );
85+ } else {
86+ CometTaskContextShim .unset ();
87+ }
88+ }
89+ }
90+ }
91+
92+ private static void evaluateInternal (
93+ String udfClassName ,
94+ long [] inputArrayPtrs ,
95+ long [] inputSchemaPtrs ,
96+ long outArrayPtr ,
97+ long outSchemaPtr ,
98+ int numRows ) {
5899 CometUDF udf =
59100 INSTANCES .computeIfAbsent (
60101 udfClassName ,
@@ -84,23 +125,17 @@ public static void evaluate(
84125 inputs [i ] = Data .importVector (allocator , inArr , inSch , null );
85126 }
86127
87- result = udf .evaluate (inputs );
128+ result = udf .evaluate (inputs , numRows );
88129 if (!(result instanceof FieldVector )) {
89130 throw new RuntimeException (
90131 "CometUDF.evaluate() must return a FieldVector, got: " + result .getClass ().getName ());
91132 }
92- // Result length must match the longest input. Scalar (length-1) inputs
93- // are allowed to be shorter, but a vector input bounds the output.
94- int expectedLen = 0 ;
95- for (ValueVector v : inputs ) {
96- expectedLen = Math .max (expectedLen , v .getValueCount ());
97- }
98- if (result .getValueCount () != expectedLen ) {
133+ if (result .getValueCount () != numRows ) {
99134 throw new RuntimeException (
100135 "CometUDF.evaluate() returned "
101136 + result .getValueCount ()
102137 + " rows, expected "
103- + expectedLen );
138+ + numRows );
104139 }
105140 ArrowArray outArr = ArrowArray .wrap (outArrayPtr );
106141 ArrowSchema outSch = ArrowSchema .wrap (outSchemaPtr );
0 commit comments