Core types, construction, coordinate mapping, algebra operations, and layout utilities in FlyDSL.
Important: All
fx.*layout operations generate MLIR IR and must be called inside a@flyc.kernelor@flyc.jitfunction body. Code snippets below show API usage patterns within that context, not standalone scripts.
| Operation | Python API | Fly Dialect Op | Description |
|---|---|---|---|
| Construction | fx.make_shape(8, 16) |
fly.make_shape |
Create shape (IntTuple) |
fx.make_stride(1, 8) |
fly.make_stride |
Create stride (IntTuple) | |
fx.make_layout(shape, stride) |
fly.make_layout |
Create layout from (shape, stride) | |
fx.make_coord(i, j) |
fly.make_coord |
Create coordinate | |
fx.make_int_tuple(elems) |
fly.make_int_tuple |
Create generic IntTuple | |
fx.make_ordered_layout(shape, order) |
fly.make_ordered_layout |
Create layout with mode ordering | |
| Mapping | fx.crd2idx(coord, layout) |
fly.crd2idx |
Coordinate → linear index |
fx.idx2crd(idx, layout) |
fly.idx2crd |
Linear index → coordinate | |
| Query | fx.size(layout) |
fly.size |
Total element count |
fx.cosize(layout) |
fly.cosize |
Codomain size (max index + 1) | |
fx.get_shape(layout) |
fly.get_shape |
Extract shape from layout | |
fx.get_stride(layout) |
fly.get_stride |
Extract stride from layout | |
fx.get(int_tuple, idx) |
fly.select + fly.get_scalar |
Extract element at index | |
| Algebra | fx.composition(A, B) |
fly.composition |
Compose: A ∘ B |
fx.complement(tiler, size) |
fly.complement |
Complement of tiler | |
fx.coalesce(layout) |
fly.coalesce |
Simplify layout | |
fx.right_inverse(layout) |
fly.right_inverse |
Right inverse of layout | |
| Products | fx.logical_product(A, B) |
fly.logical_product |
Basic product |
fx.zipped_product(A, B) |
fly.zipped_product |
Zipped product | |
fx.tiled_product(A, B) |
fly.tiled_product |
Tiled product | |
fx.flat_product(A, B) |
fly.flat_product |
Flat product | |
fx.raked_product(A, B) |
fly.raked_product |
Raked product | |
fx.block_product(A, B) |
fly.block_product |
Blocked product | |
| Divides | fx.logical_divide(A, B) |
fly.logical_divide |
Basic divide |
fx.zipped_divide(A, B) |
fly.zipped_divide |
Zipped divide | |
fx.tiled_divide(A, B) |
fly.tiled_divide |
Tiled divide | |
fx.flat_divide(A, B) |
fly.flat_divide |
Flat divide | |
| Structural | fx.select(it, indices) |
fly.select |
Select modes by index |
fx.group(it, begin, end) |
fly.group |
Group modes into nested tuple | |
fx.append(base, elem) |
fly.append |
Append mode to IntTuple | |
fx.prepend(base, elem) |
fly.prepend |
Prepend mode to IntTuple | |
fx.zip(lhs, rhs) |
fly.zip |
Zip two IntTuples | |
| Recast | fx.recast_layout(ly, old, new) |
fly.recast_layout |
Recast layout for type width change |
The Fly dialect defines several custom MLIR types for layout algebra:
| Type | MLIR Syntax | Description |
|---|---|---|
!fly.int_tuple |
!fly.int_tuple<(8, 16)> |
Integer tuple — can be nested |
!fly.layout |
!fly.layout<(8, 16):(1, 8)> |
Layout = (Shape, Stride) pair |
!fly.pointer |
!fly.pointer<f16> |
Typed pointer |
!fly.memref |
!fly.memref<...> |
Memory reference with layout |
!fly.swizzle |
!fly.swizzle<...> |
Swizzle descriptor |
!fly.copy_atom |
!fly.copy_atom_universal_copy<...> |
Copy atom type |
!fly.mma_atom |
!fly.mma_atom_universal_fma<...> |
MMA atom type |
IntTuples encode structure at the type level:
| Pattern | Meaning | Example |
|---|---|---|
| Integer literal | Static constant | 8 |
| Dynamic value | Runtime SSA value | Provided as operand |
| Nested tuple | Hierarchical mode | (8, (4, 2)) |
import flydsl.expr as fx
from flydsl.expr import arith
from flydsl.expr.typing import T
# Shapes and strides (static constants auto-materialized)
shape = fx.make_shape(8, 16) # !fly.int_tuple<(8, 16)>
stride = fx.make_stride(1, 8) # !fly.int_tuple<(1, 8)>
layout = fx.make_layout(shape, stride) # !fly.layout<(8, 16):(1, 8)>
# Shorthand — pass Python tuples directly
layout = fx.make_layout((8, 16), (1, 8))
# Coordinates
coord = fx.make_coord(i, j)
# Generic integer tuple
it = fx.make_int_tuple((4, 8, 2))
# Nested shapes
shape_nested = fx.make_shape(9, (4, 8)) # (9, (4, 8))
# Ordered layout — specify stride order (e.g., column-major vs row-major)
col_major = fx.make_ordered_layout((M, N), order=(0, 1)) # stride order: M-first
row_major = fx.make_ordered_layout((M, N), order=(1, 0)) # stride order: N-first
# Identity layout / tensor
identity = fx.make_identity_layout((M, N))
id_tensor = fx.make_identity_tensor((M, N))The fundamental operation: mapping between logical coordinates and physical memory indices.
Formula: Index = sum(coord_i * stride_i)
idx = fx.crd2idx(coord, layout)coord = fx.idx2crd(idx, layout)For layout ((8, 16), (1, 8)) (8x16, column-major):
crd2idx((3, 5), layout)=3*1 + 5*8=43idx2crd(43, layout)=(43 % 8, 43 / 8)=(3, 5)
| Operation | Description | Example |
|---|---|---|
size(x) |
Product of all dimensions | size((8, 16)) = 128 |
cosize(layout) |
Max index + 1 (codomain size) | cosize(((8,16),(1,8))) = 128 |
get_shape(layout) |
Extract shape from layout | Returns !fly.int_tuple |
get_stride(layout) |
Extract stride from layout | Returns !fly.int_tuple |
get(x, i) |
Extract i-th element | get((8, 16), 0) = 8 |
get_scalar(x) |
Extract scalar from leaf IntTuple | Returns index value |
rank(x) |
Number of top-level modes | rank((8, 16)) = 2 |
depth(x) |
Nesting depth | depth((8, (4, 2))) = 2 |
s = fx.size(layout) # total elements (returns Int32 for static)
cs = fx.cosize(layout) # codomain size (max index + 1)
shape = fx.get_shape(layout)
stride = fx.get_stride(layout)
v = fx.get(shape, 0) # first dimension
r = fx.rank(shape) # number of modesComposes two layouts: result maps through B first, then A.
Semantics: result(x) = A(B(x))
composed = fx.composition(layout_a, layout_b)Use case: Applying a permutation or tile coordinate mapping to a memory layout.
Computes the "remaining" modes not covered by the tiler, up to target_size elements.
rest = fx.complement(tiler, target_size)Use case: Internal building block for logical_divide. Computing complementary iteration space when tiling.
Simplifies a layout by flattening nested modes and combining adjacent modes when possible.
Post-conditions:
size(result) == size(layout)(preserves total size)- For all valid indices:
layout(i) == result(i)(preserves mapping)
simplified = fx.coalesce(layout)Computes the right inverse of a layout mapping.
inv = fx.right_inverse(layout)Adjusts a layout for a type width change (e.g., FP16 → FP8):
# Convert layout from 16-bit to 8-bit elements
recasted = fx.recast_layout(layout, old_type_bits=16, new_type_bits=8)Products combine two layouts to create a larger layout. All products take (layout, tiler).
| Variant | Description |
|---|---|
logical_product |
Mode-wise concatenation (most basic). Scales tiler strides by layout size. |
zipped_product |
Interleaves modes from layout and tiler. |
tiled_product |
Creates hierarchical tiled structure. |
flat_product |
Produces a flattened result. |
raked_product |
Creates a raked (interleaved) access pattern. |
block_product |
Creates a blocked access pattern. |
result = fx.logical_product(layout, tiler)
result = fx.zipped_product(layout, tiler)
result = fx.raked_product(layout, tiler)Divides partition a layout by a divisor, creating a view that separates "tile" and "rest" dimensions.
| Variant | Description |
|---|---|
logical_divide |
Basic partitioning. Internally uses complement. |
zipped_divide |
Zipped division semantics. |
tiled_divide |
Hierarchical tiled division. |
flat_divide |
Flattened division. |
result = fx.logical_divide(layout, divisor)
result = fx.zipped_divide(layout, divisor)Select modes by index:
selected = fx.select(int_tuple, indices=[0, 2]) # pick modes 0 and 2Group a range of modes into a nested tuple:
grouped = fx.group(int_tuple, begin=1, end=3)Add a mode to the end/beginning:
extended = fx.append(base_tuple, new_elem)
extended = fx.prepend(base_tuple, new_elem)Zip two IntTuples mode-wise:
zipped = fx.zip(shapes_a, shapes_b)Slice an IntTuple/layout at a coordinate:
sliced = fx.slice(layout, coord)# Allocate on-chip memory with layout
alloca = fx.memref_alloca(memref_type, layout)
# Load / store through layout
val = fx.memref_load(memref, indices)
fx.memref_store(value, memref, indices)
# Vector load / store
vec = fx.memref_load_vec(memref)
fx.memref_store_vec(vector, memref)
# Get layout from memref
ly = fx.get_layout(memref)
# Get iterator from memref
it = fx.get_iter(memref)# Create a view from iterator + layout
view = fx.make_view(iterator, layout)
# Add offset to a pointer
ptr = fx.add_offset(ptr, offset)| Type Factory | Description |
|---|---|
fx.UniversalCopy128b() |
Generic 128-bit copy |
fx.UniversalCopy64b() |
Generic 64-bit copy |
fx.UniversalCopy32b() |
Generic 32-bit copy |
fx.UniversalCopy(bits) |
Generic copy with custom bit width |
fx.rocdl.BufferCopy128b() |
AMD buffer-descriptor 128-bit copy |
fx.rocdl.BufferCopy64b() |
AMD buffer-descriptor 64-bit copy |
fx.rocdl.BufferCopy32b() |
AMD buffer-descriptor 32-bit copy |
# Create copy atom (copy_op_type, elem_type)
copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32)
# Create MMA atom
mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 4, fx.Float32))
# Build thread-value layout from thread and value layouts
tiler_mn, layout_tv = fx.make_layout_tv(thr_layout, val_layout)
# Make tiled copy from copy atom + layout + tile
tiled_copy = fx.make_tiled_copy(copy_atom, layout_tv, tile_mn)
# Make tiled copy matched to a TiledMma's A/B/C partitioning
tiled_copy_a = fx.make_tiled_copy_A(copy_atom, tiled_mma)
tiled_copy_b = fx.make_tiled_copy_B(copy_atom, tiled_mma)
tiled_copy_c = fx.make_tiled_copy_C(copy_atom, tiled_mma)
# Make tiled MMA from MMA atom + atom layout + optional permutation
tiled_mma = fx.make_tiled_mma(mma_atom, atom_layout)
tiled_mma = fx.make_tiled_mma(mma_atom, atom_layout, permutation)# Get a per-thread view of a tiled copy
thr_copy = tiled_copy.get_slice(tid) # returns ThrCopy
src_part = thr_copy.partition_S(src) # partition source tensor
dst_part = thr_copy.partition_D(dst) # partition destination tensor
retiled = thr_copy.retile(tensor) # retile tensor to match copy atom
# Get a per-thread view of a tiled MMA
thr_mma = tiled_mma.thr_slice(tid) # returns ThrMma (alias: get_slice)
# Register fragments: pass the block-level tensor views (see examples/03-tiledMma.py).
frag_a = thr_mma.make_fragment_A(tensor_a)
frag_b = thr_mma.make_fragment_B(tensor_b)
frag_c = thr_mma.make_fragment_C(tensor_c)
# Optional spatial partition of a tensor for this thread (different use case)
part_a = thr_mma.partition_A(tensor_a)# Execute tiled copy
fx.copy(copy_atom, src_part, dst_part)
# Execute tiled copy with predicate mask (for boundary handling)
fx.copy(copy_atom, src_part, dst_part, pred=pred_tensor)
# Execute GEMM: D = A * B + C
fx.gemm(mma_atom, d, a, b, c)| Property | Class | Description |
|---|---|---|
copy_atom.thr_layout |
CopyAtom |
Thread layout of copy atom |
copy_atom.tv_layout_src |
CopyAtom |
Thread-value layout for source |
copy_atom.tv_layout_dst |
CopyAtom |
Thread-value layout for destination |
mma_atom.thr_layout |
MmaAtom |
Thread layout |
mma_atom.shape_mnk |
MmaAtom |
M×N×K tile dimensions |
mma_atom.tv_layout_A/B/C |
MmaAtom |
Thread-value layouts per operand |
tiled_copy.tiled_tv_layout_S |
TiledCopy |
Full tiled source layout |
tiled_copy.tiled_tv_layout_D |
TiledCopy |
Full tiled destination layout |
tiled_mma.tile_size_mnk |
TiledMma |
Tiled MMA dimensions |
tiled_mma.thr_layout_vmnk |
TiledMma |
Thread layout across V,M,N,K |
tiled_mma.tiled_tv_layout_A/B/C |
TiledMma |
Full tiled layouts per operand |
The Fly dialect supports nested layouts for representing multi-level tiling hierarchies:
# Nested shape: 9 elements in first mode, (4, 8) = 32 elements in second
shape = fx.make_shape(9, (4, 8))Nested layouts are used in GEMM kernels for multi-level tiling (block → warp → thread → instruction).
# Element-wise operations on IntTuples
sum_it = fx.int_tuple_add(a, b)
diff_it = fx.int_tuple_sub(a, b)
prod_it = fx.int_tuple_mul(a, b)
quot_it = fx.int_tuple_div(a, b)
# Reduce to product
total = fx.int_tuple_product(int_tuple)
# Per-mode product (for nested tuples)
products = fx.int_tuple_product_each(int_tuple)The Fly dialect provides a printf op for kernel debugging:
fx.printf("tid={} bid={} val={}", tid, bid, value)Supports:
ir.Value— dynamic valuesint,float,bool— auto-converted to constantsstr,type— embedded as static text- DSL types with
__fly_values__— auto-unwrapped
Which layout operation do I need?
├── Creating a layout?
│ ├── From explicit shape + stride → make_layout(shape, stride)
│ ├── Identity layout → make_identity_layout(shape)
│ └── From existing components → make_layout(get_shape(l), new_stride)
│
├── Querying a layout?
│ ├── Total elements → size(layout)
│ ├── Extract component → get_shape(layout), get_stride(layout)
│ ├── Single mode → get(shape, i)
│ └── Number of modes → rank(layout)
│
├── Coordinate mapping?
│ ├── Coord → memory index → crd2idx(coord, layout)
│ ├── Memory index → coord → idx2crd(idx, layout)
│ └── Tuple shortcut → fx.crd2idx([c0, c1], layout)
│
├── Combining layouts?
│ ├── Sequential mapping → composition(A, B)
│ ├── Extending threads → logical_product / raked_product / block_product
│ └── Simplifying → coalesce(layout)
│
├── Partitioning / tiling?
│ ├── Split layout → logical_divide / zipped_divide
│ └── Hierarchical tile → tiled_divide
│
├── Type width change?
│ └── recast_layout(layout, old_bits, new_bits)
│
└── Structural manipulation?
├── Select modes → select(it, indices)
├── Group modes → group(it, begin, end)
└── Extend → append(it, elem) / prepend(it, elem)
| File | Description |
|---|---|
python/flydsl/expr/primitive.py |
All layout functions: construction, query, algebra, divide, product, copy, gemm |
python/flydsl/expr/derived.py |
CopyAtom, MmaAtom, TiledCopy wrapper classes |
python/flydsl/expr/typing.py |
IntTupleType, LayoutType, type definitions |
include/flydsl/Dialect/Fly/IR/FlyOps.td |
Fly dialect op definitions |
lib/Dialect/Fly/IR/FlyOps.cpp |
Type inference for composition, product, divide (Fly) |
include/flydsl/Dialect/Fly/Utils/LayoutUtils.h |
Layout algebra algorithms (composition, product, divide) |
tests/mlir/LayoutAlgebra/*.mlir |
Layout algebra MLIR lit tests |