Skip to content

Commit 4b29940

Browse files
committed
[Stream] Add async.cast for lifetime refinement
HAL ABI import/export lowering sometimes needs to assert that a resource must be usable with a different lifetime without scheduling a copy. Modeling that assertion as stream.async.transfer is too strong: a transfer is real async work, so an external fence chained from the original timeline can signal before the transfer executes and let a caller observe a returned tensor too early. Add stream.async.cast as a tied async-phase passthrough that carries lifetime constraints through ResourceUsageAnalysis. RefineUsage folds the cast when the source can be refined to the requested lifetime and lowers it to stream.async.transfer only when concrete lifetimes cannot match. This keeps lifetime-only ABI transitions from introducing accidental timeline work while preserving real copies where they are required.
1 parent af030e4 commit 4b29940

12 files changed

Lines changed: 353 additions & 32 deletions

File tree

compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,14 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
490490
DFX::Resolution::REQUIRED);
491491
getState() ^= tiedUsage.getState();
492492
})
493+
.Case([&](IREE::Stream::AsyncCastOp op) {
494+
// Cast is a tied passthrough - propagate source usage to result.
495+
// The result type's lifetime constraints are set during init.
496+
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
497+
*this, Position::forValue(op.getSource()),
498+
DFX::Resolution::REQUIRED);
499+
getState() ^= sourceUsage.getState();
500+
})
493501
.Case([&](IREE::Stream::AsyncTransferOp op) {
494502
removeAssumedBits(NOT_TRANSFER_WRITE);
495503
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
@@ -877,6 +885,14 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
877885
DFX::Resolution::OPTIONAL);
878886
getState() ^= resultUsage.getState();
879887
})
888+
.Case([&](IREE::Stream::AsyncCastOp op) {
889+
// Cast is a tied passthrough - propagate result usage to source.
890+
// The source type's lifetime constraints are set during init.
891+
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
892+
*this, Position::forValue(op.getResult()),
893+
DFX::Resolution::OPTIONAL);
894+
getState() ^= resultUsage.getState();
895+
})
880896
.Case([&](IREE::Stream::AsyncTransferOp op) {
881897
removeAssumedBits(NOT_TRANSFER_READ);
882898
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -97,25 +97,25 @@ struct ConvertTensorImportOp
9797
}
9898

9999
// If byte_offset was specified, create a subview at that offset before
100-
// the transfer. This makes the non-zero offset visible to
100+
// the cast. This makes the non-zero offset visible to
101101
// AnnotateDispatchArguments, which computes alignment as
102102
// gcd(base_alignment, offset).
103-
Value transferSource = resource;
104-
Value transferSize = importSize;
103+
Value castSource = resource;
104+
Value castSize = importSize;
105105
if (byteOffset) {
106-
transferSource = IREE::Stream::ResourceSubviewOp::create(
106+
castSource = IREE::Stream::ResourceSubviewOp::create(
107107
rewriter, op.getLoc(), resource, importSize, byteOffset, tensorSize);
108-
transferSize = tensorSize;
108+
castSize = tensorSize;
109109
}
110110

111+
// Cast to unknown lifetime for use in the program.
112+
// This is an async-phase operation that RefineUsage will resolve by
113+
// propagating the external constraint. The cast will fold when types match.
111114
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
112-
Value newImport = IREE::Stream::AsyncTransferOp::create(
113-
rewriter, op.getLoc(), unknownType, transferSource, transferSize,
114-
transferSize,
115-
/*source_affinity=*/executionAffinityAttr,
116-
/*target_affinity=*/executionAffinityAttr);
117-
118-
rewriter.replaceOpWithMultiple(op, {{newImport, transferSize}});
115+
Value newImport = IREE::Stream::AsyncCastOp::create(
116+
rewriter, op.getLoc(), unknownType, castSource, castSize,
117+
executionAffinityAttr);
118+
rewriter.replaceOpWithMultiple(op, {{newImport, castSize}});
119119
return success();
120120
}
121121

@@ -184,17 +184,17 @@ struct ConvertTensorExportOp
184184
transferTensorOperands(op.getLoc(), op.getSource(), adaptor.getSource(),
185185
executionAffinityAttr, rewriter);
186186

187-
// Exporting a produced value - transfer our source value to an externally
188-
// usable resource and directly export it. This will cause an allocation.
189-
Value exportSource = adaptor.getSource().front();
187+
// Exporting a produced value - cast to external lifetime.
188+
// This is an async-phase operation that RefineUsage will resolve by
189+
// propagating the external constraint backward and converting to a transfer
190+
// if needed. The cast will fold when types match.
191+
Value exportSource = source.resource;
190192
auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
191193
IREE::Stream::Lifetime::External);
192194
if (source.resource.getType() != externalType) {
193-
exportSource = IREE::Stream::AsyncTransferOp::create(
195+
exportSource = IREE::Stream::AsyncCastOp::create(
194196
rewriter, op.getLoc(), externalType, source.resource,
195-
source.resourceSize, source.resourceSize,
196-
/*source_affinity=*/source.affinity,
197-
/*target_affinity=*/executionAffinityAttr);
197+
source.resourceSize, executionAffinityAttr);
198198
}
199199

200200
// Export (stream resource to buffer view).

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ util.func public @importBufferView(%view: !hal.buffer_view) -> tensor<?x?x4xf32>
1111
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} : index
1212
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
1313
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
14-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] :
14+
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[RESOURCE]] :
1515
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
1616
%0 = hal.tensor.import %view : !hal.buffer_view -> tensor<?x?x4xf32>{%dim0, %dim1}
1717
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
@@ -33,7 +33,7 @@ util.func public @importBufferViewWithOffset(%view: !hal.buffer_view, %fence: !h
3333
// CHECK-SAME: : !stream.resource<external>{%[[TOTAL_SIZE]]}
3434
// CHECK: %[[SUBVIEW:.+]] = stream.resource.subview %[[SYNCED]][%[[OFFSET]]] :
3535
// CHECK-SAME: !stream.resource<external>{%[[TOTAL_SIZE]]} -> !stream.resource<external>{%[[TENSOR_SIZE]]}
36-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SUBVIEW]]
36+
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SUBVIEW]]
3737
// CHECK-SAME: : !stream.resource<external>{%[[TENSOR_SIZE]]} -> !stream.resource<*>{%[[TENSOR_SIZE]]}
3838
%0 = hal.tensor.import wait(%fence) => %view offset(%offset) : !hal.buffer_view -> tensor<4xf32>
3939
// CHECK: util.return %[[RESULT]], %[[TENSOR_SIZE]] : !stream.resource<*>, index
@@ -52,7 +52,7 @@ util.func public @importBufferViewWithOffsetNoFence(%view: !hal.buffer_view, %of
5252
// CHECK-SAME: : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%[[TOTAL_SIZE]]}
5353
// CHECK: %[[SUBVIEW:.+]] = stream.resource.subview %[[RESOURCE]][%[[OFFSET]]] :
5454
// CHECK-SAME: !stream.resource<external>{%[[TOTAL_SIZE]]} -> !stream.resource<external>{%[[TENSOR_SIZE]]}
55-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SUBVIEW]]
55+
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SUBVIEW]]
5656
// CHECK-SAME: : !stream.resource<external>{%[[TENSOR_SIZE]]} -> !stream.resource<*>{%[[TENSOR_SIZE]]}
5757
%0 = hal.tensor.import %view offset(%offset) : !hal.buffer_view -> tensor<4xf32>
5858
// CHECK: util.return %[[RESULT]], %[[TENSOR_SIZE]] : !stream.resource<*>, index
@@ -67,7 +67,7 @@ util.func public @importBufferViewBitcasting(%view: !hal.buffer_view) -> tensor<
6767
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<4xbf16>
6868
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
6969
// CHECK-SAME: tensor<2xui32> in !stream.resource<external>{%[[SIZE]]}
70-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] :
70+
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[RESOURCE]] :
7171
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
7272
%0 = hal.tensor.import %view : !hal.buffer_view -> tensor<2xui32> as tensor<4xbf16>
7373
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
@@ -86,7 +86,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe
8686
// CHECK: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[FENCE]]
8787
// CHECK: %[[SYNC_RESOURCE:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[ASYNC_RESOURCE]]
8888
// CHECK-SAME: : !stream.resource<external>{%[[SIZE]]}
89-
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SYNC_RESOURCE]]
89+
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SYNC_RESOURCE]]
9090
// CHECK-SAME: : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
9191
%0 = hal.tensor.import wait(%fence) => %view : !hal.buffer_view -> tensor<4xf32>
9292
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
@@ -98,7 +98,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe
9898
// CHECK-LABEL: @exportBufferView
9999
// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
100100
util.func public @exportBufferView(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: index) -> !hal.buffer_view {
101-
// CHECK: %[[VIEW:.+]] = stream.async.clone %[[TENSOR]] :
101+
// CHECK: %[[VIEW:.+]] = stream.async.cast %[[TENSOR]] :
102102
// CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
103103
// CHECK-NEXT: %[[RESULT:.+]] = stream.tensor.export %[[VIEW]] :
104104
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
@@ -168,7 +168,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor
168168
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.promise<@dev_a>) tensor<4xf32>
169169
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import on(#hal.device.promise<@dev_a>) %[[VIEW]] : !hal.buffer_view ->
170170
// CHECK-SAME: tensor<4xf32> in !stream.resource<external>{%[[SIZE]]}
171-
// CHECK-NEXT: %[[CLONE:.+]] = stream.async.clone on(#hal.device.promise<@dev_a>) %[[RESOURCE]] :
171+
// CHECK-NEXT: %[[CLONE:.+]] = stream.async.cast on(#hal.device.promise<@dev_a>) %[[RESOURCE]] :
172172
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
173173
%0 = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xf32>
174174
// CHECK: util.return %[[CLONE]], %[[SIZE]] : !stream.resource<*>, index
@@ -185,7 +185,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor
185185
util.func public @exportBufferViewCrossDevice(%tensor: tensor<4xf32>) -> !hal.buffer_view attributes {
186186
stream.affinity = #hal.device.promise<@dev_a>
187187
} {
188-
// CHECK: %[[CLONE:.+]] = stream.async.clone %[[TENSOR]] :
188+
// CHECK: %[[CLONE:.+]] = stream.async.cast %[[TENSOR]] :
189189
// CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
190190
// CHECK-NEXT: %[[VIEW:.+]] = stream.tensor.export %[[CLONE]] :
191191
// CHECK-SAME: tensor<4xf32> in !stream.resource<external>{%[[SIZE]]}

compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ util.func public @globalStoreFromExternal(%arg0: !hal.buffer_view) {
8383
%dim0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
8484
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
8585
// CHECK: %[[IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[SIZE]]}
86-
// CHECK: %[[T:.+]] = stream.async.transfer %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
86+
// CHECK: %[[T:.+]] = stream.async.cast %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
8787
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
8888
// CHECK: %[[VAR:.+]] = stream.async.transfer %[[T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<variable>{%[[SIZE]]}
8989
// CHECK: util.global.store %[[VAR]], @var_with_buffer_view_store : !stream.resource<variable>

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,6 +2136,55 @@ void AsyncTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
21362136
results.insert<ElideUnusedOp<AsyncTransferOp>>(context);
21372137
}
21382138

2139+
//===----------------------------------------------------------------------===//
2140+
// stream.async.cast
2141+
//===----------------------------------------------------------------------===//
2142+
2143+
OpFoldResult AsyncCastOp::fold(FoldAdaptor operands) {
2144+
// Fold away cast if source and result types match.
2145+
if (getSource().getType() == getResult().getType()) {
2146+
return getSource();
2147+
}
2148+
// Fold chain: cast(cast(x)) -> x when outer result type matches inner source.
2149+
if (auto sourceCastOp = getSource().getDefiningOp<AsyncCastOp>()) {
2150+
if (sourceCastOp.getSource().getType() == getResult().getType()) {
2151+
return sourceCastOp.getSource();
2152+
}
2153+
}
2154+
return {};
2155+
}
2156+
2157+
namespace {
2158+
2159+
// Collapses chains of async casts into a single cast to the final type.
2160+
struct CollapseAsyncCastChain : public OpRewritePattern<AsyncCastOp> {
2161+
using OpRewritePattern::OpRewritePattern;
2162+
LogicalResult matchAndRewrite(AsyncCastOp op,
2163+
PatternRewriter &rewriter) const override {
2164+
auto sourceCastOp = op.getSource().getDefiningOp<AsyncCastOp>();
2165+
if (!sourceCastOp) {
2166+
return failure();
2167+
}
2168+
// If folding would handle this (source matches result), let fold do it.
2169+
if (sourceCastOp.getSource().getType() == op.getResult().getType()) {
2170+
return failure();
2171+
}
2172+
// Collapse the chain: cast(cast(x, A), B) -> cast(x, B).
2173+
rewriter.replaceOpWithNewOp<AsyncCastOp>(
2174+
op, op.getResult().getType(), sourceCastOp.getSource(),
2175+
sourceCastOp.getSourceSize(), op.getAffinityAttr());
2176+
return success();
2177+
}
2178+
};
2179+
2180+
} // namespace
2181+
2182+
void AsyncCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2183+
MLIRContext *context) {
2184+
results.insert<CollapseAsyncCastChain>(context);
2185+
results.insert<ElideUnusedOp<AsyncCastOp>>(context);
2186+
}
2187+
21392188
//===----------------------------------------------------------------------===//
21402189
// stream.async.load
21412190
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3032,6 +3032,35 @@ void AsyncTransferOp::getAsyncAccessRanges(
30323032
getResultSize(), getResultSize()});
30333033
}
30343034

3035+
//===----------------------------------------------------------------------===//
3036+
// stream.async.cast
3037+
//===----------------------------------------------------------------------===//
3038+
3039+
Value AsyncCastOp::getTiedResult(unsigned resultIndex) {
3040+
return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
3041+
}
3042+
3043+
::std::optional<unsigned>
3044+
AsyncCastOp::getTiedResultOperandIndex(unsigned resultIndex) {
3045+
return {0}; // source
3046+
}
3047+
3048+
SmallVector<int64_t> AsyncCastOp::getTiedResultOperandIndices() {
3049+
return {0}; // source
3050+
}
3051+
3052+
LogicalResult AsyncCastOp::verify() {
3053+
AsyncCastOp op = *this;
3054+
if (failed(verifyOpValueSizes(op, op.getSource(), op.getSourceSize())) ||
3055+
failed(verifyOpValueSizes(op, op.getResult(), op.getResultSize()))) {
3056+
return failure();
3057+
}
3058+
if (getSourceSize() != getResultSize()) {
3059+
return emitOpError("source and result sizes must be equal (tied storage)");
3060+
}
3061+
return success();
3062+
}
3063+
30353064
//===----------------------------------------------------------------------===//
30363065
// stream.async.load
30373066
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,6 +2564,74 @@ def Stream_AsyncTransferOp : Stream_PureOp<"async.transfer", [
25642564
let hasFolder = 1;
25652565
}
25662566

2567+
def Stream_AsyncCastOp : Stream_PureOp<"async.cast", [
2568+
Stream_AffinityOp,
2569+
Stream_AsyncPhaseOp,
2570+
Stream_TiedResourcePassthrough,
2571+
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
2572+
"getTiedResult",
2573+
"getTiedResultOperandIndex",
2574+
"getTiedResultOperandIndices",
2575+
]>,
2576+
Util_SizeAwareOp,
2577+
]> {
2578+
let summary = [{Casts a resource to a different lifetime within async phase.}];
2579+
let description = [{
2580+
Asserts that a resource has a specific lifetime. This is an async-phase
2581+
operation that participates in scheduling. It must be resolved by
2582+
RefineUsage before leaving the async phase.
2583+
2584+
If RefineUsage cannot make the source type match the target type by
2585+
backward propagation, it converts this to a proper transfer operation.
2586+
If this op survives past the async phase, it indicates a bug in RefineUsage.
2587+
2588+
The source and result share the same underlying storage (tied operation).
2589+
2590+
Example:
2591+
```mlir
2592+
// Assert resource must resolve to external lifetime.
2593+
%1 = stream.async.cast %0 : !stream.resource<*>{%size} -> !stream.resource<external>{%size}
2594+
```
2595+
}];
2596+
2597+
let arguments = (ins
2598+
Stream_AnyStreamResource:$source,
2599+
Stream_Size:$source_size,
2600+
Stream_Size:$result_size,
2601+
OptionalAttr<Stream_AffinityAttr>:$affinity
2602+
);
2603+
let results = (outs
2604+
Stream_AnyStreamResource:$result
2605+
);
2606+
2607+
let assemblyFormat = [{
2608+
(`on` `(` $affinity^ `)`)?
2609+
$source `:` type($source) `{` $source_size `}`
2610+
`->` type($result) `{` $result_size `}`
2611+
attr-dict
2612+
}];
2613+
2614+
let builders = [
2615+
OpBuilder<(ins
2616+
"Type":$resultType,
2617+
"Value":$source,
2618+
"Value":$size,
2619+
"IREE::Stream::AffinityAttr":$affinity
2620+
), [{
2621+
build($_builder, $_state, resultType, source, size, size, affinity);
2622+
}]>,
2623+
];
2624+
2625+
let extraClassDeclaration = [{
2626+
Value getOperandSize(unsigned idx) { return getSourceSize(); }
2627+
Value getResultSize(unsigned idx) { return getResultSize(); }
2628+
}];
2629+
2630+
let hasVerifier = 1;
2631+
let hasFolder = 1;
2632+
let hasCanonicalizer = 1;
2633+
}
2634+
25672635
def Stream_AsyncLoadOp : Stream_PureOp<"async.load", [
25682636
Stream_AsyncPhaseOp,
25692637
Util_SizeAwareOp,

compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,44 @@ util.func private @IntermediateTransferElision(%source: !stream.resource<constan
450450

451451
// -----
452452

453+
// CHECK-LABEL: @FoldAsyncCastSameType
454+
// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<external>, %[[SIZE:.+]]: index)
455+
util.func private @FoldAsyncCastSameType(%source: !stream.resource<external>, %size: index) -> !stream.resource<external> {
456+
// A cast where source and result types match should fold away.
457+
// CHECK-NOT: stream.async.cast
458+
%cast = stream.async.cast %source : !stream.resource<external>{%size} -> !stream.resource<external>{%size}
459+
// CHECK: util.return %[[SOURCE]]
460+
util.return %cast : !stream.resource<external>
461+
}
462+
463+
// -----
464+
465+
// CHECK-LABEL: @FoldAsyncCastChain
466+
// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<external>, %[[SIZE:.+]]: index)
467+
util.func private @FoldAsyncCastChain(%source: !stream.resource<external>, %size: index) -> !stream.resource<external> {
468+
// A chain of casts that returns to the original type should fold away.
469+
// CHECK-NOT: stream.async.cast
470+
%cast0 = stream.async.cast %source : !stream.resource<external>{%size} -> !stream.resource<*>{%size}
471+
%cast1 = stream.async.cast %cast0 : !stream.resource<*>{%size} -> !stream.resource<external>{%size}
472+
// CHECK: util.return %[[SOURCE]]
473+
util.return %cast1 : !stream.resource<external>
474+
}
475+
476+
// -----
477+
478+
// CHECK-LABEL: @CollapseAsyncCastChain
479+
// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<external>, %[[SIZE:.+]]: index)
480+
util.func private @CollapseAsyncCastChain(%source: !stream.resource<external>, %size: index) -> !stream.resource<transient> {
481+
// A chain of casts with different types should collapse into one.
482+
%cast0 = stream.async.cast %source : !stream.resource<external>{%size} -> !stream.resource<*>{%size}
483+
// CHECK: %[[CAST:.+]] = stream.async.cast %[[SOURCE]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<transient>{%[[SIZE]]}
484+
%cast1 = stream.async.cast %cast0 : !stream.resource<*>{%size} -> !stream.resource<transient>{%size}
485+
// CHECK: util.return %[[CAST]]
486+
util.return %cast1 : !stream.resource<transient>
487+
}
488+
489+
// -----
490+
453491
// CHECK-LABEL: @FoldAsyncLoadBitcast
454492
util.func private @FoldAsyncLoadBitcast(%arg0: !stream.resource<staging>, %arg1: index) -> f32 {
455493
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)