Skip to content

Commit e0a17de

Browse files
committed
Adding the stream.resource.cast op.
We had a potential correctness issue when using chained external fences and returning values (vs writing to output arguments) where we'd insert a `stream.async.transfer` as effectively just a cast to external lifetime for returned tensors. The problem is that the `stream.timepoint.barrier` feeding the chain_external op was *before* the transfer, meaning that if the user did wait on the fence and consume the returned value they may be consuming it *before* the transfer has executed. We're mostly saved today by most usage being through the synchronous ABI or torch placing results into outputs as well as most transfers being elided, but it was not guaranteed. This adds a `stream.resource.cast` that just does lifetime assertions and pins values in usage refinement. This allows us to import/export and cast to avoid any potential for copies to arise. Future changes will use this op in a timeline verification pass that checks that resources produced by every StreamableOp are consumed using an appropriate timeline.
1 parent cdcbfd3 commit e0a17de

10 files changed

Lines changed: 218 additions & 21 deletions

File tree

compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,14 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
402402
DFX::Resolution::REQUIRED);
403403
getState() ^= tiedUsage.getState();
404404
})
405+
.Case([&](IREE::Stream::ResourceCastOp op) {
406+
// Cast is a tied passthrough - propagate source usage to result.
407+
// The result type's lifetime constraints are set during init.
408+
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
409+
*this, Position::forValue(op.getSource()),
410+
DFX::Resolution::REQUIRED);
411+
getState() ^= sourceUsage.getState();
412+
})
405413
.Case([&](IREE::Stream::AsyncTransferOp op) {
406414
removeAssumedBits(NOT_TRANSFER_WRITE);
407415
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
@@ -753,6 +761,14 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
753761
DFX::Resolution::OPTIONAL);
754762
getState() ^= resultUsage.getState();
755763
})
764+
.Case([&](IREE::Stream::ResourceCastOp op) {
765+
// Cast is a tied passthrough - propagate result usage to source.
766+
// The source type's lifetime constraints are set during init.
767+
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
768+
*this, Position::forValue(op.getResult()),
769+
DFX::Resolution::OPTIONAL);
770+
getState() ^= resultUsage.getState();
771+
})
756772
.Case([&](IREE::Stream::AsyncTransferOp op) {
757773
removeAssumedBits(NOT_TRANSFER_READ);
758774
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,13 @@ struct ConvertTensorImportOp
8484
.getResult(0);
8585
}
8686

87+
// Cast to unknown lifetime for use in the program.
88+
// This is a compile-time assertion that the imported resource shares
89+
// storage with the program's view of it. RefineUsage will propagate the
90+
// external constraint and the cast will fold when types match.
8791
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
88-
Value newImport = IREE::Stream::AsyncTransferOp::create(
89-
rewriter, op.getLoc(), unknownType, resource, resultSize, resultSize,
90-
/*source_affinity=*/executionAffinityAttr,
91-
/*target_affinity=*/executionAffinityAttr);
92+
Value newImport = IREE::Stream::ResourceCastOp::create(
93+
rewriter, op.getLoc(), unknownType, resource, resultSize);
9294
rewriter.replaceOpWithMultiple(op, {{newImport, resultSize}});
9395
return success();
9496
}
@@ -158,17 +160,17 @@ struct ConvertTensorExportOp
158160
transferTensorOperands(op.getLoc(), op.getSource(), adaptor.getSource(),
159161
executionAffinityAttr, rewriter);
160162

161-
// Exporting a produced value - transfer our source value to an externally
162-
// usable resource and directly export it. This will cause an allocation.
163+
// Exporting a produced value - cast to external lifetime assertion.
164+
// This is a compile-time assertion that the source resource shares storage
165+
// with the exported buffer. RefineUsage will propagate the external
166+
// constraint backward and the cast will fold when types match.
163167
Value exportSource = adaptor.getSource().front();
164168
auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
165169
IREE::Stream::Lifetime::External);
166170
if (source.resource.getType() != externalType) {
167-
exportSource = IREE::Stream::AsyncTransferOp::create(
171+
exportSource = IREE::Stream::ResourceCastOp::create(
168172
rewriter, op.getLoc(), externalType, source.resource,
169-
source.resourceSize, source.resourceSize,
170-
/*source_affinity=*/source.affinity,
171-
/*target_affinity=*/executionAffinityAttr);
173+
source.resourceSize);
172174
}
173175

174176
// Export (stream resource to buffer view).

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ util.func public @importBufferView(%view: !hal.buffer_view) -> tensor<?x?x4xf32>
1111
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} : index
1212
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
1313
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
14-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] :
14+
// CHECK-NEXT: %[[RESULT:.+]] = stream.resource.cast %[[RESOURCE]] :
1515
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
1616
%0 = hal.tensor.import %view : !hal.buffer_view -> tensor<?x?x4xf32>{%dim0, %dim1}
1717
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
@@ -26,7 +26,7 @@ util.func public @importBufferViewBitcasting(%view: !hal.buffer_view) -> tensor<
2626
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<4xbf16>
2727
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
2828
// CHECK-SAME: tensor<2xui32> in !stream.resource<external>{%[[SIZE]]}
29-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] :
29+
// CHECK-NEXT: %[[RESULT:.+]] = stream.resource.cast %[[RESOURCE]] :
3030
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
3131
%0 = hal.tensor.import %view : !hal.buffer_view -> tensor<2xui32> as tensor<4xbf16>
3232
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
@@ -45,7 +45,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe
4545
// CHECK: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[FENCE]]
4646
// CHECK: %[[SYNC_RESOURCE:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[ASYNC_RESOURCE]]
4747
// CHECK-SAME: : !stream.resource<external>{%[[SIZE]]}
48-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SYNC_RESOURCE]]
48+
// CHECK-NEXT: %[[RESULT:.+]] = stream.resource.cast %[[SYNC_RESOURCE]]
4949
// CHECK-SAME: : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
5050
%0 = hal.tensor.import wait(%fence) => %view : !hal.buffer_view -> tensor<4xf32>
5151
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
@@ -57,7 +57,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe
5757
// CHECK-LABEL: @exportBufferView
5858
// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
5959
util.func public @exportBufferView(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: index) -> !hal.buffer_view {
60-
// CHECK: %[[VIEW:.+]] = stream.async.clone %[[TENSOR]] :
60+
// CHECK: %[[VIEW:.+]] = stream.resource.cast %[[TENSOR]] :
6161
// CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
6262
// CHECK-NEXT: %[[RESULT:.+]] = stream.tensor.export %[[VIEW]] :
6363
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
@@ -127,7 +127,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor
127127
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.promise<@dev_a>) tensor<4xf32>
128128
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import on(#hal.device.promise<@dev_a>) %[[VIEW]] : !hal.buffer_view ->
129129
// CHECK-SAME: tensor<4xf32> in !stream.resource<external>{%[[SIZE]]}
130-
// CHECK-NEXT: %[[CLONE:.+]] = stream.async.clone on(#hal.device.promise<@dev_a>) %[[RESOURCE]] :
130+
// CHECK-NEXT: %[[CLONE:.+]] = stream.resource.cast %[[RESOURCE]] :
131131
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
132132
%0 = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xf32>
133133
// CHECK: util.return %[[CLONE]], %[[SIZE]] : !stream.resource<*>, index
@@ -144,7 +144,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor
144144
util.func public @exportBufferViewCrossDevice(%tensor: tensor<4xf32>) -> !hal.buffer_view attributes {
145145
stream.affinity = #hal.device.promise<@dev_a>
146146
} {
147-
// CHECK: %[[CLONE:.+]] = stream.async.clone %[[TENSOR]] :
147+
// CHECK: %[[CLONE:.+]] = stream.resource.cast %[[TENSOR]] :
148148
// CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
149149
// CHECK-NEXT: %[[VIEW:.+]] = stream.tensor.export %[[CLONE]] :
150150
// CHECK-SAME: tensor<4xf32> in !stream.resource<external>{%[[SIZE]]}

compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ util.func public @globalStoreFromExternal(%arg0: !hal.buffer_view) {
8383
%dim0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
8484
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
8585
// CHECK: %[[IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[SIZE]]}
86-
// CHECK: %[[T:.+]] = stream.async.transfer %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
86+
// CHECK: %[[T:.+]] = stream.resource.cast %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
8787
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
8888
// CHECK: %[[VAR:.+]] = stream.async.transfer %[[T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<variable>{%[[SIZE]]}
8989
// CHECK: util.global.store %[[VAR]], @var_with_buffer_view_store : !stream.resource<variable>

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,80 @@ void ResourceSubviewOp::getCanonicalizationPatterns(RewritePatternSet &results,
10491049
results.insert<SinkSubviewAcrossSelectOps>(context);
10501050
}
10511051

1052+
//===----------------------------------------------------------------------===//
1053+
// stream.resource.cast
1054+
//===----------------------------------------------------------------------===//
1055+
1056+
OpFoldResult ResourceCastOp::fold(FoldAdaptor operands) {
1057+
// Fold away cast if source and result types match.
1058+
if (getSource().getType() == getResult().getType()) {
1059+
return getSource();
1060+
}
1061+
// Fold chain: cast(cast(x)) -> cast(x) when outer result type matches x.
1062+
if (auto sourceCastOp = getSource().getDefiningOp<ResourceCastOp>()) {
1063+
if (sourceCastOp.getSource().getType() == getResult().getType()) {
1064+
return sourceCastOp.getSource();
1065+
}
1066+
}
1067+
return {};
1068+
}
1069+
1070+
namespace {
1071+
1072+
// Collapses chains of casts into a single cast to the final type.
1073+
// This handles cases where fold() can't apply because the types don't match.
1074+
struct CollapseResourceCastChain : public OpRewritePattern<ResourceCastOp> {
1075+
using OpRewritePattern::OpRewritePattern;
1076+
LogicalResult matchAndRewrite(ResourceCastOp op,
1077+
PatternRewriter &rewriter) const override {
1078+
auto sourceCastOp = op.getSource().getDefiningOp<ResourceCastOp>();
1079+
if (!sourceCastOp)
1080+
return failure();
1081+
// Skip if fold() would apply.
1082+
if (sourceCastOp.getSource().getType() == op.getResult().getType())
1083+
return failure();
1084+
// Replace with single cast from original source to final type.
1085+
rewriter.replaceOpWithNewOp<ResourceCastOp>(op, op.getResult().getType(),
1086+
sourceCastOp.getSource(),
1087+
sourceCastOp.getSourceSize());
1088+
return success();
1089+
}
1090+
};
1091+
1092+
// Commutes cast through subview: cast(subview(x)) -> subview(cast(x)).
1093+
// This pushes casts earlier to increase folding opportunities.
1094+
struct CommuteCastThroughSubview : public OpRewritePattern<ResourceCastOp> {
1095+
using OpRewritePattern::OpRewritePattern;
1096+
LogicalResult matchAndRewrite(ResourceCastOp op,
1097+
PatternRewriter &rewriter) const override {
1098+
auto subviewOp = op.getSource().getDefiningOp<ResourceSubviewOp>();
1099+
if (!subviewOp)
1100+
return failure();
1101+
// Only commute if subview has single use (the cast).
1102+
if (!subviewOp->hasOneUse())
1103+
return failure();
1104+
// Create cast of the subview source with the target lifetime.
1105+
auto castResultType = op.getResult().getType();
1106+
auto newCast = ResourceCastOp::create(rewriter, op.getLoc(), castResultType,
1107+
subviewOp.getSource(),
1108+
subviewOp.getSourceSize());
1109+
// Create new subview with cast result.
1110+
rewriter.replaceOpWithNewOp<ResourceSubviewOp>(
1111+
op, castResultType, newCast, subviewOp.getSourceSize(),
1112+
subviewOp.getSourceOffset(), subviewOp.getResultSize());
1113+
// Clean up the old subview.
1114+
rewriter.eraseOp(subviewOp);
1115+
return success();
1116+
}
1117+
};
1118+
1119+
} // namespace
1120+
1121+
void ResourceCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1122+
MLIRContext *context) {
1123+
results.insert<CollapseResourceCastChain, CommuteCastThroughSubview>(context);
1124+
}
1125+
10521126
//===----------------------------------------------------------------------===//
10531127
// stream.file.read
10541128
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,30 @@ IREE::Stream::ResourceSubviewOp ResourceSubviewOp::findSubviewOp(Value value) {
16901690
return {};
16911691
}
16921692

1693+
//===----------------------------------------------------------------------===//
1694+
// stream.resource.cast
1695+
//===----------------------------------------------------------------------===//
1696+
1697+
Value ResourceCastOp::getTiedResult(unsigned resultIndex) {
1698+
return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
1699+
}
1700+
1701+
::std::optional<unsigned>
1702+
ResourceCastOp::getTiedResultOperandIndex(unsigned resultIndex) {
1703+
return {0}; // source
1704+
}
1705+
1706+
SmallVector<int64_t> ResourceCastOp::getTiedResultOperandIndices() {
1707+
return {0}; // source
1708+
}
1709+
1710+
LogicalResult ResourceCastOp::verify() {
1711+
if (getSourceSize() != getResultSize()) {
1712+
return emitOpError("source and result sizes must be equal (tied storage)");
1713+
}
1714+
return success();
1715+
}
1716+
16931717
//===----------------------------------------------------------------------===//
16941718
// stream.resource.transients
16951719
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,67 @@ def Stream_ResourceSubviewOp : Stream_PureOp<"resource.subview", [
730730
let hasFolder = 1;
731731
}
732732

733+
def Stream_ResourceCastOp : Stream_PureOp<"resource.cast", [
734+
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
735+
"getTiedResult",
736+
"getTiedResultOperandIndex",
737+
"getTiedResultOperandIndices",
738+
]>,
739+
Util_SizeAwareOp,
740+
]> {
741+
let summary = [{Casts a resource to a different lifetime.}];
742+
let description = [{
743+
Asserts that a resource has a specific lifetime. This is a compile-time
744+
constraint that influences lifetime resolution during RefineUsage. The
745+
source and result share the same underlying storage (tied operation).
746+
747+
This is NOT a StreamableOp - it has no execution semantics and does not
748+
participate in timeline verification. It purely constrains type resolution.
749+
750+
The source and result sizes must be equal (verified at compile time).
751+
752+
Example:
753+
```mlir
754+
// Assert resource must resolve to external lifetime.
755+
%1 = stream.resource.cast %0 : !stream.resource<*>{%size} -> !stream.resource<external>{%size}
756+
```
757+
}];
758+
759+
let arguments = (ins
760+
Stream_AnyStreamResource:$source,
761+
Stream_Size:$source_size,
762+
Stream_Size:$result_size
763+
);
764+
let results = (outs
765+
Stream_AnyStreamResource:$result
766+
);
767+
768+
let assemblyFormat = [{
769+
$source `:` type($source) `{` $source_size `}`
770+
`->` type($result) `{` $result_size `}`
771+
attr-dict
772+
}];
773+
774+
let builders = [
775+
OpBuilder<(ins
776+
"Type":$resultType,
777+
"Value":$source,
778+
"Value":$size
779+
), [{
780+
build($_builder, $_state, resultType, source, size, size);
781+
}]>,
782+
];
783+
784+
let extraClassDeclaration = [{
785+
Value getOperandSize(unsigned idx) { return getSourceSize(); }
786+
Value getResultSize(unsigned idx) { return getResultSize(); }
787+
}];
788+
789+
let hasVerifier = 1;
790+
let hasFolder = 1;
791+
let hasCanonicalizer = 1;
792+
}
793+
733794
def Stream_ResourceTransientsOp : Stream_PureOp<"resource.transients", [
734795
AllTypesMatch<["resource", "result"]>,
735796
DeclareOpInterfaceMethods<Stream_AffinityOp, [

compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,9 @@ static void insertUsageRefinementPatterns(MLIRContext *context,
500500
ApplyGenericOp<mlir::scf::ConditionOp>,
501501
ApplyGenericOp<mlir::scf::YieldOp>,
502502
ApplyGenericOp<IREE::Stream::TimepointAwaitOp>,
503-
ApplyGenericOp<IREE::Stream::TimepointBarrierOp>>(context,
504-
analysis);
503+
ApplyGenericOp<IREE::Stream::TimepointBarrierOp>,
504+
ApplyGenericOp<IREE::Stream::ResourceCastOp>>(context,
505+
analysis);
505506
patterns.insert<ApplyStreamableOp<IREE::Stream::ResourceAllocOp>,
506507
ApplyStreamableOp<IREE::Stream::ResourceAllocaOp>,
507508
ApplyStreamableOp<IREE::Stream::ResourceTransientsOp>,

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ util.func public @simple_mul(%arg0: !hal.buffer_view) -> !hal.buffer_view attrib
3232
// CHECK: hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%0, %c4]) type(%[[ELEMENT_TYPE]]) encoding(%[[ENCODING_TYPE]])
3333
// CHECK: %[[ARG0_SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
3434
// CHECK: %[[ARG0_IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[ARG0_SIZE]]}
35-
// CHECK: %[[ARG0_T:.+]] = stream.async.transfer %[[ARG0_IMPORT]] : !stream.resource<external>{%[[ARG0_SIZE]]} -> !stream.resource<*>{%[[ARG0_SIZE]]}
35+
// CHECK: %[[ARG0_T:.+]] = stream.resource.cast %[[ARG0_IMPORT]] : !stream.resource<external>{%[[ARG0_SIZE]]} -> !stream.resource<*>{%[[ARG0_SIZE]]}
3636
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
3737

3838
%c1 = arith.constant 1 : index
@@ -41,7 +41,7 @@ util.func public @simple_mul(%arg0: !hal.buffer_view) -> !hal.buffer_view attrib
4141
// CHECK: %[[RET0:.+]] = stream.tensor.dispatch @executable::@dispatch[%c2, %c1, %c1](%[[ARG0_T]]) : (tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<*>{%[[ARG0_SIZE]]}) -> tensor<?xf32>{%[[DIM0]]} in !stream.resource<*>{%[[RET0_SIZE]]}
4242
%1 = flow.dispatch @executable::@dispatch[%c2, %c1, %c1](%0) : (tensor<?x4xf32>{%dim0}) -> tensor<?xf32>{%dim0}
4343

44-
// CHECK: %[[RET0_T:.+]] = stream.async.transfer %[[RET0]] : !stream.resource<*>{%[[RET0_SIZE]]} -> !stream.resource<external>{%[[RET0_SIZE]]}
44+
// CHECK: %[[RET0_T:.+]] = stream.resource.cast %[[RET0]] : !stream.resource<*>{%[[RET0_SIZE]]} -> !stream.resource<external>{%[[RET0_SIZE]]}
4545
// CHECK: %[[RET0_EXPORT:.+]] = stream.tensor.export %[[RET0_T]] : tensor<?xf32>{%[[DIM0]]} in !stream.resource<external>{%[[RET0_SIZE]]} -> !hal.buffer_view
4646
%2 = hal.tensor.export %1 : tensor<?xf32>{%dim0} -> !hal.buffer_view
4747
// CHECK: util.return %[[RET0_EXPORT]] : !hal.buffer_view

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,3 +554,22 @@ util.func public @transients_external_result(%size: index, %storage_size: index)
554554
// CHECK: util.return {{.+}} !stream.resource<external>
555555
util.return %awaited : !stream.resource<*>
556556
}
557+
558+
// -----
559+
560+
// Tests that stream.resource.cast propagates its constraint to the source value.
561+
// The cast asserts the result must be external, so the source (splat) must also
562+
// become external. After refinement, the cast folds away since types match.
563+
564+
// CHECK-LABEL: @resourceCastPropagation
565+
util.func public @resourceCastPropagation(%size: index) -> !stream.resource<external> {
566+
%c123_i32 = arith.constant 123 : i32
567+
// The splat should be refined to external due to the cast constraint.
568+
// CHECK: %[[SPLAT:.+]] = stream.async.splat %c123_i32 {{.+}} -> !stream.resource<external>
569+
%splat = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
570+
// The cast should fold away after the source type is refined to match.
571+
// CHECK-NOT: stream.resource.cast
572+
%cast = stream.resource.cast %splat : !stream.resource<*>{%size} -> !stream.resource<external>{%size}
573+
// CHECK: util.return %[[SPLAT]] : !stream.resource<external>
574+
util.return %cast : !stream.resource<external>
575+
}

0 commit comments

Comments
 (0)