Skip to content

Commit 1ae4edc

Browse files
Merge branch 'develop' into aviralgoel/update_copyright_final
2 parents dfea641 + de64664 commit 1ae4edc

File tree

6 files changed

+43
-26
lines changed

6 files changed

+43
-26
lines changed

example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,16 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
6161
GemmTraits,
6262
ComputeDataType>;
6363

64-
// This example only supports BQuant (no AQuant)
65-
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
64+
// Base pipeline selection based on quant mode and preshuffle settings
6665
using BaseGemmPipeline = std::conditional_t<
6766
GemmConfig::PreshuffleB == true,
6867
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
69-
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
68+
std::conditional_t<
69+
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
70+
ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
71+
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
72+
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>,
73+
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
7074

7175
const ck_tile::index_t K_split =
7276
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -125,7 +129,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
125129
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
126130
std::conditional_t<
127131
QuantMode == ck_tile::QuantType::AQuantGrouped,
128-
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
132+
std::conditional_t<GemmConfig::PreshuffleQuant == true,
133+
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
134+
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
129135
std::conditional_t<GemmConfig::PreshuffleB == true,
130136
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
131137
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;

include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
446446
using GemmADataType = ck::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
447447
using GemmBDataType = ck::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
448448

449-
#define GridwiseGemmMultiABDTemplateParameters \
449+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \
450450
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
451451
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
452452
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
@@ -462,7 +462,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
462462
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
463463
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
464464

465-
#define GridwiseGemmTemplateParameters \
465+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS \
466466
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
467467
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
468468
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -480,8 +480,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
480480
template <index_t NXdlPerWave_>
481481
using GridwiseGemmBase = ck::conditional_t<
482482
isMultiA || isMultiB,
483-
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
484-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
483+
GridwiseGemmMultipleABD_xdl_cshuffle<CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS>,
484+
GridwiseGemmMultipleD_xdl_cshuffle<CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS>>;
485+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS
486+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS
485487
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
486488
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
487489

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
439439
}
440440

441441
// GridwiseGemm
442-
#define GridwiseGemmMultiDTemplateParams \
442+
#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \
443443
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
444444
AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
445445
MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
@@ -454,7 +454,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
454454
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
455455
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
456456

457-
#define GridwiseGemmCTransposeTemplateParameters \
457+
#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \
458458
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
459459
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
460460
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \
@@ -470,10 +470,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
470470
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
471471

472472
template <index_t NXdlPerWave_>
473-
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
473+
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
474+
CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS>;
474475
template <index_t NXdlPerWave_>
475-
using GridwiseGemmCTransposeBase =
476-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>;
476+
using GridwiseGemmCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
477+
CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS>;
478+
#undef CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS
479+
#undef CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS
477480
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
478481
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
479482

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
485485
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
486486
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
487487

488-
#define GridwiseGemmMultiABDTemplateParameters \
488+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
489489
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
490490
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
491491
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
@@ -502,7 +502,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
502502
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
503503
BComputeDataType
504504

505-
#define GridwiseGemmTemplateParameters \
505+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
506506
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
507507
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
508508
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -518,7 +518,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
518518
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
519519
BComputeDataType, DoElementwiseBeforeCShuffle
520520

521-
#define GridwiseGemmCTransposeTemplateParameters \
521+
#define CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
522522
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
523523
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
524524
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
@@ -536,14 +536,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
536536

537537
// Use appropriate gridwise gemm
538538
template <index_t NXdlPerWave_>
539-
using GridwiseGemmMultipleABDBase =
540-
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>;
539+
using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle<
540+
CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
541541
template <index_t NXdlPerWave_>
542-
using GridwiseGemmMultipleDBase =
543-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
542+
using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle<
543+
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
544544
template <index_t NXdlPerWave_>
545-
using GridwiseGemmMultipleDCTransposeBase =
546-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>;
545+
using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
546+
CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
547+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
548+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
549+
#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
547550

548551
using GridwiseGemm64 =
549552
std::conditional_t<isMultiA || isMultiB,

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
405405
is_split_valid);
406406
}
407407

408-
#define GridwiseGemmTemplateParameters \
408+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \
409409
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
410410
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
411411
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -422,9 +422,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
422422
AComputeDataType, DoElementwiseBeforeCShuffle
423423
// Use appropriate gridwise gemm
424424
template <index_t NXdlPerWave_>
425-
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
426-
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
427-
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
425+
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
426+
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>;
427+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS
428+
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
429+
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
428430

429431
// desc for blockwise copy
430432
using AGridDesc_AK0_M_AK1 =

include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
256256
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
257257

258258
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
259+
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
259260
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
260261
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
261262
"Aq block window has incorrect lengths for defined AqLayout!");

0 commit comments

Comments
 (0)