Open
Description
What this cast is doing
- reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
- for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
- rearrange the scale to swizzled format expected by gemm kernel
- return the casted elements and the swizzled scale
What we currently see:
TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_with_to_blocked
Output: https://gist.github.com/vkuzo/9bb4194b289003b6d8bf32d066e3f8e1
(i) one kernel to calculate the unswizzled scale and to cast the elements, (ii) one kernel to convert scale layout.
Metadata
Metadata
Assignees
Labels
No labels
Activity