@@ -40,9 +40,11 @@ namespace detail {
40
40
* @param is_batch_interleaved is the input data layout batch interleaved
41
41
* @param workgroup_size The size of the work-group. Must be divisible by 2.
42
42
*/
43
- PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup (bool is_batch_interleaved,
44
- Idx workgroup_size) noexcept {
45
- return is_batch_interleaved ? workgroup_size / 2 : 1 ;
43
+ PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup (bool /* is_batch_interleaved*/ ,
44
+ Idx /* workgroup_size*/ ) noexcept {
45
+ // TODO reenable when tests are passing
46
+ // return is_batch_interleaved ? workgroup_size / 2 : 1;
47
+ return 1 ;
46
48
}
47
49
48
50
/* *
@@ -110,8 +112,9 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
110
112
const IdxGlobal input_distance = kh.get_specialization_constant <detail::SpecConstInputDistance>();
111
113
const IdxGlobal output_distance = kh.get_specialization_constant <detail::SpecConstOutputDistance>();
112
114
113
- const bool is_input_batch_interleaved = input_stride == n_transforms && input_distance == 1 ;
114
- const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
115
+ // TODO reable when tests are passing
116
+ const bool is_input_batch_interleaved = false ; // input_stride == n_transforms && input_distance == 1;
117
+ const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
115
118
116
119
global_data.log_message_global (__func__, " entered" , " fft_size" , fft_size, " n_transforms" , n_transforms);
117
120
Idx num_workgroups = static_cast <Idx>(global_data.it .get_group_range (0 ));
@@ -280,8 +283,8 @@ struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<SubgroupSize
280
283
PORTFFT_LOG_FUNCTION_ENTRY ();
281
284
auto & kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels .at (0 )
282
285
: dimension_data.backward_kernels .at (0 );
283
- Idx num_batches_in_local_mem =
284
- input_layout == layout::BATCH_INTERLEAVED ? kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2 : 1 ;
286
+ Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup (
287
+ input_layout == layout::BATCH_INTERLEAVED, kernel_data.used_sg_size * PORTFFT_SGS_IN_WG) ;
285
288
constexpr detail::memory Mem = std::is_pointer_v<TOut> ? detail::memory::USM : detail::memory::BUFFER;
286
289
Scalar* twiddles = kernel_data.twiddles_forward .get ();
287
290
std::size_t local_elements =
@@ -355,8 +358,8 @@ struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struc
355
358
// working memory + twiddles for subgroup impl for the two sizes
356
359
Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup (
357
360
input_layout == layout::BATCH_INTERLEAVED, used_sg_size * PORTFFT_SGS_IN_WG);
358
- return detail::pad_local ( static_cast <std::size_t >(2 * num_batches_in_local_mem) * length,
359
- bank_lines_per_pad_wg ( 2 * static_cast <std::size_t >(sizeof (Scalar)) * m) ) +
361
+ const auto bank_lines_per_pad = bank_lines_per_pad_wg ( 2 * static_cast <std::size_t >(sizeof (Scalar)) * m);
362
+ return detail::pad_local ( static_cast <std::size_t >(2 * num_batches_in_local_mem) * length, bank_lines_per_pad ) +
360
363
2 * (m + n);
361
364
}
362
365
};
0 commit comments