|
| 1 | +import mlir.extras.types as T |
| 2 | +import numpy as np |
| 3 | +from mlir.dialects import builtin |
| 4 | +from mlir.dialects.transform import any_op_t |
| 5 | +from mlir.dialects.transform.extras import named_sequence, apply_patterns |
| 6 | +from mlir.dialects.transform.structured import MatchInterfaceEnum, VectorizeOp |
| 7 | +from mlir.dialects.transform.vector import ( |
| 8 | + VectorContractLowering, |
| 9 | +) |
| 10 | +from mlir.ir import StringAttr, UnitAttr, Attribute |
| 11 | + |
| 12 | +# you need this to register the memref value caster |
| 13 | +# noinspection PyUnresolvedReferences |
| 14 | +import mlir.extras.dialects.memref |
| 15 | +from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule |
| 16 | +from mlir.extras.dialects import linalg |
| 17 | +from mlir.extras.dialects import transform, llvm |
| 18 | +from mlir.extras.dialects.func import func |
| 19 | +from mlir.extras.dialects.transform import ( |
| 20 | + match, |
| 21 | + get_parent_op, |
| 22 | +) |
| 23 | +from mlir.extras.runtime.passes import Pipeline, run_pipeline |
| 24 | +from mlir.extras.runtime.refbackend import LLVMJITBackend |
| 25 | +from mlir.extras.util import find_ops |
| 26 | + |
| 27 | +ctx = RAIIMLIRContext() |
| 28 | +backend = LLVMJITBackend() |
| 29 | +module = ExplicitlyManagedModule() |
| 30 | + |
| 31 | +M, K, N = 7, 13, 7 |
| 32 | + |
| 33 | + |
| 34 | +@func |
| 35 | +def matmul_armsme( |
| 36 | + A: T.tensor(M, K, T.f32()), |
| 37 | + B: T.tensor(K, N, T.f32()), |
| 38 | + C: T.tensor(M, N, T.f32()), |
| 39 | +): |
| 40 | + return linalg.matmul(A, B, C) |
| 41 | + |
| 42 | + |
| 43 | +@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")}) |
| 44 | +def payload(): |
| 45 | + matmul_armsme.emit(force=True) |
| 46 | + |
| 47 | + |
| 48 | +# based on https://github.com/llvm/llvm-project/blob/ad656d3a1954dd6157ba689b3003b6fbb97a0833/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir |
| 49 | +@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()}) |
| 50 | +def mod_transform(): |
| 51 | + @named_sequence("main", [any_op_t()], []) |
| 52 | + def main(module_op: any_op_t()): |
| 53 | + # Step 1: Match the linalg.matmul operation |
| 54 | + matmul_op = match(module_op, ops=["linalg.matmul"]) |
| 55 | + |
| 56 | + # Step 2: Tile for size [4] x [4], which corresponds to SVLs x SVLs |
| 57 | + tiled_linalg_op, loops = transform.tile_to_scf_for( |
| 58 | + matmul_op, sizes=[[4], [4], 1] |
| 59 | + ) |
| 60 | + |
| 61 | + # Step 3: Vectorize |
| 62 | + VectorizeOp(tiled_linalg_op, vector_sizes=[[4], [4], 1]) |
| 63 | + |
| 64 | + # Step 4: Bufferize ahead of TransferReadDropUnitDimsPattern |
| 65 | + bufferize = transform.bufferization.one_shot_bufferize( |
| 66 | + module_op, bufferize_function_boundaries=True |
| 67 | + ) |
| 68 | + |
| 69 | + # Step 5: Match func.func operations |
| 70 | + func_op = match(bufferize, ops=["func.func"]) |
| 71 | + |
| 72 | + # Step 6: Lower vector.multi_reduction to vector.contract (+ some helpful patterns) |
| 73 | + @apply_patterns(func_op) |
| 74 | + def patterns1(): |
| 75 | + transform.apply_patterns.vector.lower_masked_transfers() |
| 76 | + transform.apply_patterns.vector.transfer_permutation_patterns() |
| 77 | + transform.apply_patterns.vector.reduction_to_contract() |
| 78 | + transform.apply_patterns.vector.sink_ops() |
| 79 | + |
| 80 | + # Step 7: Lower vector.contract to vector.outerproduct |
| 81 | + @apply_patterns(func_op) |
| 82 | + def patterns2(): |
| 83 | + transform.apply_patterns.vector.lower_contraction( |
| 84 | + lowering_strategy=VectorContractLowering.OuterProduct |
| 85 | + ) |
| 86 | + transform.apply_patterns.vector.lower_masks() |
| 87 | + transform.apply_patterns.vector.rank_reducing_subview_patterns() |
| 88 | + transform.apply_patterns.canonicalization() |
| 89 | + |
| 90 | + # # Step 8 (optional optimization): Hoist accumulator load/store |
| 91 | + func_h = transform.structured.hoist_redundant_vector_transfers( |
| 92 | + any_op_t(), func_op |
| 93 | + ) |
| 94 | + |
| 95 | + all_loops = match(bufferize, interface=MatchInterfaceEnum.LoopLikeInterface) |
| 96 | + |
| 97 | + transform.apply_licm(all_loops) |
| 98 | + transform.loop.hoist_loop_invariant_subsets(all_loops) |
| 99 | + |
| 100 | + |
| 101 | +module = module.finish() |
| 102 | + |
| 103 | +vectorized_module = run_pipeline( |
| 104 | + module, |
| 105 | + pipeline=Pipeline() |
| 106 | + .transform_interpreter(entry_point="main", debug_payload_root_tag="payload") |
| 107 | + .canonicalize() |
| 108 | + .cse(), |
| 109 | +) |
| 110 | + |
| 111 | +# print(vectorized_module) |
| 112 | + |
| 113 | +kernel_funcs = find_ops( |
| 114 | + vectorized_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp) |
| 115 | +) |
| 116 | +for k in kernel_funcs: |
| 117 | + k.attributes["target_features"] = Attribute.parse( |
| 118 | + '#llvm.target_features<["+sme", "+sve"]>' |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +lower_to_llvm = ( |
| 123 | + Pipeline() |
| 124 | + # https://github.com/llvm/llvm-project/blob/9146ef5df0543f08a86686cfeb3bd1ea7338f4c6/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp#L45 |
| 125 | + # Legalize vector operations so they can be converted to ArmSME. |
| 126 | + .arm_sme_vector_legalization() |
| 127 | + # Sprinkle some cleanups. |
| 128 | + .canonicalize() |
| 129 | + .cse() |
| 130 | + # Passes that convert operations on vectors to ArmSME operations. |
| 131 | + # Convert Arith to ArmSME. |
| 132 | + .convert_arith_to_arm_sme() |
| 133 | + # Convert Vector to ArmSME. |
| 134 | + .convert_vector_to_arm_sme() |
| 135 | + # Convert operations on high-level vectors to loops. |
| 136 | + # Convert ArmSME to SCF. |
| 137 | + .convert_arm_sme_to_scf() |
| 138 | + # Convert Vector to SCF (with full unroll enabled). |
| 139 | + .convert_vector_to_scf(full_unroll=True) |
| 140 | + # Enable streaming-mode and ZA. |
| 141 | + .Func( |
| 142 | + Pipeline().enable_arm_streaming( |
| 143 | + streaming_mode="streaming-locally", |
| 144 | + za_mode="new-za", |
| 145 | + if_required_by_ops=True, |
| 146 | + ) |
| 147 | + ) |
| 148 | + # Convert SCF to CF (required for ArmSME tile allocation). |
| 149 | + .convert_scf_to_cf() |
| 150 | + # Convert ArmSME to LLVM. |
| 151 | + .Func(Pipeline().convert_arm_sme_to_llvm()) |
| 152 | + # Sprinkle some cleanups. |
| 153 | + .canonicalize() |
| 154 | + .cse() |
| 155 | + # https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44 |
| 156 | + .Func( |
| 157 | + Pipeline() |
| 158 | + # Blanket-convert any remaining high-level vector ops to loops if any remain. |
| 159 | + .convert_vector_to_scf() |
| 160 | + # Blanket-convert any remaining linalg ops to loops if any remain. |
| 161 | + .convert_linalg_to_loops() |
| 162 | + ) |
| 163 | + # Blanket-convert any remaining affine ops if any remain. |
| 164 | + .lower_affine() |
| 165 | + # Convert SCF to CF (always needed). |
| 166 | + .convert_scf_to_cf() |
| 167 | + # Sprinkle some cleanups. |
| 168 | + .canonicalize() |
| 169 | + .cse() |
| 170 | + # Convert vector to LLVM (always needed). |
| 171 | + .convert_vector_to_llvm() |
| 172 | + # Convert Math to LLVM (always needed). |
| 173 | + .Func(Pipeline().convert_math_to_llvm()) |
| 174 | + # Expand complicated MemRef operations before lowering them. |
| 175 | + .expand_strided_metadata() |
| 176 | + # The expansion may create affine expressions. Get rid of them. |
| 177 | + .lower_affine() |
| 178 | + # Convert MemRef to LLVM (always needed). |
| 179 | + .finalize_memref_to_llvm() |
| 180 | + # Convert Func to LLVM (always needed). |
| 181 | + .convert_func_to_llvm() |
| 182 | + .convert_arith_to_llvm() |
| 183 | + .convert_cf_to_llvm() |
| 184 | + # Convert Index to LLVM (always needed). |
| 185 | + .convert_index_to_llvm() |
| 186 | + # Convert UB to LLVM (always needed). |
| 187 | + .convert_ub_to_llvm() |
| 188 | + # Convert remaining unrealized_casts (always needed). |
| 189 | + .reconcile_unrealized_casts() |
| 190 | +) |
| 191 | + |
| 192 | +compiled_module = backend.compile( |
| 193 | + find_ops( |
| 194 | + vectorized_module.operation, |
| 195 | + lambda x: "transform.target_tag" in x.attributes |
| 196 | + and x.attributes["transform.target_tag"].value == "payload", |
| 197 | + single=True, |
| 198 | + ), |
| 199 | + kernel_name=matmul_armsme.__name__, |
| 200 | + pipeline=lower_to_llvm, |
| 201 | +) |
| 202 | + |
| 203 | +print(compiled_module) |
| 204 | + |
| 205 | +A = np.random.randint(0, 10, (M, K)).astype(np.float32) |
| 206 | +B = np.random.randint(0, 10, (K, N)).astype(np.float32) |
| 207 | +C = np.zeros((M, N), dtype=np.float32) |
| 208 | + |
| 209 | +backend.load(compiled_module).matmul_armsme_capi_wrapper(A, B, C) |
| 210 | +assert np.allclose(A @ B, C) |
0 commit comments