Skip to content

Commit fe9cd74

Browse files
committed
Add row-count assert_equal regression test
Signed-off-by: WilliamK112 <164879897+WilliamK112@users.noreply.github.com>
1 parent 808b4fb commit fe9cd74

1 file changed

Lines changed: 40 additions & 0 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from asserts import assert_equal
18+
from spark_session import with_cpu_session, with_gpu_session
19+
20+
21+
def test_assert_equal_row_count_match():
22+
cpu_count = with_cpu_session(lambda spark: spark.range(12).count())
23+
gpu_count = with_gpu_session(lambda spark: spark.range(12).count())
24+
25+
assert_equal(cpu_count, gpu_count)
26+
27+
28+
def test_assert_equal_row_count_mismatch_raises_assertion_error(capsys):
29+
cpu_count = with_cpu_session(lambda spark: spark.range(2).count())
30+
gpu_count = with_gpu_session(lambda spark: spark.range(1).count())
31+
32+
with pytest.raises(AssertionError) as exc_info:
33+
assert_equal(cpu_count, gpu_count)
34+
35+
assert "int values are different" in str(exc_info.value)
36+
captured = capsys.readouterr()
37+
assert "--- CPU OUTPUT" in captured.out
38+
assert "+++ GPU OUTPUT" in captured.out
39+
assert "-2" in captured.out
40+
assert "+1" in captured.out

0 commit comments

Comments
 (0)