Skip to content

Commit 152193e

Browse files
Adding changelog and other fixes.
1 parent 1366a26 commit 152193e

25 files changed

+2432
-71
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
55
## Composable Kernel 1.2.0 for ROCm 7.2.0
66

77
### Added
8+
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
89
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
910
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
1011
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM

dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ extern "C" {
3535
int conv_bwdw_init()
3636
{
3737
g_bwdw_initialized = true;
38-
return 1;
38+
return 0; // Return 0 on success (consistent with other init functions)
3939
}
4040

4141
void conv_bwdw_cleanup() { g_bwdw_initialized = false; }

dispatcher/bindings/ctypes/conv_ctypes_lib.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818

1919
#include <cstring>
20+
#include <memory>
2021
#include <vector>
2122
#include <hip/hip_runtime.h>
2223

@@ -26,9 +27,9 @@
2627

2728
using namespace ck_tile::dispatcher;
2829

29-
// Global state
30-
static ConvRegistry* g_registry = nullptr;
31-
static ConvDispatcher* g_dispatcher = nullptr;
30+
// Global state (using shared_ptr for safe memory management)
31+
static std::shared_ptr<ConvRegistry> g_registry = nullptr;
32+
static std::shared_ptr<ConvDispatcher> g_dispatcher = nullptr;
3233
static std::vector<const ConvKernelInstance*> g_kernels;
3334

3435
extern "C" {
@@ -42,8 +43,8 @@ int conv_dispatcher_init()
4243
if(g_registry)
4344
return 0; // Already initialized
4445

45-
g_registry = new ConvRegistry();
46-
g_dispatcher = new ConvDispatcher(g_registry);
46+
g_registry = std::make_shared<ConvRegistry>();
47+
g_dispatcher = std::make_shared<ConvDispatcher>(g_registry.get());
4748

4849
// Register kernel configurations
4950
using namespace ck_tile::dispatcher::conv_decl;
@@ -94,10 +95,9 @@ int conv_dispatcher_init()
9495

9596
int conv_dispatcher_cleanup()
9697
{
97-
delete g_dispatcher;
98-
delete g_registry;
99-
g_dispatcher = nullptr;
100-
g_registry = nullptr;
98+
// shared_ptr automatically handles cleanup when reset
99+
g_dispatcher.reset();
100+
g_registry.reset();
101101
g_kernels.clear();
102102
return 0;
103103
}
@@ -343,11 +343,10 @@ float conv_dispatcher_run(const void* input_ptr,
343343

344344
#ifdef CONV_BWD_WEIGHT_AVAILABLE
345345
case 2: // Backward weight
346-
// Convention: caller passes (grad_output, input, grad_weight_buffer)
346+
// Convention: caller passes (input, grad_output, grad_weight_buffer)
347347
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
348-
// This is consistent with bwd_data where grad_output goes in input_ptr slot.
349348
// run_bwd_weight expects: (input, grad_output, grad_weight)
350-
return run_bwd_weight(weight_ptr, input_ptr, output_ptr, prob, stream);
349+
return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream);
351350
#endif
352351

353352
default: return -1.0f;

dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cstdint>
1818
#include <cstring>
1919
#include <iostream>
20+
#include <memory>
2021
#include <sstream>
2122
#include <string>
2223

@@ -31,9 +32,9 @@ using namespace ck_tile::dispatcher;
3132
using namespace ck_tile::dispatcher::backends;
3233
using Priority = ck_tile::dispatcher::Registry::Priority;
3334

34-
// Global dispatcher (initialized once)
35-
static Dispatcher* g_dispatcher = nullptr;
36-
static bool g_initialized = false;
35+
// Global dispatcher (initialized once, managed via shared_ptr for safe cleanup)
36+
static std::shared_ptr<Dispatcher> g_dispatcher = nullptr;
37+
static bool g_initialized = false;
3738

3839
#define HIP_CHECK(call) \
3940
{ \
@@ -98,8 +99,8 @@ int dispatcher_initialize()
9899
Registry::instance().clear();
99100
Registry::instance().register_kernel(kernel, Priority::High);
100101

101-
// Create dispatcher
102-
g_dispatcher = new Dispatcher();
102+
// Create dispatcher (using shared_ptr for safe memory management)
103+
g_dispatcher = std::make_shared<Dispatcher>();
103104
g_initialized = true;
104105

105106
return 0;
@@ -294,19 +295,53 @@ int dispatcher_run_gemm(const void* A, // Host pointer
294295
const BDataType* B_host = static_cast<const BDataType*>(B);
295296
CDataType* C_host = static_cast<CDataType*>(C);
296297

297-
// Allocate GPU memory
298+
// Allocate GPU memory with proper cleanup on failure
298299
ADataType* A_dev = nullptr;
299300
BDataType* B_dev = nullptr;
300301
CDataType* C_dev = nullptr;
301302

302-
HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType)));
303-
HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType)));
304-
HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType)));
303+
// Helper lambda for cleanup
304+
auto cleanup_gpu_mem = [&]() {
305+
if(A_dev)
306+
(void)hipFree(A_dev);
307+
if(B_dev)
308+
(void)hipFree(B_dev);
309+
if(C_dev)
310+
(void)hipFree(C_dev);
311+
};
312+
313+
if(hipMalloc(&A_dev, M * K * sizeof(ADataType)) != hipSuccess)
314+
{
315+
cleanup_gpu_mem();
316+
return -1;
317+
}
318+
if(hipMalloc(&B_dev, K * N * sizeof(BDataType)) != hipSuccess)
319+
{
320+
cleanup_gpu_mem();
321+
return -1;
322+
}
323+
if(hipMalloc(&C_dev, M * N * sizeof(CDataType)) != hipSuccess)
324+
{
325+
cleanup_gpu_mem();
326+
return -1;
327+
}
305328

306329
// Copy input data to GPU
307-
HIP_CHECK(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice));
308-
HIP_CHECK(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice));
309-
HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType)));
330+
if(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice) != hipSuccess)
331+
{
332+
cleanup_gpu_mem();
333+
return -1;
334+
}
335+
if(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice) != hipSuccess)
336+
{
337+
cleanup_gpu_mem();
338+
return -1;
339+
}
340+
if(hipMemset(C_dev, 0, M * N * sizeof(CDataType)) != hipSuccess)
341+
{
342+
cleanup_gpu_mem();
343+
return -1;
344+
}
310345

311346
// Run GEMM via dispatcher (kernel already selected, shouldn't throw)
312347
float exec_time;
@@ -317,14 +352,16 @@ int dispatcher_run_gemm(const void* A, // Host pointer
317352
catch(const std::exception& e)
318353
{
319354
// Unexpected error during execution
320-
(void)hipFree(A_dev);
321-
(void)hipFree(B_dev);
322-
(void)hipFree(C_dev);
355+
cleanup_gpu_mem();
323356
return -1;
324357
}
325358

326359
// Copy result back to host
327-
HIP_CHECK(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
360+
if(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost) != hipSuccess)
361+
{
362+
cleanup_gpu_mem();
363+
return -1;
364+
}
328365

329366
// Store timing if requested
330367
if(time_ms)
@@ -333,9 +370,7 @@ int dispatcher_run_gemm(const void* A, // Host pointer
333370
}
334371

335372
// Cleanup GPU memory
336-
(void)hipFree(A_dev);
337-
(void)hipFree(B_dev);
338-
(void)hipFree(C_dev);
373+
cleanup_gpu_mem();
339374

340375
return 0;
341376
}
@@ -434,11 +469,8 @@ const char* dispatcher_export_registry_json()
434469
*/
435470
void dispatcher_cleanup()
436471
{
437-
if(g_dispatcher)
438-
{
439-
delete g_dispatcher;
440-
g_dispatcher = nullptr;
441-
}
472+
// shared_ptr automatically handles cleanup when reset
473+
g_dispatcher.reset();
442474
g_initialized = false;
443475
}
444476

dispatcher/codegen/arch_filter.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ class OperatorType(Enum):
132132
ELEMENT_SIZE_MAP,
133133
WARP_SUPPORTED_COMBINATIONS,
134134
WARP_TILE_SUPPORTED_COMBINATIONS,
135+
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS,
136+
PRESHUFFLE_PIPELINES,
135137
LDS_CAPACITY_LIMITS,
136138
TRAIT_UNSUPPORTED_COMBINATIONS,
137139
DTYPE_COMBINATIONS,
@@ -179,6 +181,21 @@ class OperatorType(Enum):
179181
},
180182
}
181183

184+
# Preshuffle-specific warp tile combinations (no [4, 64, 16])
185+
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = {
186+
"gfx942": {
187+
"fp16_fp16_fp32": [
188+
[32, 32, 8],
189+
[16, 16, 16],
190+
[32, 32, 16],
191+
[16, 16, 32],
192+
[64, 4, 16],
193+
],
194+
},
195+
}
196+
197+
PRESHUFFLE_PIPELINES = ["preshufflev2"]
198+
182199
LDS_CAPACITY_LIMITS = {"compv4": 32768, "preshufflev2": 32768, "default": 65536}
183200

184201
TRAIT_UNSUPPORTED_COMBINATIONS = {
@@ -566,9 +583,20 @@ def _validate_warp_config(self, config: KernelConfig, result: ValidationResult):
566583

567584
def _validate_warp_tile_combo(self, config: KernelConfig, result: ValidationResult):
568585
"""Validate warp tile combination against architecture and data types"""
569-
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {})
586+
# Use preshuffle-specific warp tiles for preshuffle operator
587+
if config.operator == OperatorType.GEMM_PRESHUFFLE:
588+
gpu_combos = PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get(
589+
self.gpu_arch, {}
590+
)
591+
combo_source = "preshuffle"
592+
else:
593+
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {})
594+
combo_source = "standard"
595+
570596
if not gpu_combos:
571-
msg = f"No warp tile combinations defined for {self.gpu_arch}"
597+
msg = (
598+
f"No {combo_source} warp tile combinations defined for {self.gpu_arch}"
599+
)
572600
if self.strict_mode:
573601
result.add_error(msg)
574602
else:
@@ -579,19 +607,27 @@ def _validate_warp_tile_combo(self, config: KernelConfig, result: ValidationResu
579607
if not dtype_combos:
580608
# Data type combo not explicitly listed - may still be valid
581609
result.add_warning(
582-
f"No warp tile combinations defined for {config.dtype_key} on {self.gpu_arch}"
610+
f"No {combo_source} warp tile combinations defined for {config.dtype_key} on {self.gpu_arch}"
583611
)
584612
return
585613

586614
current = [config.warp_tile_m, config.warp_tile_n, config.warp_tile_k]
587615
if current not in dtype_combos:
588616
result.add_error(
589-
f"Invalid warp tile {current} for {config.dtype_key} on {self.gpu_arch}. "
617+
f"Invalid warp tile {current} for {config.dtype_key} on {self.gpu_arch} ({combo_source}). "
590618
f"Allowed: {dtype_combos}"
591619
)
592620

593621
def _validate_trait_combo(self, config: KernelConfig, result: ValidationResult):
594622
"""Validate trait (pipeline, epilogue, scheduler) combination"""
623+
# Preshuffle requires specific pipelines
624+
if config.operator == OperatorType.GEMM_PRESHUFFLE:
625+
if config.pipeline not in PRESHUFFLE_PIPELINES:
626+
result.add_error(
627+
f"Preshuffle GEMM requires pipeline in {PRESHUFFLE_PIPELINES}, "
628+
f"got {config.pipeline}"
629+
)
630+
595631
combo = (config.pipeline, config.epilogue, config.scheduler)
596632
if combo in TRAIT_UNSUPPORTED_COMBINATIONS:
597633
result.add_error(
@@ -769,7 +805,7 @@ def get_supported_archs() -> List[str]:
769805
def get_arch_family(gpu_arch: str) -> Optional[str]:
770806
"""Get the GPU family for an architecture"""
771807
family = ARCH_FAMILY_MAP.get(gpu_arch.lower())
772-
return family.value if family else None
808+
return family if family else None # ARCH_FAMILY_MAP contains strings, not Enums
773809

774810

775811
def create_filter_for_current_gpu() -> Optional[ArchFilter]:

dispatcher/codegen/arch_specs.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,5 +232,33 @@
232232
["compv4", "cshuffle", "interwave"],
233233
["compv4", "default", "interwave"]
234234
]
235+
},
236+
237+
"preshuffle_warp_tile_combos": {
238+
"_comment": "Preshuffle-specific warp tile combinations (subset of standard GEMM, no [4, 64, 16])",
239+
"gfx90a": {
240+
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
241+
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
242+
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
243+
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]]
244+
},
245+
"gfx942": {
246+
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
247+
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
248+
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
249+
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
250+
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
251+
},
252+
"gfx950": {
253+
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
254+
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
255+
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
256+
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
257+
}
258+
},
259+
260+
"preshuffle_pipelines": {
261+
"_comment": "Pipelines supported for preshuffle GEMM variant",
262+
"supported": ["preshufflev2"]
235263
}
236264
}

0 commit comments

Comments
 (0)