Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2e7f777

Browse files
committedMar 7, 2024
avoid batch interleaved path for now
1 parent 037fa6f commit 2e7f777

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed
 

‎src/portfft/dispatcher/workgroup_dispatcher.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ namespace detail {
4040
* @param is_batch_interleaved is the input data layout batch interleaved
4141
* @param workgroup_size The size of the work-group. Must be divisible by 2.
4242
*/
43-
PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool is_batch_interleaved,
44-
Idx workgroup_size) noexcept {
45-
return is_batch_interleaved ? workgroup_size / 2 : 1;
43+
PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool /*is_batch_interleaved*/,
44+
Idx /*workgroup_size*/) noexcept {
45+
// TODO reenable when tests are passing
46+
// return is_batch_interleaved ? workgroup_size / 2 : 1;
47+
return 1;
4648
}
4749

4850
/**
@@ -110,8 +112,9 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
110112
const IdxGlobal input_distance = kh.get_specialization_constant<detail::SpecConstInputDistance>();
111113
const IdxGlobal output_distance = kh.get_specialization_constant<detail::SpecConstOutputDistance>();
112114

113-
const bool is_input_batch_interleaved = input_stride == n_transforms && input_distance == 1;
114-
const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
115+
// TODO reable when tests are passing
116+
const bool is_input_batch_interleaved = false; // input_stride == n_transforms && input_distance == 1;
117+
const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
115118

116119
global_data.log_message_global(__func__, "entered", "fft_size", fft_size, "n_transforms", n_transforms);
117120
Idx num_workgroups = static_cast<Idx>(global_data.it.get_group_range(0));
@@ -280,8 +283,8 @@ struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<SubgroupSize
280283
PORTFFT_LOG_FUNCTION_ENTRY();
281284
auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0)
282285
: dimension_data.backward_kernels.at(0);
283-
Idx num_batches_in_local_mem =
284-
input_layout == layout::BATCH_INTERLEAVED ? kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2 : 1;
286+
Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup(
287+
input_layout == layout::BATCH_INTERLEAVED, kernel_data.used_sg_size * PORTFFT_SGS_IN_WG);
285288
constexpr detail::memory Mem = std::is_pointer_v<TOut> ? detail::memory::USM : detail::memory::BUFFER;
286289
Scalar* twiddles = kernel_data.twiddles_forward.get();
287290
std::size_t local_elements =
@@ -355,8 +358,8 @@ struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struc
355358
// working memory + twiddles for subgroup impl for the two sizes
356359
Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup(
357360
input_layout == layout::BATCH_INTERLEAVED, used_sg_size * PORTFFT_SGS_IN_WG);
358-
return detail::pad_local(static_cast<std::size_t>(2 * num_batches_in_local_mem) * length,
359-
bank_lines_per_pad_wg(2 * static_cast<std::size_t>(sizeof(Scalar)) * m)) +
361+
const auto bank_lines_per_pad = bank_lines_per_pad_wg(2 * static_cast<std::size_t>(sizeof(Scalar)) * m);
362+
return detail::pad_local(static_cast<std::size_t>(2 * num_batches_in_local_mem) * length, bank_lines_per_pad) +
360363
2 * (m + n);
361364
}
362365
};

0 commit comments

Comments
 (0)
Please sign in to comment.