2828
2929def create_dim_table (table_name , table_format , length = 500 ):
3030 def fn (spark ):
31+ # Pick a random filter value, but make it constant for the whole
3132 df = gen_df (spark , [
3233 ('key' , IntegerGen (nullable = False , min_val = 0 , max_val = 9 , special_cases = [])),
3334 ('skey' , IntegerGen (nullable = False , min_val = 0 , max_val = 4 , special_cases = [])),
3435 ('ex_key' , IntegerGen (nullable = False , min_val = 0 , max_val = 3 , special_cases = [])),
3536 ('value' , value_gen ),
3637 # specify nullable=False for `filter` to avoid generating invalid SQL with
3738 # expression `filter = None` (https://github.com/NVIDIA/spark-rapids/issues/9817)
38- ('filter' , RepeatSeqGen (
39- IntegerGen (min_val = 0 , max_val = length , special_cases = [], nullable = False ), length = length // 20 ))
39+ ('filter' , RepeatSeqGen (IntegerGen (nullable = False ), length = 1 ))
4040 ], length )
4141 df .cache ()
4242 df .write .format (table_format ) \
4343 .mode ("overwrite" ) \
4444 .saveAsTable (table_name )
45- return df .select ('filter' ).where ( "value > 0" ).first ()[0 ]
45+ return df .select ('filter' ).first ()[ 0 ], df . select ( 'ex_key' ).first ()[0 ]
4646
4747 return with_cpu_session (fn )
4848
@@ -146,7 +146,7 @@ def fn(spark):
146146 dim_table AS (
147147 SELECT dim.key as key, dim.value as value, dim.filter as filter
148148 FROM {1} dim
149- WHERE ex_key = 3
149+ WHERE ex_key = {3}
150150 ORDER BY dim.key
151151 )
152152 SELECT key, max(value)
@@ -181,8 +181,9 @@ def fn(spark):
181181def test_dpp_reuse_broadcast_exchange (spark_tmp_table_factory , store_format , s_index , aqe_enabled ):
182182 fact_table , dim_table = spark_tmp_table_factory .get (), spark_tmp_table_factory .get ()
183183 create_fact_table (fact_table , store_format , length = 10000 )
184- filter_val = create_dim_table (dim_table , store_format , length = 2000 )
185- statement = _statements [s_index ].format (fact_table , dim_table , filter_val )
184+ filter_val , ex_key_val = create_dim_table (dim_table , store_format , length = 2000 )
185+ statement = _statements [s_index ].format (fact_table , dim_table , filter_val , ex_key_val )
186+
186187 if is_databricks113_or_later () and aqe_enabled == 'true' :
187188 # SubqueryBroadcastExec is unoptimized in Databricks 11.3 with EXECUTOR_BROADCAST
188189 # See https://github.com/NVIDIA/spark-rapids/issues/7425
@@ -202,8 +203,8 @@ def test_dpp_reuse_broadcast_exchange(spark_tmp_table_factory, store_format, s_i
202203def test_dpp_reuse_broadcast_exchange_cpu_scan (spark_tmp_table_factory ):
203204 fact_table , dim_table = spark_tmp_table_factory .get (), spark_tmp_table_factory .get ()
204205 create_fact_table (fact_table , 'parquet' , length = 10000 )
205- filter_val = create_dim_table (dim_table , 'parquet' , length = 2000 )
206- statement = _statements [0 ].format (fact_table , dim_table , filter_val )
206+ filter_val , ex_key_val = create_dim_table (dim_table , 'parquet' , length = 2000 )
207+ statement = _statements [0 ].format (fact_table , dim_table , filter_val , ex_key_val )
207208 assert_cpu_and_gpu_are_equal_collect_with_capture (
208209 lambda spark : spark .sql (statement ),
209210 # The existence of GpuSubqueryBroadcastExec indicates the reuse works on the GPU
@@ -226,8 +227,8 @@ def test_dpp_reuse_broadcast_exchange_cpu_scan(spark_tmp_table_factory):
226227def test_dpp_bypass (spark_tmp_table_factory , store_format , s_index , aqe_enabled ):
227228 fact_table , dim_table = spark_tmp_table_factory .get (), spark_tmp_table_factory .get ()
228229 create_fact_table (fact_table , store_format )
229- filter_val = create_dim_table (dim_table , store_format )
230- statement = _statements [s_index ].format (fact_table , dim_table , filter_val )
230+ filter_val , ex_key_val = create_dim_table (dim_table , store_format )
231+ statement = _statements [s_index ].format (fact_table , dim_table , filter_val , ex_key_val )
231232 assert_cpu_and_gpu_are_equal_collect_with_capture (
232233 lambda spark : spark .sql (statement ),
233234 # Bypass with a true literal, if we can not reuse broadcast exchange.
@@ -250,8 +251,8 @@ def test_dpp_bypass(spark_tmp_table_factory, store_format, s_index, aqe_enabled)
250251def test_dpp_via_aggregate_subquery (spark_tmp_table_factory , store_format , s_index , aqe_enabled ):
251252 fact_table , dim_table = spark_tmp_table_factory .get (), spark_tmp_table_factory .get ()
252253 create_fact_table (fact_table , store_format )
253- filter_val = create_dim_table (dim_table , store_format )
254- statement = _statements [s_index ].format (fact_table , dim_table , filter_val )
254+ filter_val , ex_key_val = create_dim_table (dim_table , store_format )
255+ statement = _statements [s_index ].format (fact_table , dim_table , filter_val , ex_key_val )
255256 assert_cpu_and_gpu_are_equal_collect_with_capture (
256257 lambda spark : spark .sql (statement ),
257258 # SubqueryExec appears if we plan extra subquery for DPP
@@ -271,8 +272,8 @@ def test_dpp_via_aggregate_subquery(spark_tmp_table_factory, store_format, s_ind
271272def test_dpp_skip (spark_tmp_table_factory , store_format , s_index , aqe_enabled ):
272273 fact_table , dim_table = spark_tmp_table_factory .get (), spark_tmp_table_factory .get ()
273274 create_fact_table (fact_table , store_format )
274- filter_val = create_dim_table (dim_table , store_format )
275- statement = _statements [s_index ].format (fact_table , dim_table , filter_val )
275+ filter_val , ex_key_val = create_dim_table (dim_table , store_format )
276+ statement = _statements [s_index ].format (fact_table , dim_table , filter_val , ex_key_val )
276277 assert_cpu_and_gpu_are_equal_collect_with_capture (
277278 lambda spark : spark .sql (statement ),
278279 # SubqueryExec appears if we plan extra subquery for DPP
0 commit comments