Skip to content

Commit c4bb6e3

Browse files
authored
Merge pull request #1157 from bader/assert
Hoist assert from the device code.
2 parents f32f510 + b7c8385 commit c4bb6e3

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

tests/common/semantics_reference.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,10 @@ void check_kernel(InitFunc init_func, const std::string& type_name,
466466
std::size_t result_count = test_traits::result_count +
467467
test_copy::result_count + test_move::result_count;
468468

469-
std::vector<int> results(result_count, false);
469+
std::size_t result_size = result_count + 1;
470+
std::vector<int> results(result_size, false);
470471
{
471-
sycl::buffer<int> buffer(results.data(), sycl::range<1>{result_count});
472+
sycl::buffer<int> buffer(results.data(), sycl::range<1>{result_size});
472473

473474
queue.submit([&](sycl::handler& cgh) {
474475
auto accessor = buffer.template get_access<sycl::access_mode::write>(cgh);
@@ -484,8 +485,8 @@ void check_kernel(InitFunc init_func, const std::string& type_name,
484485
ptr += test_copy::result_count;
485486
test_move::run<storage>(ptr, t); // last since invalidates instance
486487
ptr += test_move::result_count;
487-
assert(static_cast<std::ptrdiff_t>(result_count) ==
488-
ptr - accessor.begin());
488+
*ptr = static_cast<std::ptrdiff_t>(result_count) ==
489+
ptr - accessor.begin();
489490
});
490491
});
491492
}
@@ -497,7 +498,7 @@ void check_kernel(InitFunc init_func, const std::string& type_name,
497498
ptr += test_copy::result_count;
498499
test_move::evaluate(ptr);
499500
ptr += test_move::result_count;
500-
assert(static_cast<std::ptrdiff_t>(result_count) == ptr - results.data());
501+
assert(*ptr);
501502
}
502503

503504
} // namespace common_reference_semantics

tests/group_functions/group_scan.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ auto joint_inclusive_scan_helper(Group group, T* v_begin, T* v_end,
3838
return sycl::joint_inclusive_scan(group, v_begin, v_end, r_i_begin, op,
3939
I(init));
4040
}
41-
assert((std::is_same_v<I, U> &&
42-
"Without init value I and U should be the same type."));
4341
return (U*)sycl::joint_inclusive_scan(group, v_begin, v_end, (I*)r_i_begin,
4442
op);
4543
}
@@ -51,8 +49,6 @@ auto joint_exclusive_scan_helper(Group group, T* v_begin, T* v_end,
5149
return sycl::joint_exclusive_scan(group, v_begin, v_end, r_e_begin, I(init),
5250
op);
5351
}
54-
assert((std::is_same_v<I, U> &&
55-
"Without init value I and U should be the same type."));
5652
return (U*)sycl::joint_exclusive_scan(group, v_begin, v_end, (I*)r_e_begin,
5753
op);
5854
}
@@ -149,6 +145,9 @@ void check_scan(sycl::queue& queue, size_t size,
149145
sycl::buffer<bool, 1> end_sycl = host_data.create_end_buffer();
150146
sycl::buffer<bool, 1> ret_type_sycl = host_data.create_ret_type_buffer();
151147

148+
assert((with_init || std::is_same_v<I, U>) &&
149+
"Without init value I and U must be the same type.");
150+
152151
queue
153152
.submit([&](sycl::handler& cgh) {
154153
sycl::accessor<T, 1> ref_input_acc(ref_input_sycl, cgh);
@@ -315,8 +314,6 @@ auto inclusive_scan_over_group_helper(Group group, U x, OpT op,
315314
if (with_init) {
316315
return sycl::inclusive_scan_over_group(group, x, op, T(init));
317316
}
318-
assert((std::is_same_v<T, U> &&
319-
"Without init value T and U should be the same type."));
320317
return sycl::inclusive_scan_over_group(group, T(x), op);
321318
}
322319

@@ -326,8 +323,6 @@ auto exclusive_scan_over_group_helper(Group group, U x, OpT op,
326323
if (with_init) {
327324
return sycl::exclusive_scan_over_group(group, x, T(init), op);
328325
}
329-
assert((std::is_same_v<T, U> &&
330-
"Without init value T and U should be the same type."));
331326
return sycl::exclusive_scan_over_group(group, T(x), op);
332327
}
333328

@@ -457,6 +452,9 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
457452
auto local_id_sycl = host_data.create_local_id_buffer();
458453
auto sub_group_id_sycl = host_data.create_sub_group_id_buffer();
459454

455+
assert((with_init || std::is_same_v<T, U>) &&
456+
"Without init value T and U must be the same type.");
457+
460458
queue
461459
.submit([&](sycl::handler& cgh) {
462460
sycl::accessor<U, 1, sycl::access_mode::read> ref_input_acc(

0 commit comments

Comments
 (0)