Skip to content

Commit 6d3b286

Browse files
Another round of improvements based on feedback.
1 parent c5c7d0d commit 6d3b286

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2875
-1000
lines changed

dispatcher/README.md

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ cmake .. \
166166
### Step 4: Build
167167

168168
```bash
169-
# Build all targets (uses all CPU cores)
169+
# Build all targets (generates kernels automatically, then compiles)
170170
make -j$(nproc)
171171

172172
# Or build specific targets
@@ -179,6 +179,31 @@ make dispatcher_conv_bwdw_lib # Conv backward weight library for Python
179179
make python_libs -j$(nproc)
180180
```
181181

182+
### Kernel Generation Targets
183+
184+
Kernels are generated automatically during `make`, but you can also control generation explicitly:
185+
186+
```bash
187+
# Generate all kernels only (no compilation)
188+
make generate_all_kernels
189+
190+
# Generate specific kernel types
191+
make generate_gemm_kernels # GEMM kernels only
192+
make generate_conv_kernels # Conv kernels (fwd + bwd)
193+
make generate_conv_fwd_kernels # Conv forward only
194+
make generate_conv_bwd_kernels # Conv backward only
195+
196+
# Force regenerate (even if kernels exist)
197+
make regenerate_all_kernels
198+
make regenerate_gemm_kernels
199+
make regenerate_conv_kernels
200+
201+
# Generate for specific GPU architecture
202+
make generate_kernels_gfx942 # MI300X
203+
make generate_kernels_gfx90a # MI200
204+
make generate_kernels_gfx1100 # RDNA3
205+
```
206+
182207
### Step 5: Verify Build
183208

184209
```bash
@@ -305,6 +330,99 @@ Step 4: GPU Execution
305330

306331
---
307332

333+
## Benchmark Parameters
334+
335+
The dispatcher supports fine-grained control over benchmarking, matching CK Tile's `stream_config`:
336+
337+
### Available Parameters
338+
339+
| Parameter | Type | Default | Description |
340+
|-----------|------|---------|-------------|
341+
| `warmup` | int | 5 | Warmup iterations (discarded from timing) |
342+
| `repeat` | int | 20 | Benchmark iterations (averaged) |
343+
| `flush_cache` | bool | false | Flush GPU L2 cache between iterations |
344+
| `rotating_count` | int | 1 | Rotating buffer count (for cache simulation) |
345+
| `timer` | string | "gpu" | Timer type: "gpu" (HIP events) or "cpu" |
346+
| `init` | string | "random" | Matrix initialization: "random", "linear", "constant" |
347+
| `split_k` | int | 1 | Split-K parallelism factor |
348+
349+
### Python Usage
350+
351+
```python
352+
from conv_utils import GpuConvRunner
353+
354+
# Basic usage (default benchmark settings)
355+
runner = GpuConvRunner()
356+
357+
# Advanced benchmark settings
358+
runner = GpuConvRunner(
359+
warmup=10, # More warmup iterations
360+
repeat=100, # More benchmark iterations
361+
flush_cache=True, # Flush L2 cache (for memory-bound analysis)
362+
rotating_count=4, # 4 rotating buffers
363+
timer="gpu", # Use GPU timer (most accurate)
364+
)
365+
366+
result = runner.run(input_data, weight_data, problem)
367+
print(f"Average time: {result['time_ms']:.4f} ms")
368+
print(f"TFLOPS: {result['tflops']:.2f}")
369+
```
370+
371+
### C++ Usage
372+
373+
```cpp
374+
// Basic timing
375+
ck_tile::stream_config cfg{nullptr, true};
376+
377+
// Advanced benchmark settings
378+
ck_tile::stream_config cfg{
379+
nullptr, // stream_id (nullptr = default stream)
380+
true, // time_kernel
381+
1, // log_level
382+
10, // cold_niters (warmup)
383+
100, // nrepeat
384+
true, // is_gpu_timer
385+
true, // flush_cache
386+
4 // rotating_count
387+
};
388+
389+
float avg_time = kernel.run(args, cfg);
390+
```
391+
392+
### Command Line (Python Examples)
393+
394+
```bash
395+
# Basic run
396+
python3 examples/gemm/python/10_advanced_benchmark.py
397+
398+
# With benchmark parameters
399+
python3 examples/gemm/python/10_advanced_benchmark.py \
400+
--warmup 10 \
401+
--repeat 100 \
402+
--flush-cache \
403+
--rotating-count 4 \
404+
--timer gpu
405+
406+
# For memory-bound analysis
407+
python3 examples/conv/python/13_advanced_benchmark.py \
408+
--flush-cache \
409+
--init constant \
410+
-n 1 -c 256 -k 256 -hi 56 -wi 56
411+
```
412+
413+
### When to Use Each Parameter
414+
415+
| Use Case | Recommended Settings |
416+
|----------|---------------------|
417+
| Quick test | `warmup=1, repeat=3` |
418+
| Stable benchmark | `warmup=10, repeat=100` |
419+
| Memory-bound analysis | `flush_cache=True, rotating_count=4` |
420+
| Compute-bound analysis | `flush_cache=False` (default) |
421+
| Debug timing | `timer="cpu"` |
422+
| Production | `timer="gpu"` (default) |
423+
424+
---
425+
308426
## External Integration
309427

310428
### Using Dispatcher in Your Own Project

dispatcher/codegen/unified_conv_codegen.py

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,25 @@ class ConvKernelConfig:
124124
vector_size_b: int = 8
125125
vector_size_c: int = 8
126126

127-
# Fixed parameters
127+
# Occupancy parameters
128128
block_per_cu: int = 1
129129
num_wave_groups: int = 1
130+
num_groups_to_merge: int = 1 # For group merged convolution
131+
132+
# Double buffering
133+
double_smem_buffer: bool = False
130134

131135
def name(self, datatype: str) -> str:
132-
"""Generate kernel name"""
136+
"""
137+
Generate kernel name that uniquely identifies the kernel configuration.
138+
139+
Format: conv_{variant}_{dtype}_{ndim}d_{pipeline}_{epilogue}_{scheduler}
140+
_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}
141+
_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
142+
[_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}]
143+
144+
All parameters that affect kernel behavior are included.
145+
"""
133146
t = self.tile
134147
tr = self.trait
135148

@@ -139,12 +152,42 @@ def name(self, datatype: str) -> str:
139152
ConvVariant.BACKWARD_WEIGHT: "bwdw",
140153
}[self.variant]
141154

155+
# Core identity: variant, dtype, dims
142156
name = f"conv_{variant_str}_{datatype}_{self.ndim_spatial}d"
157+
158+
# Pipeline configuration
143159
name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}"
160+
161+
# Block tile dimensions (M_Tile x N_Tile x K_Tile)
144162
name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}"
163+
164+
# Wave distribution (M_Warp x N_Warp x K_Warp)
145165
name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}"
146166

147-
# Add padding suffix if not all enabled
167+
# Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile)
168+
name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}"
169+
170+
# Vector sizes (only if non-default)
171+
if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8):
172+
name += (
173+
f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}"
174+
)
175+
176+
# Occupancy hints (only if non-default)
177+
if self.block_per_cu != 1:
178+
name += f"_bpc{self.block_per_cu}"
179+
180+
if self.num_wave_groups != 1:
181+
name += f"_wg{self.num_wave_groups}"
182+
183+
if self.num_groups_to_merge != 1:
184+
name += f"_gm{self.num_groups_to_merge}"
185+
186+
# Double SMEM buffer (for compute V4+)
187+
if self.double_smem_buffer or tr.double_smem_buffer:
188+
name += "_dsb"
189+
190+
# Padding suffix (only if not all enabled)
148191
if not (tr.pad_m and tr.pad_n and tr.pad_k):
149192
name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}"
150193

@@ -786,6 +829,44 @@ def main():
786829
help="List configurations without generating",
787830
)
788831

832+
# Individual kernel configuration (when not using predefined configs)
833+
parser.add_argument("--tile-m", type=int, help="Block tile M dimension")
834+
parser.add_argument("--tile-n", type=int, help="Block tile N dimension")
835+
parser.add_argument("--tile-k", type=int, help="Block tile K dimension")
836+
parser.add_argument("--warp-m", type=int, help="Wave distribution M")
837+
parser.add_argument("--warp-n", type=int, help="Wave distribution N")
838+
parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K")
839+
parser.add_argument("--warp-tile-m", type=int, help="Warp tile M")
840+
parser.add_argument("--warp-tile-n", type=int, help="Warp tile N")
841+
parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K")
842+
parser.add_argument(
843+
"--pipeline",
844+
type=str,
845+
choices=["mem", "compv3", "compv4", "compv5"],
846+
help="Pipeline type",
847+
)
848+
parser.add_argument(
849+
"--scheduler",
850+
type=str,
851+
choices=["intrawave", "interwave"],
852+
help="Scheduler type",
853+
)
854+
parser.add_argument(
855+
"--epilogue",
856+
type=str,
857+
default="cshuffle",
858+
choices=["cshuffle", "default"],
859+
help="Epilogue type",
860+
)
861+
parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension")
862+
parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension")
863+
parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension")
864+
parser.add_argument("--vector-a", type=int, default=4, help="Vector size A")
865+
parser.add_argument("--vector-b", type=int, default=8, help="Vector size B")
866+
parser.add_argument("--vector-c", type=int, default=8, help="Vector size C")
867+
parser.add_argument("--block-per-cu", type=int, default=1, help="Blocks per CU")
868+
parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups")
869+
789870
args = parser.parse_args()
790871

791872
if args.verbose:
@@ -799,11 +880,53 @@ def main():
799880
}
800881
requested_variants = [variant_map[v] for v in args.variant]
801882

802-
# Get configurations for target arch with requested variants and ndims
803-
filtered_configs = get_default_configs(
804-
arch=args.arch, variants=requested_variants, ndims=args.ndim
883+
# Check if user specified custom configuration
884+
custom_config = (
885+
args.tile_m is not None or args.tile_n is not None or args.pipeline is not None
805886
)
806887

888+
if custom_config:
889+
# Build custom config from CLI arguments
890+
tile = TileConfig(
891+
tile_m=args.tile_m or 128,
892+
tile_n=args.tile_n or 128,
893+
tile_k=args.tile_k or 64,
894+
warp_m=args.warp_m or 2,
895+
warp_n=args.warp_n or 2,
896+
warp_k=args.warp_k or 1,
897+
warp_tile_m=args.warp_tile_m or 32,
898+
warp_tile_n=args.warp_tile_n or 32,
899+
warp_tile_k=args.warp_tile_k or 16,
900+
)
901+
trait = TraitConfig(
902+
pipeline=args.pipeline or "compv4",
903+
scheduler=args.scheduler or "intrawave",
904+
epilogue=args.epilogue or "cshuffle",
905+
pad_m=args.pad_m,
906+
pad_n=args.pad_n,
907+
pad_k=args.pad_k,
908+
)
909+
config = ConvKernelConfig(
910+
tile=tile,
911+
trait=trait,
912+
variant=requested_variants[0]
913+
if requested_variants
914+
else ConvVariant.FORWARD,
915+
ndim_spatial=args.ndim[0] if args.ndim else 2,
916+
arch=args.arch,
917+
vector_size_a=args.vector_a,
918+
vector_size_b=args.vector_b,
919+
vector_size_c=args.vector_c,
920+
block_per_cu=args.block_per_cu,
921+
num_wave_groups=args.num_wave_groups,
922+
)
923+
filtered_configs = [config]
924+
else:
925+
# Get predefined configurations for target arch with requested variants and ndims
926+
filtered_configs = get_default_configs(
927+
arch=args.arch, variants=requested_variants, ndims=args.ndim
928+
)
929+
807930
if args.list_configs:
808931
print(f"Convolution configurations for {args.arch}:")
809932
print(f" Datatypes: {args.datatype}")

0 commit comments

Comments
 (0)