Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
527b718
formatted
amd-khushbu Oct 31, 2025
c25420e
formatted
amd-khushbu Nov 5, 2025
8999dae
formatting
amd-khushbu Nov 5, 2025
00e61e0
formatting
amd-khushbu Nov 6, 2025
e5d6a80
formatting
amd-khushbu Nov 7, 2025
a967b49
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 7, 2025
bc26224
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 7, 2025
c553c87
enable prefill shapes
amd-khushbu Nov 7, 2025
8818018
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 8, 2025
9debcc1
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 10, 2025
51ec0c2
merge with Cong's Changes
amd-khushbu Nov 11, 2025
869bc5b
adding preshuffle quant as new parameter and its associated new files
amd-khushbu Nov 11, 2025
075c36b
remove debugging statements
amd-khushbu Nov 11, 2025
0f79fa5
adding test
amd-khushbu Nov 12, 2025
b8b5709
enable preshuffle quant with permuteN
amd-khushbu Nov 12, 2025
903800f
rebase with develop
amd-khushbu Nov 13, 2025
48e7559
updating readme and correcponding gemmconfigs
amd-khushbu Nov 13, 2025
36f2f87
updating cmake file
amd-khushbu Nov 13, 2025
07700cc
fixing CI failures for grouped quant gemm
amd-khushbu Nov 14, 2025
f5856af
Merge branch 'develop' into lwpck-3984
amd-khushbu Nov 14, 2025
2275548
debugging permuteN
amd-khushbu Nov 18, 2025
a974a08
debugging
amd-khushbu Nov 20, 2025
cf3f9b5
Merge branch 'develop' into lwpck-3985
amd-khushbu Nov 20, 2025
04aaf97
debugging PermuteN
amd-khushbu Nov 24, 2025
7788979
initial commit
amd-khushbu Nov 25, 2025
f290428
working code for preshuffleb
amd-khushbu Nov 26, 2025
3447196
resolving merge conflicts
amd-khushbu Nov 26, 2025
2441260
Merge branch 'develop' into 1dQuantPreshuffleWeight
amd-khushbu Dec 4, 2025
18fe146
Merge branch 'develop' into 2dQuantPreshuffleWeight
ThomasNing Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,154 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};

lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"bquant",
"preshuffleb",
"non-preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
Expand All @@ -52,10 +170,50 @@ void bquant_quantgrouped_preshuffleb_instance_factory(
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
Expand Down
20 changes: 14 additions & 6 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;

constexpr bool TiledPermuteN =
(QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
if(s.log_level_ > 0)
{
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN);
}
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
Expand All @@ -154,7 +161,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
1,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
TiledPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;

Expand Down Expand Up @@ -550,7 +557,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
b_k_n.SetZero();
bq_tensor_ptr->SetZero();
}

ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
Expand Down Expand Up @@ -603,7 +609,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
if constexpr(GemmConfig::TiledMMAPermuteN)
if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1)
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
Expand All @@ -628,11 +634,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
QuantGroupSize::kN == 1)
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr);
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, QuantGroupSize::kN);

if constexpr(GemmConfig::PreshuffleQuant)
{
Expand All @@ -652,7 +658,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
}
else
{
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
}

invoke_gemm<GemmConfig,
Expand Down
10 changes: 7 additions & 3 deletions include/ck_tile/host/tensor_shuffle_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];

if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
Expand Down Expand Up @@ -110,16 +111,19 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}

template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);

int n_ = t.get_lengths()[1];
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;

ck_tile::HostTensor<T> t_view(
{n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}
Expand Down
Loading