Skip to content

Pass to 'unroll and jam' the scf.for loops around aievec.matmul #1167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented Mar 6, 2025

Currently: the scf for loops become cf blocks, which get unrolled in llvm's opt (after leaving iree-amd-aie). The unrolling results in an order of matmuls that is linear: iterate k then iterate m then iterate n. This isn't optimal for AIE -- I don't know the details but peano somehow can't orchestrate the pipelining of loads to registers with the mac/mfma instructions.

The basic idea with this PR is to 'tile' the scf.for loops, so that multiple m and n's can be running in parallel. See lit tests for what jam and unroll actually does. The PR allows us to experiment with different approaches and manually find the best performing approach.

The improves performance (measured as wall clock time to run a matmul in a tight loop) by between 25% and 35%. With this PR, the bf16 on strix kernel performance is about 55% overall (see CI and will appear in https://nod-ai.github.io/iree-amd-aie/results_history_npu1.html when landed).

FOLLOW-UP work: this PR only improves performance for 1 case: bf16 on phoenix. It also relies on the tiling strategy not changing (expects 64x64x64 matmul on core). This is a bit brittle, and can hopefully be made more robust.

NOTE: currently this only hits the pack-peel pipeline, see also #1177

@newling newling force-pushed the unroll_and_jam_aievec_matmul_bench branch from 29699ea to 667e61a Compare March 7, 2025 00:26
@newling newling force-pushed the unroll_and_jam_aievec_matmul_bench branch from 667e61a to f2c29bd Compare March 11, 2025 15:18
@newling newling changed the title [WIP] unroll-and-jam Pass to 'unroll and jam' the scf.for loops around aievec.matmul Mar 11, 2025
@newling newling marked this pull request as ready for review March 11, 2025 15:31
Copy link
Contributor

@Yu-Zhewen Yu-Zhewen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Really interesting results. Looking forward to seeing how it performs on different devices and pipelines.

namespace {

/// Get the loop count of \p forOp if its loop bounds and step are constant.
/// otherwise return std::nullopt. The loop count is 'end - start / step'.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// otherwise return std::nullopt. The loop count is 'end - start / step'.
/// Otherwise return std::nullopt. The loop count is '(end - start) / step'.


// -----

// Test of the 'none' sequence, which specifies that no this pass effectively does nothing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Test of the 'none' sequence, which specifies that no this pass effectively does nothing.
// Test of the 'none' sequence, which specifies that this pass effectively does nothing.

%c1 = arith.constant 1 : index
%c6 = arith.constant 6 : index
scf.for %arg4 = %c0 to %c6 step %c1 {
//expected-error @below {{'aievec.matmul' op has an unroll sequence "uj_0" whose length is is 3*n + 2 for some n.}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//expected-error @below {{'aievec.matmul' op has an unroll sequence "uj_0" whose length is is 3*n + 2 for some n.}}
//expected-error @below {{'aievec.matmul' op has an unroll sequence "uj_0" whose length is 3*n + 2 for some n.}}

Copy link
Collaborator

@jtuyls jtuyls left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The basic idea with this PR is to 'tile' the scf.for loops, so that multiple m and n's can be running in parallel

What do you mean with this? How can m en n's run in parallel?

The improves performance (measured as wall clock time to run a matmul in a tight loop) by between 25% and 35%. With this PR, the bf16 on strix kernel performance is about 55% overall (see CI and will appear in https://nod-ai.github.io/iree-amd-aie/results_history_npu1.html when landed).

I think you mean on Phoenix and not Strix? And how do you calculate the 55% number?

@@ -135,6 +135,10 @@ std::string getConstantIntValuesString(ArrayRef<OpFoldResult> opFoldResults);
bool sinkInto(Region &, IRRewriter &,
std::function<bool(Operation *)> shouldSink);

/// Annotate `forOp` with the llvm.loop_annotation attribute specifying that
/// it should never be unrolled.
void addNoUnrollAttribute(scf::ForOp forOp, OpBuilder &builder);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should happen if the loop already has a 'NoUnroll' loop annotation? And unless I overlooked, you don't have a testcase for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point I hadn't considered that.

FWIW having the _NOUNROLL on the end of string is a bit ugly and complex, but adding iree-compile options which get filtered down to passes is quite intrusive/heavy.


namespace {

/// Get the loop count of \p forOp if its loop bounds and step are constant.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This \p notation doesn't seem to be used yet in the codebase and neither in IREE, can we just use `` for consistency within the codebase?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 It's not very readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess so, especially as we're not generating any doxygen comments. Will update.

+1 It's not very readable.

Looks quite nice with the right vim plugin :)


namespace {

/// Get the loop count of \p forOp if its loop bounds and step are constant.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 It's not very readable.


else if (splitsMod == 2) {
return matmul->emitOpError()
<< errStart() << "whose length is is 3*n + 2 for some n.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< errStart() << "whose length is is 3*n + 2 for some n.";
<< errStart() << "whose length is 3*n + 2 for some n.";

Also I'm not quite understand what this error means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remainder when divided by 3 is 2. There are 2 supported formats. If remainder is 0, it's something like uj_0_2_u_1_4. If remainder is 1, it's like the remainder 0 case but with _UNROLL or _NOUNROLL in the end. If remainder is 2: invalid.

Comment on lines +280 to +282
assert(
splits.size() % 3 == 0 &&
"case of 1 reduced to case of 0, case of 2 already handled as error.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this assert is still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used to document the code.

Comment on lines +331 to +333
for (arith::ConstantOp constant : constantsToHoist) {
builder.moveOpBefore(constant, &firstOp);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can go inside the scfFor.walk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always wary (and admittedly a bit lazy) here. Is it safe, or does the walk get confused because the op it is currently visiting gets relocated? i.e. is the next op to visit determined before the constant op moves, or after? If after, it's no longer going to be walking inside the scf.for.

@newling
Copy link
Contributor Author

newling commented Mar 12, 2025

What do you mean with this? How can m en n's run in parallel?

Below is from a lit test. It's for 2 nested scf.fors (but same idea for 3 nested scf.fors) Does this comment describing it sort of make sense? Unroll-jam can be interpreted as tiling (rereading my comments, I think I can make them clearer and more concise, will do).

// Check multiple levels of unrolling.
// Before this transform the access pattern of the writes into the memref is:
// [[ 0  1  2  3  ]
//  [ 4  5  6  7  ]
//  [ 8  9  10 11 ]
//  [ 12 13 14 15 ]]
//
// After the transform it is:
// [[ 0  2  8  10 ]
//  [ 1  3  9  11 ]
//  [ 4  6  12 14 ]
//  [ 5  7  13 15 ]]

By 'in parallel' I mean that we don't process a full k-reduction for (m0, n0) before starting the k-reduction of (m1, n1). i.e. when looking at the final llvm IR, we don't see

load A[m0, 0]
load B[0, n0]  
mac
load A[m0, 1]
load B[0, n0]  
mac
...
load A[m0, K-1]
load B[K-1 n0]  
mac
store C[m0, n0]
load A[m1, 0]
load B[0, n1]
...

but rather we see multiple loads of A[m1, *] interleaved with loads of A[m0, *] (ditto B).

I think you mean on Phoenix and not Strix? And how do you calculate the 55% number?

Yes phoenix! From CI:

matmul_512_512_512_bf16_f32_O3_npu1_4col_callrepl_100_outline_benchmark
--------------------------------------------------------------------------------------------------
Benchmark                                        Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------------------------
BM_matmul/process_time/real_time_mean       172166 us         1488 us            3 items_per_second=5.80834/s
BM_matmul/process_time/real_time_median     172199 us         1564 us            3 items_per_second=5.80724/s
BM_matmul/process_time/real_time_stddev       60.5 us          144 us            3 items_per_second=2.04121m/s
--------------------------------------------------------------------------------------------------
The largest program memory size (read from byte 72 of elf files) is 13552 bytes

And then the same logic as in #1158 (comment)

So:

flops = 100*512*512*512*2
time = 0.172 seconds

observed tflops = 1e-12 * 100*512*512*512*2 / 0.172 = 0.156

Phoenix 16 cores can to 4 tflops, so 1 core does 0.25 tflops (the above test is a single core).

observed / peak = 62.4 % (so 55% was a slight underestimate) O2 took 181897 us which is 59%.

@jtuyls
Copy link
Collaborator

jtuyls commented Mar 12, 2025

By 'in parallel' I mean that we don't process a full k-reduction for (m0, n0) before starting the k-reduction of (m1, n1)

I see. I would avoid 'in parallel' as it conveys multiple parallel macs imo, which is not what is happening. Rather, I would describe it as tiling/processing across the parallel dimensions first, before the reduction dimension.

flops = 1005125125122
time = 0.172 seconds
observed tflops = 1e-12 * 1005125125122 / 0.172 = 0.156
Phoenix 16 cores can to 4 tflops, so 1 core does 0.25 tflops (the above test is a single core).
observed / peak = 62.4 % (so 55% was a slight underestimate) O2 took 181897 us which is 59%.

One caveat is that this CI machine is a Hawk point version I believe, which runs at higher frequency (1.6GHz I believe).

@newling
Copy link
Contributor Author

newling commented Mar 13, 2025

One caveat is that this CI machine is a Hawk point version I believe, which runs at higher frequency (1.6GHz I believe).

The "4 tflops" number for phoenix, what frequency is that based on?

@jtuyls
Copy link
Collaborator

jtuyls commented Mar 13, 2025

One caveat is that this CI machine is a Hawk point version I believe, which runs at higher frequency (1.6GHz I believe).

The "4 tflops" number for phoenix, what frequency is that based on?

1GHz

128 MACs (bf16) * 2 (flops) * 16 cores * 1e9 = 4tflops

@newling
Copy link
Contributor Author

newling commented Mar 13, 2025

One caveat is that this CI machine is a Hawk point version I believe, which runs at higher frequency (1.6GHz I believe).

The "4 tflops" number for phoenix, what frequency is that based on?

1GHz

128 MACs (bf16) * 2 (flops) * 16 cores * 1e9 = 4tflops

Oh I guess I could have solved for the 1 unknown myself in that equation! But darn, the 62% drops to 39%.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants