@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
6666 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 32 , 128 , 2 , 32 , 128 , 128 , 128 , 2 , true );
6767 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 64 , 128 , 2 , 32 , 128 , 128 , 128 , 2 , true );
6868
69+ GGML_CUDA_FATTN_MMA_CONFIG_CASE (320 , 256 , 32 , 128 , 2 , 32 , 128 , 128 , 128 , 1 , false );
70+ GGML_CUDA_FATTN_MMA_CONFIG_CASE (320 , 256 , 64 , 256 , 1 , 32 , 128 , 128 , 128 , 1 , false );
71+
6972 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 8 , 64 , 4 , 32 , 256 , 256 , 128 , 1 , false );
7073 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 16 , 64 , 4 , 32 , 256 , 256 , 128 , 1 , false );
7174 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 32 , 128 , 2 , 32 , 128 , 128 , 128 , 1 , false );
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
8588 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 32 , 128 , 2 , 64 , 128 , 128 , 64 , 2 , true );
8689 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 64 , 128 , 2 , 64 , 128 , 128 , 64 , 2 , true );
8790
91+ GGML_CUDA_FATTN_MMA_CONFIG_CASE (320 , 256 , 32 , 128 , 2 , 32 , 128 , 128 , 128 , 1 , false );
92+ GGML_CUDA_FATTN_MMA_CONFIG_CASE (320 , 256 , 64 , 256 , 1 , 32 , 128 , 128 , 128 , 1 , false );
93+
8894 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 8 , 64 , 4 , 32 , 96 , 64 , 128 , 1 , false );
8995 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 16 , 64 , 4 , 32 , 96 , 64 , 128 , 1 , false );
9096 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 32 , 128 , 2 , 32 , 128 , 128 , 128 , 1 , false );
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
118124 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 32 , 128 , 2 , 64 , 128 , 128 , 64 , 2 , true );
119125 GGML_CUDA_FATTN_MMA_CONFIG_CASE (256 , 256 , 64 , 128 , 2 , 64 , 128 , 128 , 64 , 2 , true );
120126
127+ GGML_CUDA_FATTN_MMA_CONFIG_CASE (320 , 256 , 32 , 128 , 2 , 64 , 160 , 128 , 64 , 2 , true );
128+ GGML_CUDA_FATTN_MMA_CONFIG_CASE (320 , 256 , 64 , 128 , 2 , 64 , 160 , 128 , 64 , 2 , false );
129+
121130 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 16 , 64 , 4 , 32 , 128 , 128 , 128 , 1 , false );
122131 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 32 , 128 , 2 , 32 , 128 , 128 , 128 , 1 , false );
123132 GGML_CUDA_FATTN_MMA_CONFIG_CASE (512 , 512 , 64 , 256 , 1 , 32 , 128 , 128 , 128 , 1 , false );
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
12171226 float KQ_max_scale[cols_per_thread];
12181227#pragma unroll
12191228 for (int col = 0 ; col < cols_per_thread; ++col) {
1220- const int jc = cols_per_warp == 8 ? T_C_KQ::get_j (col) : T_C_KQ::get_i (2 *col);
1229+ const int jc = ( threadIdx . y /np)* cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j (col) : T_C_KQ::get_i (2 *col) );
12211230 const float sink = sinks_f[jc % ncols2];
12221231
12231232 const float KQ_max_new = fmaxf (KQ_max[col], sink);
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
18251834extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 2 , 16 );
18261835extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 4 , 16 );
18271836
1837+ // Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
1838+ extern DECL_FATTN_MMA_F16_CASE (320 , 256 , 1 , 32 );
1839+ extern DECL_FATTN_MMA_F16_CASE (320 , 256 , 2 , 32 );
1840+
18281841// For GLM 4.7 Flash
18291842extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 4 , 4 );
18301843extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 8 , 4 );
0 commit comments