Skip to content

Commit 40c9981

Browse files
committed
Merge up to v2.5.5
1 parent 9140339 commit 40c9981

9 files changed

+117
-40
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ FlashAttention-2 currently supports:
3333
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
3434
GPUs for now.
3535
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
36-
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
36+
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
3737

3838
## Citation
3939
If you use this codebase, or otherwise found our work valuable, please cite:

csrc/cutlass

Submodule cutlass updated 61 files

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
521521
// if (cute::thread(32, 0)) { print(scores); }
522522
// Compute the exponential value.
523523
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
524-
if (Is_dropout) {
524+
if constexpr (Is_dropout) {
525525
int warp_id = tidx / 32;
526526
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
527527
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32

csrc/flash_attn/src/flash_bwd_launch_template.h

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
7070
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
7171
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
7272
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
73-
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
74-
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
75-
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
73+
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
74+
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
75+
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
7676
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
7777
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
7878
// If Is_local, set Is_causal to false
@@ -101,7 +101,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
101101

102102
template<typename Kernel_traits, bool Is_dropout>
103103
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
104+
#ifndef FLASHATTENTION_DISABLE_BACKWARD
104105
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
106+
#endif
105107
}
106108

107109
template<typename T>
@@ -115,7 +117,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
115117
if (status_ != cudaSuccess) {
116118
C10_CUDA_CHECK(status_);
117119
}
118-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
120+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
119121
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
120122
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
121123
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
@@ -140,7 +142,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
140142
C10_CUDA_CHECK(status_);
141143
}
142144
// printf("max_smem_per_block = %d\n", max_smem_per_block);
143-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
145+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
144146
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
145147
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
146148
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
@@ -185,7 +187,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
185187
C10_CUDA_CHECK(status_);
186188
}
187189
// printf("max_smem_per_block = %d\n", max_smem_per_block);
188-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
190+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
189191
if (max_smem_per_block >= 116 * 1024) {
190192
if constexpr(!Is_dropout) { // 92KB
191193
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
@@ -211,7 +213,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
211213
C10_CUDA_CHECK(status_);
212214
}
213215
// printf("max_smem_per_block = %d\n", max_smem_per_block);
214-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
216+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
215217
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
216218
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
217219
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
@@ -244,7 +246,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
244246
if (status_ != cudaSuccess) {
245247
C10_CUDA_CHECK(status_);
246248
}
247-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
249+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
248250
if (max_smem_per_block >= 116 * 1024) {
249251
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
250252
} else {
@@ -264,7 +266,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
264266
if (status_ != cudaSuccess) {
265267
C10_CUDA_CHECK(status_);
266268
}
267-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
269+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
268270
if (max_smem_per_block >= 136 * 1024) {
269271
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
270272
} else {
@@ -276,7 +278,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
276278
template<typename T>
277279
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
278280
constexpr static int Headdim = 224;
279-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
281+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
280282
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
281283
});
282284
}
@@ -292,11 +294,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
292294
if (status_ != cudaSuccess) {
293295
C10_CUDA_CHECK(status_);
294296
}
295-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
297+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
296298
if (max_smem_per_block >= 176 * 1024) { // H100
297299
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
298-
} else { // A100, we don't do double buffering to save smem
300+
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
299301
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
302+
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
303+
if constexpr (!Is_dropout) {
304+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
305+
}
300306
}
301307
});
302308
}

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
4343
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
4444
const bool return_softmax = params.p_ptr != nullptr;
4545
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
46-
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
47-
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
46+
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
47+
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
4848
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
49-
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
49+
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
5050
// Will only return softmax if dropout, to reduce compilation time.
5151
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
5252
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
@@ -84,11 +84,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
8484
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
8585
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
8686
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
87-
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
88-
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
87+
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
88+
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
8989
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
9090
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
91-
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
91+
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
9292
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
9393
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
9494
// If Is_local, set Is_causal to false
@@ -114,7 +114,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
114114
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
115115
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
116116
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
117-
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
117+
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
118118
if (params.num_splits <= 2) {
119119
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
120120
} else if (params.num_splits <= 4) {
@@ -148,7 +148,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
148148
template<typename T>
149149
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
150150
constexpr static int Headdim = 32;
151-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
151+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
152152
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
153153
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
154154
});
@@ -158,7 +158,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
158158
template<typename T>
159159
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
160160
constexpr static int Headdim = 64;
161-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
161+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
162162
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
163163
if constexpr(!Is_dropout) {
164164
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
@@ -180,13 +180,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
180180
template<typename T>
181181
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
182182
constexpr static int Headdim = 96;
183+
183184
// auto dprops = at::cuda::getCurrentDeviceProperties();
184185
int device, major, minor;
185186
C10_CUDA_CHECK(cudaGetDevice(&device));
186187
C10_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
187188
C10_CUDA_CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
188189
bool is_sm8x = major == 8 && minor > 0;
189-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
190+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
190191
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
191192
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
192193
if (is_sm8x) {
@@ -217,7 +218,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
217218
bool is_sm8x = major == 8 && minor > 0;
218219
// auto dprops = at::cuda::getCurrentDeviceProperties();
219220
// bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
220-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
221+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
221222
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
222223
if constexpr(!Is_dropout) {
223224
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -259,7 +260,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
259260
bool is_sm8x = major == 8 && minor > 0;
260261
// auto dprops = at::cuda::getCurrentDeviceProperties();
261262
// bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
262-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
263+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
263264
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
264265
// For A100, H100, 128 x 32 is the fastest.
265266
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -287,7 +288,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
287288
template<typename T>
288289
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
289290
constexpr static int Headdim = 192;
290-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
291+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
291292
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
292293
if constexpr(!Is_dropout) {
293294
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
@@ -315,7 +316,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
315316
C10_CUDA_CHECK(status_);
316317
}
317318
// printf("max_smem_per_block = %d\n", max_smem_per_block);
318-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
319+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
319320
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
320321
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
321322
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
@@ -346,7 +347,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
346347
C10_CUDA_CHECK(status_);
347348
}
348349
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
349-
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
350+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
350351
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
351352
// For A100, we want to run with 128 x 64 (128KB smem).
352353
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.

csrc/flash_attn/src/kernel_traits.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,11 @@ struct Flash_bwd_kernel_traits : public Base {
231231
// TODO: generalize to other values of kBlockN
232232
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
233233
// static constexpr int kPBlockN = kBlockN;
234-
static_assert(kBlockN >= 64);
234+
// Temporarily disabling this for hdim 256 on sm86 and sm89
235+
// static_assert(kBlockN >= 64);
236+
static_assert(kBlockN >= 32);
235237
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
236-
static constexpr int kPBlockN = 64;
238+
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
237239
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
238240
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
239241
static constexpr int kSwizzlePdS = 3;

csrc/flash_attn/src/softmax.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
5555
reduce_<zero_init>(tensor, max, max_op);
5656
}
5757

58-
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
58+
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
5959
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
6060
SumOp<float> sum_op;
61-
reduce_(tensor, sum, sum_op);
61+
thread_reduce_<zero_init>(tensor, sum, sum_op);
6262
}
6363

6464
// Apply the exp to all the elements.
@@ -133,7 +133,7 @@ struct Softmax {
133133
if (Is_first) {
134134
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
135135
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
136-
flash::reduce_sum(scores, row_sum);
136+
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
137137
} else {
138138
Tensor scores_max_prev = make_fragment_like(row_max);
139139
cute::copy(row_max, scores_max_prev);
@@ -152,15 +152,16 @@ struct Softmax {
152152
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
153153
}
154154
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
155-
Tensor scores_sum_cur = make_fragment_like(row_sum);
156-
flash::reduce_sum(scores, scores_sum_cur);
157-
#pragma unroll
158-
for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); }
155+
// We don't do the reduce across threads here since we don't need to use the row_sum.
156+
// We do that reduce at the end when we need to normalize the softmax.
157+
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
159158
}
160159
};
161160

162161
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
163162
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
163+
SumOp<float> sum_op;
164+
quad_allreduce_(row_sum, row_sum, sum_op);
164165
TensorT lse = make_fragment_like(row_sum);
165166
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
166167
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);

0 commit comments

Comments
 (0)