Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
set(EXE_NAME tile_example_gemm_quant)
add_executable(${EXE_NAME}
gemm_quant.cpp
gemm_abquant_quantgrouped.cpp
gemm_aquant_quantgrouped.cpp
gemm_aquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfig_ABQuant_Prefill<T>;

void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"non-preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"non-preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"abquant",
"non-preshuffleb",
"non-preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}
17 changes: 15 additions & 2 deletions example/ck_tile/38_block_scale_gemm/gemm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4")
"or bf8i4; for ABQuant: fp8, bf8, i4fp8, or i4bf8")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
Expand All @@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "1000", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("quant_mode", "bquant", "Choose aquant, bquant, abquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
.insert("group_size",
Expand Down Expand Up @@ -75,6 +75,16 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant";
params.push_back(preshufflequant);
}
if(quant_mode == "abquant")
{
std::string preshuffleb =
arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb";
params.push_back(preshuffleb);

std::string preshufflequant =
arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant";
params.push_back(preshufflequant);
}
if(quant_mode != "rowcol" && quant_mode != "tensor")
{
// NOTE: rowcol and tensor pipeline do not use group size
Expand All @@ -85,6 +95,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
return hash_multiple_strings(params);
}

void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void aquant_quantgrouped_preshufflequant_instance_factory(
Expand Down Expand Up @@ -122,6 +134,7 @@ int main(int argc, char* argv[])
ck_tile::hip_check_error(hipSetDevice(device_id));

std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
abquant_quantgrouped_instance_factory(lut);
aquant_quantgrouped_instance_factory(lut);
aquant_quantgrouped_preshufflequant_instance_factory(lut);
bquant_quantgrouped_fp8_instance_factory(lut);
Expand Down
17 changes: 17 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,23 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};

template <typename PrecType>
struct GemmConfig_ABQuant_Prefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
};

template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
Expand Down
Loading