-
Notifications
You must be signed in to change notification settings - Fork 893
Expand file tree
/
Copy pathStreamOpFolders.cpp
More file actions
4417 lines (3945 loc) · 169 KB
/
StreamOpFolders.cpp
File metadata and controls
4417 lines (3945 loc) · 169 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <algorithm>
#include <optional>
#include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir::iree_compiler::IREE::Stream {
//===----------------------------------------------------------------------===//
// Utilities shared across patterns
//===----------------------------------------------------------------------===//
namespace {
// Returns an integer with a bit width as small as possible to represent the
// input |pattern|, aligned to 8-bits.
//
// Examples:
// 0 : i64 -> 0 : i8
// 1 : i32 -> 1 : i8
// 123 : i32 -> 123 : i8
// 1234 : i32 -> 1234 : i16
// 0xCDCDCDCD : i32 -> 0xCD : i8
static APInt computeRequiredPatternBits(APInt pattern) {
// Special case for well-known constant values.
if (pattern.isZero()) {
return APInt(8, 0u);
}
if (pattern.isAllOnes()) {
return APInt(8, 0xFF);
}
// Extend up to a power of two bit width. This makes the value easier to work
// with as we'll be dealing with one of 4 sizes (1/2/4/8b).
uint64_t bitWidth = llvm::PowerOf2Ceil(pattern.getBitWidth());
if (bitWidth != pattern.getBitWidth()) {
// Extending as we operate - that's not good: users should have taken care
// of this earlier.
return pattern;
}
uint64_t byteWidth = bitWidth / 8;
uint64_t value = pattern.getZExtValue();
switch (byteWidth) {
case 1:
// Can't go smaller than 1 byte.
return pattern;
case 2: {
uint64_t b0 = value & 0xFF;
uint64_t b1 = (value >> 8) & 0xFF;
if (b0 == b1) {
// 0xAAAA : i16 => 0xAA : i8
return APInt(8, value & 0xFF);
}
return pattern;
}
case 4: {
uint64_t b0 = value & 0xFF;
uint64_t b1 = (value >> 8) & 0xFF;
uint64_t b2 = (value >> 16) & 0xFF;
uint64_t b3 = (value >> 24) & 0xFF;
if (b0 == b1 && b0 == b2 && b0 == b3) {
// 0xAAAAAAAA : i32 => 0xAA : i8
return APInt(8, b0);
} else if (b0 == b2 && b1 == b3) {
// 0xAABBAABB : i32 => 0xAABB : i16
return APInt(16, b0 | (b1 << 8));
}
return pattern;
}
case 8: {
uint64_t b0 = value & 0xFF;
uint64_t b1 = (value >> 8) & 0xFF;
uint64_t b2 = (value >> 16) & 0xFF;
uint64_t b3 = (value >> 24) & 0xFF;
uint64_t b4 = (value >> 32) & 0xFF;
uint64_t b5 = (value >> 40) & 0xFF;
uint64_t b6 = (value >> 48) & 0xFF;
uint64_t b7 = (value >> 56) & 0xFF;
if (b0 == b1 && b0 == b2 && b0 == b3 && b0 == b4 && b0 == b5 && b0 == b6 &&
b0 == b7) {
// 0xAAAAAAAAAAAAAAAA : i64 => 0xAA : i8
return APInt(8, b0);
} else if ((b0 == b2 && b0 == b4 && b0 == b6) &&
(b1 == b3 && b1 == b5 && b1 == b7)) {
// 0xAABBAABBAABBAABB : i64 => 0xAABB : i16
return APInt(16, b0 | (b1 << 8));
} else if (b0 == b4 && b1 == b5 && b2 == b6 && b3 == b7) {
// 0xAABBCCDDAABBCCDD : i64 => 0xAABBCCDD : i32
return APInt(32, b0 | (b1 << 8) | (b2 << 16) | (b3 << 24));
}
return pattern;
}
default:
// Unhandled bit width.
return pattern;
}
}
// Narrows the bit width of a splat/fill pattern when known safe to do so.
// Target HAL implementations don't support 64-bit and a real 64-bit splat needs
// to be emulated - if we can avoid that here that's a big win. Some HAL
// implementations (such as Metal) only support 8-bit fills and anything larger
// needs to be implemented as well.
static TypedAttr tryNarrowPatternBits(TypedAttr patternAttr) {
// Get the old pattern bitcast to an APInt. Splats are bitwise operations
// and we don't care what the value originally was.
APInt oldPattern;
if (auto floatAttr = dyn_cast<FloatAttr>(patternAttr)) {
oldPattern = floatAttr.getValue().bitcastToAPInt();
} else if (auto intAttr = dyn_cast<IntegerAttr>(patternAttr)) {
oldPattern = intAttr.getValue();
} else {
// Can't handle today.
return patternAttr;
}
// Don't handle values <= 8 bits. We are narrowing to a minimum of 8-bits and
// we don't have signedness information to know how to extend them.
if (oldPattern.getBitWidth() <= 8) {
return patternAttr;
}
// Try narrowing the pattern.
auto newPattern = computeRequiredPatternBits(oldPattern);
if (newPattern.getBitWidth() == oldPattern.getBitWidth()) {
return patternAttr;
}
// Wrap the result in an attribute - note that it is always an integer.
return IntegerAttr::get(
IntegerType::get(patternAttr.getContext(), newPattern.getBitWidth()),
newPattern);
}
// Tries to narrow constant splat/fill patterns to a smaller bit width.
template <typename Op>
struct NarrowFillPattern : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op fillOp,
PatternRewriter &rewriter) const override {
// Try narrowing the pattern.
TypedAttr oldPatternAttr;
if (!matchPattern(fillOp.getValue(), m_Constant(&oldPatternAttr))) {
return failure();
}
auto newPatternAttr = tryNarrowPatternBits(oldPatternAttr);
if (newPatternAttr == oldPatternAttr) {
return failure();
}
// Replace the pattern on the op with the new one.
auto narrowValue =
arith::ConstantOp::create(rewriter, fillOp.getLoc(), newPatternAttr);
rewriter.modifyOpInPlace(
fillOp, [&]() { fillOp.getValueMutable().assign(narrowValue); });
return success();
}
};
// Returns the stream.yield op in |block| if it is the only op.
//
// Example:
// stream.async.concurrent ... {
// stream.yield
// }
static std::optional<IREE::Stream::YieldOp> getYieldIfOnlyOp(Block &block) {
if (block.empty()) {
return std::nullopt;
}
if (&block.front() != &block.back()) {
return std::nullopt;
}
auto yieldOp = dyn_cast<IREE::Stream::YieldOp>(block.back());
if (yieldOp) {
return yieldOp;
}
return std::nullopt;
}
// Various patterns try to sink ops, and in case of uses in multiple blocks
// they might be sunk to the end of a block. When multiple such ops are being
// sunk, they can "fight" over who is at the end of the block, resulting in
// infinite pattern recursion. To avoid this, we need to collectively know
// across patterns which ops are liable to be sunk that way.
static bool isSinkCandidate(Operation *op) {
return isa<AsyncSplatOp, AsyncAllocaOp, TimepointAwaitOp>(op);
}
// Determine if sinking |toBeSunkOp| before |targetOp| won't result in an
// unstable oscillation across patterns. Oscillations can occur if there
// are multiple ops inserted before a single op as insertion order based on
// canonicalization is undefined.
//
// Example:
// %0 = op.a
// %1 = op.b
// %2 = op.c %0, %1
// If %0 and %1 are sunk to %2 the ordering will depend on which sink pattern
// runs first and each of the patterns will fight trying to sink lower than the
// other. As long as sinking only happens when this function returns `true`,
// then the sinking across patterns will reach a fixed-point.
static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) {
// Stably sinking implies that other sinking won't "fight" with this
// sinking. This is obviously not possible in an open pattern ecosystem,
// but for the purpose of this function, we assume that all sinking patterns
// that we are concerned with are the other patterns in the `stream` dialect.
//
// In typical usage, this function will result in various patterns sinking
// their relevant ops before `targetOp`. This results in a sequence of
// sinkable ops before `targetOp`. This is fine, until we start to sink
// them again, which can result in "fighting". We detect that scenario
// by seeing if all the ops between `toBeSunkOp` and `targetOp` might be sunk
// again.
//
// To prove that this function results in sinking that reaches a fixed-point,
// we can design a potential function `f(the_module) -> int`, and show that it
// decreases strictly monotonically with each sinking operation (and cannot go
// below 0). In particular, we choose the following function: `f(the_module) =
// sum(g(op) for op in the_module)`, where `g(op) -> int` gives the distance
// between op's current location and the latest it could appear in the program
// (infinite, if that location is in another block).
assert(isSinkCandidate(toBeSunkOp) && "asking to sink a non-sinkable op");
// If `targetOp` is a terminator, then it might be chosen as a sink location
// purely for control flow reasons, and not due to use-def chains. This means
// that if `targetOp` is not a terminator, then we can prune the set of
// sinkable ops that might fight with `toBeSunkOp` more aggressively by using
// use-def chains.
// The use-def chains check below doesn't detect implicit captures (which can
// be heavy to check) so we also ignore `targetOp` with regions. This can be
// relexed if needed.
bool allowUseDefPruning =
!targetOp->hasTrait<mlir::OpTrait::IsTerminator>() &&
targetOp->getNumRegions() == 0;
// If the sinking operation would be a no-op, then we need to prevent
// the sinking operation, to avoid infinite pattern applications.
if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp))) {
return false;
}
// If the sinking is to a different block, then it okay, since for any later
// sinkings, this reduces the problem to stable sinking within a single
// block (handled below).
if (toBeSunkOp->getBlock() != targetOp->getBlock()) {
return true;
}
SmallPtrSet<Operation *, 4> producerOps;
if (allowUseDefPruning) {
for (auto operand : targetOp->getOperands()) {
if (operand.getDefiningOp()) {
producerOps.insert(operand.getDefiningOp());
}
}
}
// If any of the ops between `toBeSunkOp` and `targetOp` are known to not
// fight with this op, then it is stable to sink.
for (Operation &op : llvm::make_range(Block::iterator(toBeSunkOp),
Block::iterator(targetOp))) {
// If the intervening op that is not even a sink candidate itself,
// then it cannot fight.
if (!isSinkCandidate(&op)) {
return true;
}
// If the op is pruned by use-def chains, then it won't fight.
if (allowUseDefPruning && !producerOps.contains(&op)) {
return true;
}
}
return false;
}
// Sinks |op| down to |targetOp|, ensuring that we don't oscillate.
// Returns success if the op was sunk and failure if sinking was not needed.
static LogicalResult sinkOp(Operation *op, Operation *targetOp) {
if (!canStablySinkTo(op, targetOp)) {
return failure();
}
op->moveBefore(targetOp);
return success();
}
// Sets |rewriter| to point immediately before the parent execution region.
// Example:
// %0 =
// <-- insertion point set to here -->
// stream.async.execute ... {
// %1 = op
// }
static void setInsertionPointToParentExecutionScope(Operation *op,
PatternRewriter &rewriter) {
if (auto parentOp = op->getParentOfType<AsyncExecuteOp>()) {
rewriter.setInsertionPoint(parentOp);
} else if (auto parentOp = op->getParentOfType<CmdExecuteOp>()) {
rewriter.setInsertionPoint(parentOp);
} else {
assert(false && "must be nested within an execution region");
}
}
// Erases an op if it has no uses.
// This is to support ops that are "pure" but can't be marked as such because
// the MLIR CSE pass would deduplicate them.
template <typename Op>
struct ElideUnusedOp : OpRewritePattern<Op> {
explicit ElideUnusedOp(MLIRContext *context)
: OpRewritePattern<Op>(context, /*benefit=*/1000) {}
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
if (!op.use_empty()) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
};
// Clones ops that prefer to be cloned directly.
// This prevents us from splatting out a value and then cloning that (keeping
// the memory live/etc) instead of just splatting it again on-demand.
//
// Example:
// %0 = stream.async.splat %c123_i32
// %1 = stream.async.clone %0
// ->
// %1 = stream.async.splat %c123_i32
template <typename Op>
struct PropagateCloneableOps : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op cloneOp,
PatternRewriter &rewriter) const override {
if (cloneOp.use_empty()) {
// No consumers to clone for.
return failure();
}
auto sourceOp =
cloneOp.getSource()
.template getDefiningOp<IREE::Stream::StreamableOpInterface>();
if (!sourceOp || !sourceOp.preferCloneToConsumers()) {
// Only look at cloneable producer ops.
return failure();
}
for (auto &use :
llvm::make_early_inc_range(cloneOp.getResult().getUses())) {
auto result = cast<OpResult>(use.get());
rewriter.setInsertionPoint(use.getOwner());
auto clonedOp = rewriter.clone(*sourceOp);
auto clonedResult = clonedOp->getResult(result.getResultNumber());
clonedResult.setType(use.get().getType());
use.set(clonedResult);
}
if (cloneOp.use_empty()) {
rewriter.eraseOp(cloneOp);
}
return success();
}
};
// Ties the results of execution region to their operands when the region
// operations are tied throughout the entire body.
//
// Example:
// %ret:2 = stream.async.execute with(%src as %arg0) -> !stream.resource<*> {
// %2 = stream.async.dispatch ... (%arg0) -> %arg0
// stream.yield %2
// }
// ->
// %ret:2 = stream.async.execute with(%src as %arg0) -> %src {
// %2 = stream.async.dispatch ... (%arg0) -> %arg0
// stream.yield %2
// }
template <typename Op>
struct TieRegionResults : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
assert(op.getRegion().getBlocks().size() == 1 &&
"only one stream block supported");
bool didModify = false;
for (auto yieldOp : op.template getOps<IREE::Stream::YieldOp>()) {
for (auto result : llvm::enumerate(yieldOp.getResourceOperands())) {
if (op.getTiedResultOperandIndex(result.index()).has_value()) {
continue; // Already tied.
}
auto baseValue =
IREE::Util::TiedOpInterface::findTiedBaseValue(result.value());
if (auto blockArg = dyn_cast<BlockArgument>(baseValue)) {
unsigned operandIndex = blockArg.getArgNumber();
rewriter.modifyOpInPlace(op, [&]() {
op.setTiedResultOperandIndex(result.index(), operandIndex);
});
didModify = true;
}
}
}
return didModify ? success() : failure();
}
};
// Adds await dependencies on |newTimepoints| to the op with an optional
// |existingTimepoint| by possibly producing a new timepoint to await.
// This may just pass through the provided timepoint or create a join based on
// the existing await behavior of the op and the new values.
static Value joinAwaitTimepoints(Location loc, Value existingTimepoint,
ArrayRef<Value> newTimepoints,
OpBuilder &builder) {
if (newTimepoints.empty()) {
// No new timepoints - preserve existing.
return existingTimepoint;
} else if (newTimepoints.size() == 1 && !existingTimepoint) {
// Adding a single new timepoint.
return newTimepoints.front();
}
// Materialize a join of the new timepoints + the existing (if present).
SmallVector<Value> joinTimepoints;
if (existingTimepoint) {
joinTimepoints.push_back(existingTimepoint);
}
llvm::append_range(joinTimepoints, newTimepoints);
return IREE::Stream::TimepointJoinOp::join(loc, joinTimepoints, builder);
}
// Elides waits that are known to be immediately resolved.
//
// Example:
// %0 = stream.timepoint.immediate
// %1 = stream.resource.alloca await(%0) ...
// ->
// %1 = stream.resource.alloca ...
template <typename Op>
struct ElideImmediateTimepointWait : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
bool isImmediate =
op.getAwaitTimepoint() && isa_and_nonnull<TimepointImmediateOp>(
op.getAwaitTimepoint().getDefiningOp());
if (!isImmediate) {
return failure();
}
rewriter.modifyOpInPlace(op,
[&]() { op.getAwaitTimepointMutable().clear(); });
return success();
}
};
// Chains operand resources produced by an await to dependent execution regions.
// This elides host waits and allows for device-side wait resolution.
//
// Example:
// %0 = stream.cmd.execute with(%resource)
// %1 = stream.timepoint.await %0 => %resource
// %2 = stream.cmd.execute with(%resource)
// ->
// %0 = stream.cmd.execute with(%resource)
// %2 = stream.cmd.execute await(%0) => with(%resource)
template <typename Op>
struct ChainDependentAwaits : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
SmallVector<Value> newTimepoints;
SmallVector<std::pair<unsigned, Value>> replacements;
for (auto operand : llvm::enumerate(op.getResourceOperands())) {
if (auto awaitOp =
operand.value().template getDefiningOp<TimepointAwaitOp>()) {
if (!awaitOp.getSync()) {
newTimepoints.push_back(awaitOp.getAwaitTimepoint());
replacements.push_back(std::make_pair(
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
}
}
}
if (replacements.empty()) {
return failure();
}
rewriter.modifyOpInPlace(op, [&]() {
op.setAwaitTimepoints(newTimepoints, rewriter);
for (auto replacement : replacements) {
op.getResourceOperandsMutable()
.slice(replacement.first, 1)
.assign(replacement.second);
}
});
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// stream.resource.alloc
//===----------------------------------------------------------------------===//
void ResourceAllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): sink to first user.
}
//===----------------------------------------------------------------------===//
// stream.resource.alloca
//===----------------------------------------------------------------------===//
namespace {
// Elides transient allocations that have no uses of their resource.
// This sometimes arises when operations that were using the resource are
// DCEd by other patterns or passes. The ElideAllocaDeallocaOp pattern will be
// used after deallocations have been inserted but prior to that point this
// pattern allows for more eager removal of unused allocations.
struct ElideUnusedAllocaOp : OpRewritePattern<ResourceAllocaOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourceAllocaOp allocaOp,
PatternRewriter &rewriter) const override {
if (!allocaOp.getResult().use_empty()) {
return failure(); // >= 1 user
}
Value newTimepoint = allocaOp.getAwaitTimepoint();
if (!newTimepoint) {
newTimepoint = IREE::Stream::TimepointImmediateOp::create(
rewriter, allocaOp.getLoc());
}
rewriter.replaceAllUsesWith(allocaOp.getResultTimepoint(), newTimepoint);
rewriter.eraseOp(allocaOp);
return success();
}
};
// Elides transient allocations that are only used by deallocations.
// This sometimes arises when operations that were using the resource are
// DCEd by other patterns or passes.
//
// Example:
// %resource, %alloca_t = stream.resource.alloca
// %dealloca_t = stream.resource.dealloca await(%alloca_t) %resource
struct ElideAllocaDeallocaOp : OpRewritePattern<ResourceAllocaOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourceAllocaOp allocaOp,
PatternRewriter &rewriter) const override {
if (!allocaOp.getResult().hasOneUse()) {
return failure(); // more than one user
}
auto user = *allocaOp.getResult().getUsers().begin();
auto deallocaOp = dyn_cast<IREE::Stream::ResourceDeallocaOp>(user);
if (!deallocaOp) {
return failure(); // not used by a dealloca
}
// Replace waiters on the alloca and dealloca.
// Note that the dealloca may be using the timepoint of the alloca so we
// replace that first.
Value newAllocaTimepoint = allocaOp.getAwaitTimepoint();
if (!newAllocaTimepoint) {
newAllocaTimepoint = IREE::Stream::TimepointImmediateOp::create(
rewriter, allocaOp.getLoc());
}
rewriter.replaceAllUsesWith(allocaOp.getResultTimepoint(),
newAllocaTimepoint);
Value newDeallocaTimepoint = deallocaOp.getAwaitTimepoint();
if (!newDeallocaTimepoint) {
newDeallocaTimepoint = IREE::Stream::TimepointImmediateOp::create(
rewriter, deallocaOp.getLoc());
}
rewriter.replaceAllUsesWith(deallocaOp.getResultTimepoint(),
newDeallocaTimepoint);
// Erase the deallocation first (its the only user of the allocated
// resource).
rewriter.eraseOp(deallocaOp);
rewriter.eraseOp(allocaOp);
return success();
}
};
// Finds sequences of chained allocas/deallocas and rewrites them to batch as
// many as possible on a single timepoint. This is done as a canonicalization as
// it is always intended that allocations and deallocations do not wait and we
// can repeatedly optimize when run as part of a larger canonicalization pass
// that cleans up timepoints with other patterns as we modify them here.
//
// Example:
// %d0 = dealloca await(%t)
// %d1 = dealloca await(%d0)
// %d2 = dealloca await(%d1)
// %d3 = dealloca await(%d2)
// ... await(%d3)
// ->
// %d0 = dealloca await(%t)
// %d1 = dealloca await(%t)
// %d2 = dealloca await(%t)
// %d3 = dealloca await(%t)
// %j = join %d0, %d1, %d2, %d3
// ... await(%j)
template <typename OpT>
struct BatchAllocaOps : OpRewritePattern<OpT> {
using OpRewritePattern<OpT>::OpRewritePattern;
LogicalResult matchAndRewrite(OpT op,
PatternRewriter &rewriter) const override {
// Gather alloca ops chained on timepoints starting from this op.
SmallVector<OpT> allocaOps;
OpT nextOp = op;
while (nextOp) {
Value resultTimepoint = nextOp.getResultTimepoint();
allocaOps.push_back(nextOp);
if (!resultTimepoint.hasOneUse()) {
break;
}
nextOp = dyn_cast<OpT>(*resultTimepoint.user_begin());
}
if (allocaOps.size() <= 1) {
return failure(); // no-op if only one op
}
// Gather the result timepoints of all alloca ops so we can join on them.
// We'll issue all of them concurrently and only join after all
// deallocations complete.
SmallVector<Location> allocaLocs;
SmallVector<Value> allocaTimepoints;
for (auto allocaOp : allocaOps) {
allocaLocs.push_back(allocaOp.getLoc());
allocaTimepoints.push_back(allocaOp.getResultTimepoint());
}
rewriter.setInsertionPointAfter(allocaOps.back());
auto joinOp = IREE::Stream::TimepointJoinOp::create(
rewriter, rewriter.getFusedLoc(allocaLocs),
rewriter.getType<IREE::Stream::TimepointType>(), allocaTimepoints);
// Make all alloca ops wait on the earliest timepoint so they can proceed
// together. Note that the origin op may be waiting on an immediate
// timepoint and be nullptr.
Value awaitTimepoint = op.getAwaitTimepoint();
for (auto allocaOp : allocaOps) {
rewriter.modifyOpInPlace(allocaOp, [&]() {
allocaOp.getAwaitTimepointMutable().assign(awaitTimepoint);
});
}
// Replace the tail timepoint in the alloca chain with the join result so
// subsequent waiters are waiting on the batch.
rewriter.replaceAllUsesExcept(allocaOps.back().getResultTimepoint(),
joinOp.getResultTimepoint(), joinOp);
return success();
}
};
} // namespace
void ResourceAllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): sink to first user.
results.insert<ElideUnusedAllocaOp>(context);
results.insert<ElideAllocaDeallocaOp>(context);
results.insert<BatchAllocaOps<ResourceAllocaOp>>(context);
results.insert<ElideImmediateTimepointWait<ResourceAllocaOp>>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.dealloca
//===----------------------------------------------------------------------===//
void ResourceDeallocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): move up to producer of timepoint.
results.insert<BatchAllocaOps<ResourceDeallocaOp>>(context);
results.insert<ElideImmediateTimepointWait<ResourceDeallocaOp>>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.size
//===----------------------------------------------------------------------===//
OpFoldResult ResourceSizeOp::fold(FoldAdaptor operands) {
auto sizeAwareType =
cast<IREE::Util::SizeAwareTypeInterface>(getOperand().getType());
Operation *op = this->getOperation();
Value sizeValue = sizeAwareType.findSizeValue(getOperand(), op->getBlock(),
Block::iterator(op));
// Do not fold if we found ourselves (would cause infinite fold loop).
if (sizeValue != getResult()) {
return sizeValue;
}
return {};
}
namespace {
// Propagates resource sizes through select ops by selecting on the sizes of the
// select operands.
//
// Example:
// %a = stream... : !stream.resource<*>{%a_sz}
// %b = stream... : !stream.resource<*>{%b_sz}
// %c = select %cond, %a, %b : !stream.resource<*>
// %c_sz = stream.resource.size %c : !stream.resource<*>
// ->
// %c = select %cond, %a, %b : !stream.resource<*>
// %c_sz = select %cond, %a_sz, %b_sz : index
struct SelectResourceSizeOp : OpRewritePattern<ResourceSizeOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourceSizeOp op,
PatternRewriter &rewriter) const override {
auto selectOp = op.getOperand().getDefiningOp<mlir::arith::SelectOp>();
if (!selectOp) {
return failure();
}
auto trueSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
op.getLoc(), selectOp.getTrueValue(), op.getAffinityAttr());
auto falseSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
op.getLoc(), selectOp.getFalseValue(), op.getAffinityAttr());
rewriter.replaceOpWithNewOp<mlir::arith::SelectOp>(
op, selectOp.getCondition(), trueSize, falseSize);
return success();
}
};
} // namespace
void ResourceSizeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<SelectResourceSizeOp>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.try_map
//===----------------------------------------------------------------------===//
void ResourceTryMapOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): fold subviews up into maps to limit range.
// TODO(benvanik): if mapping for staging then turn into a map?
results.insert<ElideUnusedOp<ResourceTryMapOp>>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.load
//===----------------------------------------------------------------------===//
namespace {
// Folds subview offsets into loads.
//
// Example:
// %0 = stream.resource.subview %src[%subview_offset] ... -> {%subview_length}
// %1 = stream.resource.load %0[%offset]
// ->
// %new_offset = arith.addi %offset, %subview_offset
// %1 = stream.resource.load %src[%new_offset]
struct FoldSubviewIntoLoadOp : OpRewritePattern<ResourceLoadOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourceLoadOp op,
PatternRewriter &rewriter) const override {
auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getSource());
if (!subviewOp) {
return failure();
}
auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
auto newOffset = rewriter.createOrFold<arith::AddIOp>(
fusedLoc, subviewOp.getSourceOffset(), op.getSourceOffset());
rewriter.modifyOpInPlace(op, [&]() {
op.getSourceMutable().assign(subviewOp.getSource());
op.getSourceSizeMutable().assign(subviewOp.getSourceSize());
op.getSourceOffsetMutable().assign(newOffset);
});
return success();
}
};
} // namespace
void ResourceLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): if staging resource comes from splat (through transfers)
// then pull splat value.
// TODO(benvanik): combine multiple loads from the same target if contiguous.
// TODO(benvanik): value->transfer->load -> value->slice->transfer->load?
results.insert<FoldSubviewIntoLoadOp>(context);
results.insert<ElideUnusedOp<ResourceLoadOp>>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.store
//===----------------------------------------------------------------------===//
namespace {
// Folds subview offsets into stores.
//
// Example:
// %0 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
// stream.resource.store %c123_i32, %0[%offset]
// ->
// %new_offset = arith.addi %offset, %subview_offset
// stream.resource.store %c123_i32, %dst[%new_offset]
struct FoldSubviewIntoStoreOp : OpRewritePattern<ResourceStoreOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourceStoreOp op,
PatternRewriter &rewriter) const override {
auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget());
if (!subviewOp) {
return failure();
}
auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
auto newOffset = rewriter.createOrFold<arith::AddIOp>(
fusedLoc, subviewOp.getSourceOffset(), op.getTargetOffset());
rewriter.modifyOpInPlace(op, [&]() {
op.getTargetMutable().assign(subviewOp.getSource());
op.getTargetSizeMutable().assign(subviewOp.getSourceSize());
op.getTargetOffsetMutable().assign(newOffset);
});
return success();
}
};
} // namespace
void ResourceStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): combine multiple stores to the same target if contiguous.
// TODO(benvanik): if value is a constant splat then turn into fill?
results.insert<FoldSubviewIntoStoreOp>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.pack
//===----------------------------------------------------------------------===//
LogicalResult ResourcePackOp::fold(FoldAdaptor operands,
SmallVectorImpl<OpFoldResult> &results) {
Builder builder(getContext());
// If there are no slices then the entire pack results in a zero-length slab.
if (getPackedOffsets().empty()) {
results.push_back(builder.getZeroAttr(builder.getIndexType()));
return success();
}
// If there's a single slice then we just use that as there is no packing to
// perform.
if (getPackedOffsets().size() == 1) {
// Total length is the slice size and offset is always either 0 or the
// provided optional base offset.
results.push_back(getDynamicSliceSizes()[0]);
if (getOffset()) {
results.push_back(getOffset());
} else {
results.push_back(builder.getZeroAttr(builder.getIndexType()));
}
return success();
}
return failure();
}
namespace {
// Propagates base offsets on a pack op to its results.
// This allows for better folding of the results after packing has completed.
// The offset value is just a convenience for when splitting pack ops and has
// no impact on the actual packing operation.
struct PropagateResourcePackBaseOffset : OpRewritePattern<ResourcePackOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourcePackOp op,
PatternRewriter &rewriter) const override {
// Offset is optional.
auto baseOffset = op.getOffset();
if (!baseOffset) {
return failure();
}
// We always strip the offset here.
rewriter.modifyOpInPlace(op, [&]() { op.getOffsetMutable().clear(); });
// Zero offsets don't do anything and can just be removed so we can avoid
// inserting a bunch of additional IR.
if (auto constantOp = baseOffset.getDefiningOp<arith::ConstantIndexOp>()) {
if (constantOp.getValue() == 0) {
return success();
}
}
// Propagate the offset to all returned slice offsets.
rewriter.setInsertionPointAfter(op);
for (auto sliceOffset : op.getPackedOffsets()) {
auto addOp =
arith::AddIOp::create(rewriter, op.getLoc(), baseOffset, sliceOffset);
rewriter.replaceAllUsesExcept(sliceOffset, addOp.getResult(), addOp);
}
return success();
}
};
// Sorts and compacts the slice intervals into a dense ascending order set.
// This is not required by the packing algorithm but yields more
// consistent-looking IR and makes the range overlaps easier to see for us
// meatbags.
//
// Example:
// %0:3 = stream.resource.pack slices({
// [1, 2] = %size,
// [0, 4] = %size,
// }) : index
// ->
// %0:3 = stream.resource.pack slices({
// [0, 4] = %size,
// [1, 2] = %size,
// }) : index
struct CanonicalizeResourcePackIntervals : OpRewritePattern<ResourcePackOp> {
using Base::Base;
LogicalResult matchAndRewrite(ResourcePackOp op,
PatternRewriter &rewriter) const override {
// Get the slices in a possibly unsorted order and sort.
auto slices = op.getSlices();
std::stable_sort(slices.begin(), slices.end());
// See if the sorted order is different than how they are stored in the op.
bool orderChanged = false;
for (auto [slice, packedOffset] :
llvm::zip_equal(slices, op.getPackedOffsets())) {
if (slice.packedOffset != packedOffset) {
orderChanged = true;
break;
}
}
if (!orderChanged) {
return failure();
}
// TODO(benvanik): compact the slice ranges.
// Rebuild the op with the sorted values.
SmallVector<int64_t> lifetimeIntervals(slices.size() * 2);
SmallVector<Value> dynamicSliceSizes(slices.size());
for (size_t i = 0; i < slices.size(); ++i) {
const auto &slice = slices[i];
lifetimeIntervals[2 * i + 0] = slice.lifetimeStart;
lifetimeIntervals[2 * i + 1] = slice.lifetimeEnd;
dynamicSliceSizes[i] = slice.dynamicSize;
}
SmallVector<Type> packedOffsetTypes(slices.size(), rewriter.getIndexType());
auto newOp = ResourcePackOp::create(
rewriter, op.getLoc(), op.getTotalLength().getType(), packedOffsetTypes,
op.getOffset(), rewriter.getIndexArrayAttr(lifetimeIntervals),
dynamicSliceSizes, op.getAffinityAttr());
// Remap existing values to the new values.
rewriter.replaceAllUsesWith(op.getTotalLength(), newOp.getTotalLength());
for (size_t i = 0; i < newOp.getPackedOffsets().size(); ++i) {
rewriter.replaceAllUsesWith(slices[i].packedOffset,
newOp.getPackedOffsets()[i]);
}
rewriter.eraseOp(op);
return success();
}
};
} // namespace
void ResourcePackOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<PropagateResourcePackBaseOffset>(context);
results.insert<CanonicalizeResourcePackIntervals>(context);
}
//===----------------------------------------------------------------------===//
// stream.resource.pack
//===----------------------------------------------------------------------===//
OpFoldResult ResourceSubviewOp::fold(FoldAdaptor operands) {
if (getSourceSize() == getResultSize()) {
// Entire range is covered; return it all.
return getSource();