Skip to content

Latest commit

 

History

History
521 lines (383 loc) · 16.6 KB

File metadata and controls

521 lines (383 loc) · 16.6 KB

Layout Algebra Guide

Core types, construction, coordinate mapping, algebra operations, and layout utilities in FlyDSL.

Important: All fx.* layout operations generate MLIR IR and must be called inside a @flyc.kernel or @flyc.jit function body. Code snippets below show API usage patterns within that context, not standalone scripts.

Quick Reference

Operation Python API Fly Dialect Op Description
Construction fx.make_shape(8, 16) fly.make_shape Create shape (IntTuple)
fx.make_stride(1, 8) fly.make_stride Create stride (IntTuple)
fx.make_layout(shape, stride) fly.make_layout Create layout from (shape, stride)
fx.make_coord(i, j) fly.make_coord Create coordinate
fx.make_int_tuple(elems) fly.make_int_tuple Create generic IntTuple
fx.make_ordered_layout(shape, order) fly.make_ordered_layout Create layout with mode ordering
Mapping fx.crd2idx(coord, layout) fly.crd2idx Coordinate → linear index
fx.idx2crd(idx, layout) fly.idx2crd Linear index → coordinate
Query fx.size(layout) fly.size Total element count
fx.cosize(layout) fly.cosize Codomain size (max index + 1)
fx.get_shape(layout) fly.get_shape Extract shape from layout
fx.get_stride(layout) fly.get_stride Extract stride from layout
fx.get(int_tuple, idx) fly.select + fly.get_scalar Extract element at index
Algebra fx.composition(A, B) fly.composition Compose: A ∘ B
fx.complement(tiler, size) fly.complement Complement of tiler
fx.coalesce(layout) fly.coalesce Simplify layout
fx.right_inverse(layout) fly.right_inverse Right inverse of layout
Products fx.logical_product(A, B) fly.logical_product Basic product
fx.zipped_product(A, B) fly.zipped_product Zipped product
fx.tiled_product(A, B) fly.tiled_product Tiled product
fx.flat_product(A, B) fly.flat_product Flat product
fx.raked_product(A, B) fly.raked_product Raked product
fx.block_product(A, B) fly.block_product Blocked product
Divides fx.logical_divide(A, B) fly.logical_divide Basic divide
fx.zipped_divide(A, B) fly.zipped_divide Zipped divide
fx.tiled_divide(A, B) fly.tiled_divide Tiled divide
fx.flat_divide(A, B) fly.flat_divide Flat divide
Structural fx.select(it, indices) fly.select Select modes by index
fx.group(it, begin, end) fly.group Group modes into nested tuple
fx.append(base, elem) fly.append Append mode to IntTuple
fx.prepend(base, elem) fly.prepend Prepend mode to IntTuple
fx.zip(lhs, rhs) fly.zip Zip two IntTuples
Recast fx.recast_layout(ly, old, new) fly.recast_layout Recast layout for type width change

1. Core Types

The Fly dialect defines several custom MLIR types for layout algebra:

Type MLIR Syntax Description
!fly.int_tuple !fly.int_tuple<(8, 16)> Integer tuple — can be nested
!fly.layout !fly.layout<(8, 16):(1, 8)> Layout = (Shape, Stride) pair
!fly.pointer !fly.pointer<f16> Typed pointer
!fly.memref !fly.memref<...> Memory reference with layout
!fly.swizzle !fly.swizzle<...> Swizzle descriptor
!fly.copy_atom !fly.copy_atom_universal_copy<...> Copy atom type
!fly.mma_atom !fly.mma_atom_universal_fma<...> MMA atom type

IntTuple Patterns

IntTuples encode structure at the type level:

Pattern Meaning Example
Integer literal Static constant 8
Dynamic value Runtime SSA value Provided as operand
Nested tuple Hierarchical mode (8, (4, 2))

2. Construction

Python API (via flydsl.expr)

import flydsl.expr as fx
from flydsl.expr import arith
from flydsl.expr.typing import T

# Shapes and strides (static constants auto-materialized)
shape = fx.make_shape(8, 16)              # !fly.int_tuple<(8, 16)>
stride = fx.make_stride(1, 8)             # !fly.int_tuple<(1, 8)>
layout = fx.make_layout(shape, stride)    # !fly.layout<(8, 16):(1, 8)>

# Shorthand — pass Python tuples directly
layout = fx.make_layout((8, 16), (1, 8))

# Coordinates
coord = fx.make_coord(i, j)

# Generic integer tuple
it = fx.make_int_tuple((4, 8, 2))

# Nested shapes
shape_nested = fx.make_shape(9, (4, 8))   # (9, (4, 8))

# Ordered layout — specify stride order (e.g., column-major vs row-major)
col_major = fx.make_ordered_layout((M, N), order=(0, 1))  # stride order: M-first
row_major = fx.make_ordered_layout((M, N), order=(1, 0))  # stride order: N-first

# Identity layout / tensor
identity = fx.make_identity_layout((M, N))
id_tensor = fx.make_identity_tensor((M, N))

3. Coordinate Mapping

The fundamental operation: mapping between logical coordinates and physical memory indices.

Formula: Index = sum(coord_i * stride_i)

crd2idx — Coordinate to Index

idx = fx.crd2idx(coord, layout)

idx2crd — Index to Coordinate (inverse)

coord = fx.idx2crd(idx, layout)

Example

For layout ((8, 16), (1, 8)) (8x16, column-major):

  • crd2idx((3, 5), layout) = 3*1 + 5*8 = 43
  • idx2crd(43, layout) = (43 % 8, 43 / 8) = (3, 5)

4. Query Operations

Operation Description Example
size(x) Product of all dimensions size((8, 16)) = 128
cosize(layout) Max index + 1 (codomain size) cosize(((8,16),(1,8))) = 128
get_shape(layout) Extract shape from layout Returns !fly.int_tuple
get_stride(layout) Extract stride from layout Returns !fly.int_tuple
get(x, i) Extract i-th element get((8, 16), 0) = 8
get_scalar(x) Extract scalar from leaf IntTuple Returns index value
rank(x) Number of top-level modes rank((8, 16)) = 2
depth(x) Nesting depth depth((8, (4, 2))) = 2
s = fx.size(layout)           # total elements (returns Int32 for static)
cs = fx.cosize(layout)        # codomain size (max index + 1)
shape = fx.get_shape(layout)
stride = fx.get_stride(layout)
v = fx.get(shape, 0)          # first dimension
r = fx.rank(shape)            # number of modes

5. Layout Algebra

5.1 Composition: composition(A, B)

Composes two layouts: result maps through B first, then A.

Semantics: result(x) = A(B(x))

composed = fx.composition(layout_a, layout_b)

Use case: Applying a permutation or tile coordinate mapping to a memory layout.

5.2 Complement: complement(tiler, target_size)

Computes the "remaining" modes not covered by the tiler, up to target_size elements.

rest = fx.complement(tiler, target_size)

Use case: Internal building block for logical_divide. Computing complementary iteration space when tiling.

5.3 Coalesce: coalesce(layout)

Simplifies a layout by flattening nested modes and combining adjacent modes when possible.

Post-conditions:

  • size(result) == size(layout) (preserves total size)
  • For all valid indices: layout(i) == result(i) (preserves mapping)
simplified = fx.coalesce(layout)

5.4 Right Inverse: right_inverse(layout)

Computes the right inverse of a layout mapping.

inv = fx.right_inverse(layout)

5.5 Recast Layout: recast_layout(layout, old_bits, new_bits)

Adjusts a layout for a type width change (e.g., FP16 → FP8):

# Convert layout from 16-bit to 8-bit elements
recasted = fx.recast_layout(layout, old_type_bits=16, new_type_bits=8)

6. Product Operations

Products combine two layouts to create a larger layout. All products take (layout, tiler).

Variant Description
logical_product Mode-wise concatenation (most basic). Scales tiler strides by layout size.
zipped_product Interleaves modes from layout and tiler.
tiled_product Creates hierarchical tiled structure.
flat_product Produces a flattened result.
raked_product Creates a raked (interleaved) access pattern.
block_product Creates a blocked access pattern.
result = fx.logical_product(layout, tiler)
result = fx.zipped_product(layout, tiler)
result = fx.raked_product(layout, tiler)

7. Divide Operations

Divides partition a layout by a divisor, creating a view that separates "tile" and "rest" dimensions.

Variant Description
logical_divide Basic partitioning. Internally uses complement.
zipped_divide Zipped division semantics.
tiled_divide Hierarchical tiled division.
flat_divide Flattened division.
result = fx.logical_divide(layout, divisor)
result = fx.zipped_divide(layout, divisor)

8. Structural Operations

select(int_tuple, indices)

Select modes by index:

selected = fx.select(int_tuple, indices=[0, 2])  # pick modes 0 and 2

group(int_tuple, begin, end)

Group a range of modes into a nested tuple:

grouped = fx.group(int_tuple, begin=1, end=3)

append(base, elem) / prepend(base, elem)

Add a mode to the end/beginning:

extended = fx.append(base_tuple, new_elem)
extended = fx.prepend(base_tuple, new_elem)

zip(lhs, rhs)

Zip two IntTuples mode-wise:

zipped = fx.zip(shapes_a, shapes_b)

slice(src, coord)

Slice an IntTuple/layout at a coordinate:

sliced = fx.slice(layout, coord)

9. MemRef / View / Copy Operations

MemRef Operations

# Allocate on-chip memory with layout
alloca = fx.memref_alloca(memref_type, layout)

# Load / store through layout
val = fx.memref_load(memref, indices)
fx.memref_store(value, memref, indices)

# Vector load / store
vec = fx.memref_load_vec(memref)
fx.memref_store_vec(vector, memref)

# Get layout from memref
ly = fx.get_layout(memref)

# Get iterator from memref
it = fx.get_iter(memref)

View and Offset

# Create a view from iterator + layout
view = fx.make_view(iterator, layout)

# Add offset to a pointer
ptr = fx.add_offset(ptr, offset)

Copy Atoms and Tiled Copies

Copy Atom Types

Type Factory Description
fx.UniversalCopy128b() Generic 128-bit copy
fx.UniversalCopy64b() Generic 64-bit copy
fx.UniversalCopy32b() Generic 32-bit copy
fx.UniversalCopy(bits) Generic copy with custom bit width
fx.rocdl.BufferCopy128b() AMD buffer-descriptor 128-bit copy
fx.rocdl.BufferCopy64b() AMD buffer-descriptor 64-bit copy
fx.rocdl.BufferCopy32b() AMD buffer-descriptor 32-bit copy

Construction

# Create copy atom (copy_op_type, elem_type)
copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32)

# Create MMA atom
mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 4, fx.Float32))

# Build thread-value layout from thread and value layouts
tiler_mn, layout_tv = fx.make_layout_tv(thr_layout, val_layout)

# Make tiled copy from copy atom + layout + tile
tiled_copy = fx.make_tiled_copy(copy_atom, layout_tv, tile_mn)

# Make tiled copy matched to a TiledMma's A/B/C partitioning
tiled_copy_a = fx.make_tiled_copy_A(copy_atom, tiled_mma)
tiled_copy_b = fx.make_tiled_copy_B(copy_atom, tiled_mma)
tiled_copy_c = fx.make_tiled_copy_C(copy_atom, tiled_mma)

# Make tiled MMA from MMA atom + atom layout + optional permutation
tiled_mma = fx.make_tiled_mma(mma_atom, atom_layout)
tiled_mma = fx.make_tiled_mma(mma_atom, atom_layout, permutation)

Thread Slicing and Partitioning

# Get a per-thread view of a tiled copy
thr_copy = tiled_copy.get_slice(tid)   # returns ThrCopy
src_part = thr_copy.partition_S(src)   # partition source tensor
dst_part = thr_copy.partition_D(dst)   # partition destination tensor
retiled  = thr_copy.retile(tensor)     # retile tensor to match copy atom

# Get a per-thread view of a tiled MMA
thr_mma = tiled_mma.thr_slice(tid)     # returns ThrMma (alias: get_slice)

# Register fragments: pass the block-level tensor views (see examples/03-tiledMma.py).
frag_a = thr_mma.make_fragment_A(tensor_a)
frag_b = thr_mma.make_fragment_B(tensor_b)
frag_c = thr_mma.make_fragment_C(tensor_c)

# Optional spatial partition of a tensor for this thread (different use case)
part_a = thr_mma.partition_A(tensor_a)

Execution

# Execute tiled copy
fx.copy(copy_atom, src_part, dst_part)

# Execute tiled copy with predicate mask (for boundary handling)
fx.copy(copy_atom, src_part, dst_part, pred=pred_tensor)

# Execute GEMM: D = A * B + C
fx.gemm(mma_atom, d, a, b, c)

Introspection

Property Class Description
copy_atom.thr_layout CopyAtom Thread layout of copy atom
copy_atom.tv_layout_src CopyAtom Thread-value layout for source
copy_atom.tv_layout_dst CopyAtom Thread-value layout for destination
mma_atom.thr_layout MmaAtom Thread layout
mma_atom.shape_mnk MmaAtom M×N×K tile dimensions
mma_atom.tv_layout_A/B/C MmaAtom Thread-value layouts per operand
tiled_copy.tiled_tv_layout_S TiledCopy Full tiled source layout
tiled_copy.tiled_tv_layout_D TiledCopy Full tiled destination layout
tiled_mma.tile_size_mnk TiledMma Tiled MMA dimensions
tiled_mma.thr_layout_vmnk TiledMma Thread layout across V,M,N,K
tiled_mma.tiled_tv_layout_A/B/C TiledMma Full tiled layouts per operand

10. Nested / Hierarchical Layouts

The Fly dialect supports nested layouts for representing multi-level tiling hierarchies:

# Nested shape: 9 elements in first mode, (4, 8) = 32 elements in second
shape = fx.make_shape(9, (4, 8))

Nested layouts are used in GEMM kernels for multi-level tiling (block → warp → thread → instruction).


11. IntTuple Arithmetic

# Element-wise operations on IntTuples
sum_it = fx.int_tuple_add(a, b)
diff_it = fx.int_tuple_sub(a, b)
prod_it = fx.int_tuple_mul(a, b)
quot_it = fx.int_tuple_div(a, b)

# Reduce to product
total = fx.int_tuple_product(int_tuple)

# Per-mode product (for nested tuples)
products = fx.int_tuple_product_each(int_tuple)

12. Printf Debugging

The Fly dialect provides a printf op for kernel debugging:

fx.printf("tid={} bid={} val={}", tid, bid, value)

Supports:

  • ir.Value — dynamic values
  • int, float, bool — auto-converted to constants
  • str, type — embedded as static text
  • DSL types with __fly_values__ — auto-unwrapped

13. Decision Tree

Which layout operation do I need?

├── Creating a layout?
│   ├── From explicit shape + stride → make_layout(shape, stride)
│   ├── Identity layout → make_identity_layout(shape)
│   └── From existing components → make_layout(get_shape(l), new_stride)
│
├── Querying a layout?
│   ├── Total elements → size(layout)
│   ├── Extract component → get_shape(layout), get_stride(layout)
│   ├── Single mode → get(shape, i)
│   └── Number of modes → rank(layout)
│
├── Coordinate mapping?
│   ├── Coord → memory index → crd2idx(coord, layout)
│   ├── Memory index → coord → idx2crd(idx, layout)
│   └── Tuple shortcut → fx.crd2idx([c0, c1], layout)
│
├── Combining layouts?
│   ├── Sequential mapping → composition(A, B)
│   ├── Extending threads → logical_product / raked_product / block_product
│   └── Simplifying → coalesce(layout)
│
├── Partitioning / tiling?
│   ├── Split layout → logical_divide / zipped_divide
│   └── Hierarchical tile → tiled_divide
│
├── Type width change?
│   └── recast_layout(layout, old_bits, new_bits)
│
└── Structural manipulation?
    ├── Select modes → select(it, indices)
    ├── Group modes → group(it, begin, end)
    └── Extend → append(it, elem) / prepend(it, elem)

14. Source Files

File Description
python/flydsl/expr/primitive.py All layout functions: construction, query, algebra, divide, product, copy, gemm
python/flydsl/expr/derived.py CopyAtom, MmaAtom, TiledCopy wrapper classes
python/flydsl/expr/typing.py IntTupleType, LayoutType, type definitions
include/flydsl/Dialect/Fly/IR/FlyOps.td Fly dialect op definitions
lib/Dialect/Fly/IR/FlyOps.cpp Type inference for composition, product, divide (Fly)
include/flydsl/Dialect/Fly/Utils/LayoutUtils.h Layout algebra algorithms (composition, product, divide)
tests/mlir/LayoutAlgebra/*.mlir Layout algebra MLIR lit tests