-
Notifications
You must be signed in to change notification settings - Fork 288
[DO NOT REVIEW] Adds in support for a "CpuBridge" that lets us fall back to the CPU on a per-expression level instead of per-SparkPlan node level. [databricks] #13368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
revans2
wants to merge
75
commits into
NVIDIA:main
Choose a base branch
from
revans2:cursor_cpu_gpu_expr_transitions
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 8 commits
Commits
Show all changes
75 commits
Select commit
Hold shift + click to select a range
39439aa
Adds in support for a "CpuBridge" that lets us fall back to the CPU on
revans2 6aa3a41
Merge branch 'branch-25.10' into cursor_cpu_gpu_expr_transitions
revans2 5e6879d
Addressed review comments and fixed issues with upmerge
revans2 a6e7e8a
Some more fixes
revans2 4bd51ec
Partial fix for ScalarSubquery
revans2 421d77f
Done with the fix
revans2 70b203b
Merge branch 'branch-25.10' into cursor_cpu_gpu_expr_transitions
revans2 246757c
oops
revans2 f9fcc78
Performance fix
revans2 3fa0d76
More improvements and fixes
revans2 3ee7c21
Better optimizer and tests
revans2 e85739f
Cleanup
revans2 c0e77a4
Heuristic not totally where I want it yet
revans2 002a575
Step
revans2 1a034e2
Step
revans2 8083eb3
Bug fix
revans2 703880b
Cleanup
revans2 6c30b6b
Fixes and cleanup
revans2 688f7a8
More cleanup and improved tests
revans2 dc92c30
More tests
revans2 7dcadba
Some join fixes, but still not really working
revans2 ebd0bee
Fixed some issues with AST
revans2 bc1800f
Fixes for join
revans2 4a12461
More join tests
revans2 09d58a4
Fix tests for predicarte push down
revans2 e3b8dc8
Fixed window tests
revans2 ee868fc
Fixed group by test
revans2 0261b91
sort test is working
revans2 c3e9872
Even more better tests
revans2 09eaff0
Now partitions work too
revans2 5267941
Even more test fixes
revans2 6fc2c5b
Merge branch 'branch-25.10' into cursor_cpu_gpu_expr_transitions
revans2 00403bb
Fix tests to avoid ANSI mode issues
revans2 3cc3587
Fix build for 3.2.0
revans2 798f6ef
Update test checks for non-GPU
revans2 0698e66
Fix scala tests
revans2 f27dfc2
Have most tests passing
revans2 c24a826
Fix issues with scalar sub query and higher order functions
revans2 2173c55
Some initial work on metrics
revans2 84082ff
More metrics
revans2 8d8ec97
Metrics
revans2 7e2ad81
Merge branch 'branch-25.12' into cursor_cpu_gpu_expr_transitions
revans2 db149e9
Fixed metrics
revans2 6fab09e
More cleanup
revans2 832e80a
Copyright dates
revans2 1d8d78d
More cleanup
revans2 3fe3c0f
Bug and test fix
revans2 b2208ee
More tests fixes and one code fix
revans2 03b2b43
A few more delta fixes
revans2 460dea4
Test fix
revans2 da64e2b
Checkpointing some changes, but not done yet
revans2 74c3e91
More fixes
revans2 ed87c3d
Merge branch 'branch-25.12' into cursor_cpu_gpu_expr_transitions
revans2 fb15916
Fix some test failures
revans2 b7f33c9
Merge branch 'branch-25.12' into cursor_cpu_gpu_expr_transitions
revans2 861fac4
Performance optimization
revans2 e8c4bd4
Better code generation for better performance
revans2 6f9c2b2
More test fixes
revans2 f9f3e18
Bug fix
revans2 3817d45
Work around performance issue. Need to debug more
revans2 f0aadb5
Perf improvement
revans2 b193dbb
Addressed some review comments
revans2 ea2b494
Copyright update
revans2 3c46afc
Move to passing metrics to bind
revans2 8bf1485
Merge branch 'branch-25.12' into cursor_cpu_gpu_expr_transitions
revans2 2b51bb2
Some test cleanup
revans2 b032ec1
Copyright update
revans2 d2cdcef
Fix compile error
revans2 4d4576f
More fixes
revans2 13a43a2
Another fix
revans2 15a5520
More fixes so that we can do interpreted fallback better
revans2 045e500
Merge branch 'branch-25.12' into cursor_cpu_gpu_expr_transitions
revans2 66f3d5e
Review comments
revans2 64e6557
Fix compile issue. I think this was missed from another PR
revans2 ba84782
Merge branch 'main' into cursor_cpu_gpu_expr_transitions
revans2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
|
|
||
| from pyspark.sql.functions import col | ||
| from asserts import assert_gpu_and_cpu_are_equal_collect, assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_fallback_collect | ||
| from marks import allow_non_gpu | ||
| from data_gen import * | ||
| from marks import ignore_order | ||
| from spark_session import is_before_spark_330, is_databricks_runtime | ||
|
|
||
|
|
||
| # Helper function to create config that forces specific expressions to CPU bridge | ||
| def create_cpu_bridge_fallback_conf(disabled_gpu_expressions, codegen_enabled=True, | ||
| disallowed_bridge_expressions=[]): | ||
| """Create config that enables CPU bridge and disables specific GPU expressions""" | ||
| conf = { | ||
| 'spark.rapids.sql.expression.cpuBridge.enabled': True, | ||
| 'spark.rapids.sql.expression.cpuBridge.codegenEnabled': codegen_enabled | ||
| } | ||
| # Disable specific GPU expressions to force CPU bridge fallback | ||
| for expr_name in disabled_gpu_expressions: | ||
| conf[f'spark.rapids.sql.expression.{expr_name}'] = False | ||
|
|
||
| if disallowed_bridge_expressions: | ||
| conf['spark.rapids.sql.expression.cpuBridge.disallowList'] = ','.join(disallowed_bridge_expressions) | ||
| return conf | ||
|
|
||
| @pytest.mark.parametrize('codegen_enabled', [True, False], ids=['codegen_on', 'codegen_off']) | ||
| def test_cpu_bridge_add_fallback(codegen_enabled): | ||
| """Test CPU bridge when Add expression is forced to fall back to CPU""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', int_gen)], length=1024) | ||
| # This Add will be forced to use CPU bridge due to config | ||
| return df.selectExpr("a", "b", "a + a as s1", "b * b as p1") | ||
|
|
||
| # Force Add to fall back to CPU bridge | ||
| conf = create_cpu_bridge_fallback_conf(['Add'], codegen_enabled=codegen_enabled) | ||
| assert_gpu_and_cpu_are_equal_collect(test_func, conf=conf) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('codegen_enabled', [True, False], ids=['codegen_on', 'codegen_off']) | ||
| def test_cpu_bridge_multiply_fallback(codegen_enabled): | ||
| """Test CPU bridge when Multiply expression is forced to fall back to CPU""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', double_gen)], length=1024) | ||
| # This Multiply will be forced to use CPU bridge due to config | ||
| return df.selectExpr("a", "b", "a + a as s1", "b * b as p1") | ||
|
|
||
| # Force Multiply to fall back to CPU bridge | ||
| conf = create_cpu_bridge_fallback_conf(['Multiply'], codegen_enabled=codegen_enabled) | ||
| assert_gpu_and_cpu_are_equal_collect(test_func, conf=conf) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('codegen_enabled', [True, False], ids=['codegen_on', 'codegen_off']) | ||
| def test_cpu_bridge_complex_expression_tree(codegen_enabled): | ||
| """Test CPU bridge with complex expression trees containing multiple fallbacks""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', int_gen), ('c', int_gen)], length=1000) | ||
| return df.selectExpr( | ||
| "a", "b", "c", | ||
| # Complex expression mixing CPU bridge (Add) and GPU (Multiply) operations | ||
| "a + (b * c) as mixed", | ||
| "case when a > 0 then a + b + 5 else a + c + 2 end as conditional" | ||
| ) | ||
|
|
||
| # Force Add to CPU bridge, keep other expressions on GPU | ||
| conf = create_cpu_bridge_fallback_conf(['Add'], codegen_enabled=codegen_enabled) | ||
| assert_gpu_and_cpu_are_equal_collect(test_func, conf=conf) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('codegen_enabled', [True, False], ids=['codegen_on', 'codegen_off']) | ||
| def test_cpu_bridge_higher_order_function_fallback(codegen_enabled): | ||
| """Test CPU bridge with higher order functions where inner expressions fall back to CPU""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('arr', ArrayGen(int_gen, min_length=3, max_length=5))], length=1000) | ||
| return df.selectExpr( | ||
| "arr", | ||
| # transform where the lambda contains Add (forced to CPU bridge) | ||
| "transform(arr, x -> x + 1) as arr_plus_one", | ||
| # filter where the lambda contains Add (forced to CPU bridge) | ||
| "filter(arr, x -> x + 2 > 5) as filtered_arr", | ||
| # exists where the lambda contains Add (forced to CPU bridge) | ||
| "exists(arr, x -> (x + 3) > 10) as has_large_element", | ||
| "transform(arr, x -> (x + 2) * 3) as transformed" | ||
| ) | ||
|
|
||
| # Force Add to CPU bridge - this will affect expressions inside the lambda functions | ||
| conf = create_cpu_bridge_fallback_conf(['Add'], codegen_enabled=codegen_enabled) | ||
| assert_gpu_and_cpu_are_equal_collect(test_func, conf=conf) | ||
|
|
||
| def test_cpu_bridge_nondeterministic_works_next_to_bridge(): | ||
| """Test mixed scenario: some expressions use CPU bridge, others stay on GPU""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', int_gen)], length=1000) | ||
| # Add should use CPU bridge, rand(42) should stay on GPU - this should work fine | ||
| return df.selectExpr("a", "b", "a + b as sum", "rand(42) * 100 as scaled_random") | ||
|
|
||
| # Force Add to CPU bridge, rand() stays on GPU - should work with mixed execution | ||
| conf = create_cpu_bridge_fallback_conf(['Add']) | ||
|
|
||
| # This should succeed with mixed GPU/CPU bridge execution, not fall back entirely | ||
| assert_gpu_and_cpu_are_equal_collect(test_func, conf=conf) | ||
|
|
||
| # This is borrowed partly from join_test.py | ||
| bloom_filter_confs = { | ||
| "spark.sql.autoBroadcastJoinThreshold": 1, | ||
| "spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold": 1, | ||
| "spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold": "100GB", | ||
| "spark.sql.optimizer.runtime.bloomFilter.enabled": True | ||
| } | ||
|
|
||
| def check_bloom_filter_join(confs, is_multi_column): | ||
| def do_join(spark): | ||
| if is_multi_column: | ||
| left = spark.range(100000).withColumn("second_id", col("id") % 5) | ||
| right = spark.range(10).withColumn("id2", col("id").cast("string")).withColumn("second_id", col("id") % 5) | ||
| return right.filter("cast(id2 as bigint) % 3 = 0").join(left, (left.id == right.id) & (left.second_id == right.second_id), "inner") | ||
| else: | ||
| left = spark.range(100000) | ||
| right = spark.range(10).withColumn("id2", col("id").cast("string")) | ||
| return right.filter("cast(id2 as bigint) % 3 = 0").join(left, left.id == right.id, "inner") | ||
| bridge_conf = create_cpu_bridge_fallback_conf([]) | ||
| partial_conf = copy_and_update(bridge_conf, confs) | ||
| all_confs = copy_and_update(bloom_filter_confs, partial_conf) | ||
| assert_gpu_and_cpu_are_equal_collect(do_join, conf=all_confs) | ||
|
|
||
| @allow_non_gpu("ShuffleExchangeExec") | ||
| @ignore_order(local=True) | ||
| @pytest.mark.parametrize("is_multi_column", [False, True], ids=["SINGLE_COLUMN", "MULTI_COLUMN"]) | ||
| @pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") | ||
| @pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") | ||
| def test_bloom_filter_join_cpu_probe(is_multi_column): | ||
| conf = {"spark.rapids.sql.expression.BloomFilterMightContain": "false"} | ||
| check_bloom_filter_join(confs=conf, is_multi_column=is_multi_column) | ||
|
|
||
| # ============================================================================== | ||
| # NEGATIVE TEST CASES - Verify expressions that should NOT use CPU bridge | ||
| # These should cause full CPU fallback instead of using the bridge | ||
| # ============================================================================== | ||
|
|
||
| @allow_non_gpu('ProjectExec') | ||
| def test_cpu_bridge_rand_disabled_fallback(): | ||
| """Test that when rand() is disabled via config, ProjectExec falls back to CPU entirely""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', int_gen)], length=1000) | ||
| # rand(42) with seed for deterministic results - disabled via config should cause full CPU fallback | ||
| return df.selectExpr("a", "b", "a + b as sum", "rand(42) as random_val") | ||
|
|
||
| # Enable CPU bridge but disable rand() - should cause full ProjectExec fallback | ||
| conf = create_cpu_bridge_fallback_conf(['Rand']) | ||
|
|
||
| # Verify that ProjectExec falls back to CPU (doesn't use GPU or bridge) | ||
| assert_gpu_fallback_collect(test_func, 'ProjectExec', conf=conf) | ||
|
zpuller marked this conversation as resolved.
|
||
|
|
||
|
|
||
| @allow_non_gpu('HashAggregateExec', 'ShuffleExchangeExec') | ||
| @ignore_order(local=True) | ||
| def test_cpu_bridge_aggregation_sum_disabled_fallback(): | ||
| """Test that when sum() is disabled via config, HashAggregateExec falls back to CPU entirely""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', string_gen)], length=2000) | ||
| # sum() disabled via config should cause full HashAggregateExec fallback, not bridge | ||
| return df.groupBy('b').agg(f.sum('a').alias('total')) | ||
|
|
||
| # Enable CPU bridge but disable sum() - should cause full HashAggregateExec fallback | ||
| conf = create_cpu_bridge_fallback_conf(['Sum']) | ||
|
|
||
| # Verify that HashAggregateExec falls back to CPU (doesn't use GPU or bridge) | ||
| assert_gpu_fallback_collect(test_func, 'HashAggregateExec', conf=conf) | ||
|
|
||
|
|
||
| @allow_non_gpu('WindowExec') | ||
| def test_cpu_bridge_window_lag_disabled_fallback(): | ||
| """Test that when lag() is disabled via config, WindowExec falls back to CPU entirely""" | ||
| def test_func(spark): | ||
| df = gen_df(spark, [('a', int_gen), ('b', string_gen)], length=1000) | ||
| # lag() disabled via config should cause full WindowExec fallback, not bridge | ||
| return df.selectExpr( | ||
| "a", "b", | ||
| "lag(a, 1) over (partition by b order by a) as prev_a", | ||
| "row_number() over (partition by b order by a) as row_num" | ||
| ) | ||
|
|
||
| # Enable CPU bridge but disable lag() - should cause full WindowExec fallback | ||
| conf = create_cpu_bridge_fallback_conf(['Lag']) | ||
|
|
||
| # Verify that WindowExec falls back to CPU (doesn't use GPU or bridge) | ||
| assert_gpu_fallback_collect(test_func, 'WindowExec', conf=conf) | ||
|
|
||
| @allow_non_gpu('ProjectExec') | ||
| def test_disallowed_bridge_fallback(): | ||
| """Test that when an expression is not on the GPU and is disallowed to use | ||
| the cpu_bridge that it is honored""" | ||
| conf = create_cpu_bridge_fallback_conf(['Add'], | ||
| disallowed_bridge_expressions=['org.apache.spark.sql.catalyst.expressions.Add']) | ||
| assert_gpu_fallback_collect(lambda spark: binary_op_df(spark, byte_gen).selectExpr("a + b"), | ||
| 'ProjectExec', conf=conf) | ||
|
|
||
| @allow_non_gpu("GenerateExec", "ShuffleExchangeExec") | ||
| @ignore_order(local=True) | ||
| def test_generate_outer_fallback(): | ||
| conf = create_cpu_bridge_fallback_conf([]) | ||
| assert_gpu_fallback_collect( | ||
| lambda spark: spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as x")\ | ||
| .repartition(1).selectExpr("inline_outer(x)"), | ||
| "GenerateExec", conf = conf) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.