Skip to content

Commit ddb3e95

Browse files
committed
Fix lazy adjacent_difference output selection on MUSA
1 parent 7ffab56 commit ddb3e95

3 files changed

Lines changed: 76 additions & 3 deletions

File tree

cub/detail/type_traits.cuh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,32 @@ struct adjacent_difference_output_impl<Invokable, InputT, false>
123123
template <typename Invokable, typename InputT>
124124
using adjacent_difference_output_t = typename adjacent_difference_output_impl<Invokable, InputT>::type;
125125

126+
template <typename OutputIteratorT,
127+
typename Invokable,
128+
typename InputT,
129+
bool = std::is_same<iterator_value_t<OutputIteratorT>, void>::value>
130+
struct adjacent_difference_output_select_impl
131+
{
132+
using type = iterator_value_t<OutputIteratorT>;
133+
};
134+
135+
template <typename OutputIteratorT,
136+
typename Invokable,
137+
typename InputT>
138+
struct adjacent_difference_output_select_impl<OutputIteratorT,
139+
Invokable,
140+
InputT,
141+
true>
142+
{
143+
using type = adjacent_difference_output_t<Invokable, InputT>;
144+
};
145+
146+
template <typename OutputIteratorT, typename Invokable, typename InputT>
147+
using adjacent_difference_output_select_t =
148+
typename adjacent_difference_output_select_impl<OutputIteratorT,
149+
Invokable,
150+
InputT>::type;
151+
126152

127153
} // namespace detail
128154
CUB_NAMESPACE_END

cub/device/dispatch/dispatch_adjacent_difference.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ DeviceAdjacentDifferenceDifferenceKernel(InputIteratorT input,
7878

7979
// Prefer a concrete output value type when one is available. This avoids
8080
// device-lambda result introspection on MUSA while preserving write-only
81-
// iterators whose value_type is void or unavailable.
82-
using OutputT = detail::non_void_iterator_value_t<
83-
OutputIteratorT, detail::adjacent_difference_output_t<DifferenceOpT, InputT>>;
81+
// iterators whose value_type is void or unavailable. The selection must be
82+
// lazy: eager alias arguments still instantiate the lambda fallback on MUSA.
83+
using OutputT = detail::adjacent_difference_output_select_t<OutputIteratorT,
84+
DifferenceOpT,
85+
InputT>;
8486

8587
using Agent = AgentDifference<ActivePolicyT,
8688
InputIteratorT,

test/test_device_adjacent_difference_lambda.cu

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,50 @@ void TestDeviceLambdaWithoutResultType()
135135
}
136136
}
137137

138+
void TestCapturedDeviceLambdaWithoutResultType()
139+
{
140+
thrust::device_vector<int> rows_flat(6);
141+
rows_flat[0] = 1;
142+
rows_flat[1] = 2;
143+
rows_flat[2] = 1;
144+
rows_flat[3] = 2;
145+
rows_flat[4] = 3;
146+
rows_flat[5] = 4;
147+
148+
thrust::device_vector<int> output(3, -1);
149+
auto row_ids = thrust::make_counting_iterator<long>(0);
150+
151+
const int *rows_ptr = thrust::raw_pointer_cast(rows_flat.data());
152+
const int row_width = 2;
153+
154+
auto difference_op = [=] __device__(long lhs_row, long rhs_row) {
155+
for (int column = 0; column < row_width; ++column)
156+
{
157+
const int lhs = rows_ptr[lhs_row * row_width + column];
158+
const int rhs = rows_ptr[rhs_row * row_width + column];
159+
if (lhs != rhs)
160+
{
161+
return 1;
162+
}
163+
}
164+
165+
return 0;
166+
};
167+
168+
RunSubtractLeftCopy(row_ids, output.begin(), difference_op, 3);
169+
170+
thrust::host_vector<int> expected(3);
171+
expected[0] = 0;
172+
expected[1] = 0;
173+
expected[2] = 1;
174+
175+
AssertEquals(output.size(), expected.size());
176+
for (std::size_t i = 0; i < expected.size(); ++i)
177+
{
178+
AssertEquals(output[i], expected[i]);
179+
}
180+
}
181+
138182
void TestWriteOnlyOutputIteratorFallback()
139183
{
140184
thrust::device_vector<bool> all_differences_correct(1, true);
@@ -152,6 +196,7 @@ int main(int argc, char **argv)
152196
CubDebugExit(args.DeviceInit());
153197

154198
TestDeviceLambdaWithoutResultType();
199+
TestCapturedDeviceLambdaWithoutResultType();
155200
TestWriteOnlyOutputIteratorFallback();
156201

157202
return 0;

0 commit comments

Comments
 (0)