diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index 00f6606f0638..74f8a196614b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -490,6 +490,14 @@ class ValueResourceUsage : public AbstractResourceUsage { DFX::Resolution::REQUIRED); getState() ^= tiedUsage.getState(); }) + .Case([&](IREE::Stream::AsyncCastOp op) { + // Cast is a tied passthrough - propagate source usage to result. + // The result type's lifetime constraints are set during init. + auto &sourceUsage = solver.getElementFor( + *this, Position::forValue(op.getSource()), + DFX::Resolution::REQUIRED); + getState() ^= sourceUsage.getState(); + }) .Case([&](IREE::Stream::AsyncTransferOp op) { removeAssumedBits(NOT_TRANSFER_WRITE); auto &sourceUsage = solver.getElementFor( @@ -877,6 +885,14 @@ class ValueResourceUsage : public AbstractResourceUsage { DFX::Resolution::OPTIONAL); getState() ^= resultUsage.getState(); }) + .Case([&](IREE::Stream::AsyncCastOp op) { + // Cast is a tied passthrough - propagate result usage to source. + // The source type's lifetime constraints are set during init. + auto &resultUsage = solver.getElementFor( + *this, Position::forValue(op.getResult()), + DFX::Resolution::OPTIONAL); + getState() ^= resultUsage.getState(); + }) .Case([&](IREE::Stream::AsyncTransferOp op) { removeAssumedBits(NOT_TRANSFER_READ); auto &resultUsage = solver.getElementFor( diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 3fbada57dd16..60af1fe6a381 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -97,25 +97,25 @@ struct ConvertTensorImportOp } // If byte_offset was specified, create a subview at that offset before - // the transfer. This makes the non-zero offset visible to + // the cast. This makes the non-zero offset visible to // AnnotateDispatchArguments, which computes alignment as // gcd(base_alignment, offset). - Value transferSource = resource; - Value transferSize = importSize; + Value castSource = resource; + Value castSize = importSize; if (byteOffset) { - transferSource = IREE::Stream::ResourceSubviewOp::create( + castSource = IREE::Stream::ResourceSubviewOp::create( rewriter, op.getLoc(), resource, importSize, byteOffset, tensorSize); - transferSize = tensorSize; + castSize = tensorSize; } + // Cast to unknown lifetime for use in the program. + // This is an async-phase operation that RefineUsage will resolve by + // propagating the external constraint. The cast will fold when types match. auto unknownType = rewriter.getType(); - Value newImport = IREE::Stream::AsyncTransferOp::create( - rewriter, op.getLoc(), unknownType, transferSource, transferSize, - transferSize, - /*source_affinity=*/executionAffinityAttr, - /*target_affinity=*/executionAffinityAttr); - - rewriter.replaceOpWithMultiple(op, {{newImport, transferSize}}); + Value newImport = IREE::Stream::AsyncCastOp::create( + rewriter, op.getLoc(), unknownType, castSource, castSize, + executionAffinityAttr); + rewriter.replaceOpWithMultiple(op, {{newImport, castSize}}); return success(); } @@ -184,17 +184,17 @@ struct ConvertTensorExportOp transferTensorOperands(op.getLoc(), op.getSource(), adaptor.getSource(), executionAffinityAttr, rewriter); - // Exporting a produced value - transfer our source value to an externally - // usable resource and directly export it. This will cause an allocation. - Value exportSource = adaptor.getSource().front(); + // Exporting a produced value - cast to external lifetime. + // This is an async-phase operation that RefineUsage will resolve by + // propagating the external constraint backward and converting to a transfer + // if needed. The cast will fold when types match. + Value exportSource = source.resource; auto externalType = rewriter.getType( IREE::Stream::Lifetime::External); if (source.resource.getType() != externalType) { - exportSource = IREE::Stream::AsyncTransferOp::create( + exportSource = IREE::Stream::AsyncCastOp::create( rewriter, op.getLoc(), externalType, source.resource, - source.resourceSize, source.resourceSize, - /*source_affinity=*/source.affinity, - /*target_affinity=*/executionAffinityAttr); + source.resourceSize, executionAffinityAttr); } // Export (stream resource to buffer view). diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir index 309c20f667cf..8f8376f60127 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir @@ -11,7 +11,7 @@ util.func public @importBufferView(%view: !hal.buffer_view) -> tensor // CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]], %[[DIM1]]} : index // CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view -> // CHECK-SAME: tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[SIZE]]} - // CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] : + // CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[RESOURCE]] : // CHECK-SAME: !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} %0 = hal.tensor.import %view : !hal.buffer_view -> tensor{%dim0, %dim1} // CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index @@ -33,7 +33,7 @@ util.func public @importBufferViewWithOffset(%view: !hal.buffer_view, %fence: !h // CHECK-SAME: : !stream.resource{%[[TOTAL_SIZE]]} // CHECK: %[[SUBVIEW:.+]] = stream.resource.subview %[[SYNCED]][%[[OFFSET]]] : // CHECK-SAME: !stream.resource{%[[TOTAL_SIZE]]} -> !stream.resource{%[[TENSOR_SIZE]]} - // CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SUBVIEW]] + // CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SUBVIEW]] // CHECK-SAME: : !stream.resource{%[[TENSOR_SIZE]]} -> !stream.resource<*>{%[[TENSOR_SIZE]]} %0 = hal.tensor.import wait(%fence) => %view offset(%offset) : !hal.buffer_view -> tensor<4xf32> // CHECK: util.return %[[RESULT]], %[[TENSOR_SIZE]] : !stream.resource<*>, index @@ -52,7 +52,7 @@ util.func public @importBufferViewWithOffsetNoFence(%view: !hal.buffer_view, %of // CHECK-SAME: : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%[[TOTAL_SIZE]]} // CHECK: %[[SUBVIEW:.+]] = stream.resource.subview %[[RESOURCE]][%[[OFFSET]]] : // CHECK-SAME: !stream.resource{%[[TOTAL_SIZE]]} -> !stream.resource{%[[TENSOR_SIZE]]} - // CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SUBVIEW]] + // CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SUBVIEW]] // CHECK-SAME: : !stream.resource{%[[TENSOR_SIZE]]} -> !stream.resource<*>{%[[TENSOR_SIZE]]} %0 = hal.tensor.import %view offset(%offset) : !hal.buffer_view -> tensor<4xf32> // CHECK: util.return %[[RESULT]], %[[TENSOR_SIZE]] : !stream.resource<*>, index @@ -67,7 +67,7 @@ util.func public @importBufferViewBitcasting(%view: !hal.buffer_view) -> tensor< // CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<4xbf16> // CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view -> // CHECK-SAME: tensor<2xui32> in !stream.resource{%[[SIZE]]} - // CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] : + // CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[RESOURCE]] : // CHECK-SAME: !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} %0 = hal.tensor.import %view : !hal.buffer_view -> tensor<2xui32> as tensor<4xbf16> // CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index @@ -86,7 +86,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe // CHECK: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[FENCE]] // CHECK: %[[SYNC_RESOURCE:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[ASYNC_RESOURCE]] // CHECK-SAME: : !stream.resource{%[[SIZE]]} - // CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SYNC_RESOURCE]] + // CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SYNC_RESOURCE]] // CHECK-SAME: : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} %0 = hal.tensor.import wait(%fence) => %view : !hal.buffer_view -> tensor<4xf32> // CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index @@ -98,7 +98,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe // CHECK-LABEL: @exportBufferView // CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index) util.func public @exportBufferView(%tensor: tensor, %dim0: index, %dim1: index) -> !hal.buffer_view { - // CHECK: %[[VIEW:.+]] = stream.async.clone %[[TENSOR]] : + // CHECK: %[[VIEW:.+]] = stream.async.cast %[[TENSOR]] : // CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} // CHECK-NEXT: %[[RESULT:.+]] = stream.tensor.export %[[VIEW]] : // CHECK-SAME: tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[SIZE]]} @@ -168,7 +168,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor // CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.promise<@dev_a>) tensor<4xf32> // CHECK: %[[RESOURCE:.+]] = stream.tensor.import on(#hal.device.promise<@dev_a>) %[[VIEW]] : !hal.buffer_view -> // CHECK-SAME: tensor<4xf32> in !stream.resource{%[[SIZE]]} - // CHECK-NEXT: %[[CLONE:.+]] = stream.async.clone on(#hal.device.promise<@dev_a>) %[[RESOURCE]] : + // CHECK-NEXT: %[[CLONE:.+]] = stream.async.cast on(#hal.device.promise<@dev_a>) %[[RESOURCE]] : // CHECK-SAME: !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} %0 = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xf32> // CHECK: util.return %[[CLONE]], %[[SIZE]] : !stream.resource<*>, index @@ -185,7 +185,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor util.func public @exportBufferViewCrossDevice(%tensor: tensor<4xf32>) -> !hal.buffer_view attributes { stream.affinity = #hal.device.promise<@dev_a> } { - // CHECK: %[[CLONE:.+]] = stream.async.clone %[[TENSOR]] : + // CHECK: %[[CLONE:.+]] = stream.async.cast %[[TENSOR]] : // CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} // CHECK-NEXT: %[[VIEW:.+]] = stream.tensor.export %[[CLONE]] : // CHECK-SAME: tensor<4xf32> in !stream.resource{%[[SIZE]]} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir index 34b8bbade4c2..b858163ffdad 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir @@ -83,7 +83,7 @@ util.func public @globalStoreFromExternal(%arg0: !hal.buffer_view) { %dim0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index // CHECK: %[[SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]]} : index // CHECK: %[[IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor{%[[DIM0]]} in !stream.resource{%[[SIZE]]} - // CHECK: %[[T:.+]] = stream.async.transfer %[[IMPORT]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} + // CHECK: %[[T:.+]] = stream.async.cast %[[IMPORT]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor{%dim0} // CHECK: %[[VAR:.+]] = stream.async.transfer %[[T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} // CHECK: util.global.store %[[VAR]], @var_with_buffer_view_store : !stream.resource diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 9bc9d0655d44..f43749edb9f9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -2136,6 +2136,55 @@ void AsyncTransferOp::getCanonicalizationPatterns(RewritePatternSet &results, results.insert>(context); } +//===----------------------------------------------------------------------===// +// stream.async.cast +//===----------------------------------------------------------------------===// + +OpFoldResult AsyncCastOp::fold(FoldAdaptor operands) { + // Fold away cast if source and result types match. + if (getSource().getType() == getResult().getType()) { + return getSource(); + } + // Fold chain: cast(cast(x)) -> x when outer result type matches inner source. + if (auto sourceCastOp = getSource().getDefiningOp()) { + if (sourceCastOp.getSource().getType() == getResult().getType()) { + return sourceCastOp.getSource(); + } + } + return {}; +} + +namespace { + +// Collapses chains of async casts into a single cast to the final type. +struct CollapseAsyncCastChain : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AsyncCastOp op, + PatternRewriter &rewriter) const override { + auto sourceCastOp = op.getSource().getDefiningOp(); + if (!sourceCastOp) { + return failure(); + } + // If folding would handle this (source matches result), let fold do it. + if (sourceCastOp.getSource().getType() == op.getResult().getType()) { + return failure(); + } + // Collapse the chain: cast(cast(x, A), B) -> cast(x, B). + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), sourceCastOp.getSource(), + sourceCastOp.getSourceSize(), op.getAffinityAttr()); + return success(); + } +}; + +} // namespace + +void AsyncCastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); + results.insert>(context); +} + //===----------------------------------------------------------------------===// // stream.async.load //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index be59659079bb..d789a8f88f70 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -3032,6 +3032,35 @@ void AsyncTransferOp::getAsyncAccessRanges( getResultSize(), getResultSize()}); } +//===----------------------------------------------------------------------===// +// stream.async.cast +//===----------------------------------------------------------------------===// + +Value AsyncCastOp::getTiedResult(unsigned resultIndex) { + return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); +} + +::std::optional +AsyncCastOp::getTiedResultOperandIndex(unsigned resultIndex) { + return {0}; // source +} + +SmallVector AsyncCastOp::getTiedResultOperandIndices() { + return {0}; // source +} + +LogicalResult AsyncCastOp::verify() { + AsyncCastOp op = *this; + if (failed(verifyOpValueSizes(op, op.getSource(), op.getSourceSize())) || + failed(verifyOpValueSizes(op, op.getResult(), op.getResultSize()))) { + return failure(); + } + if (getSourceSize() != getResultSize()) { + return emitOpError("source and result sizes must be equal (tied storage)"); + } + return success(); +} + //===----------------------------------------------------------------------===// // stream.async.load //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index 09d9728dca10..f74ecb4b869b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -2564,6 +2564,74 @@ def Stream_AsyncTransferOp : Stream_PureOp<"async.transfer", [ let hasFolder = 1; } +def Stream_AsyncCastOp : Stream_PureOp<"async.cast", [ + Stream_AffinityOp, + Stream_AsyncPhaseOp, + Stream_TiedResourcePassthrough, + DeclareOpInterfaceMethods, + Util_SizeAwareOp, +]> { + let summary = [{Casts a resource to a different lifetime within async phase.}]; + let description = [{ + Asserts that a resource has a specific lifetime. This is an async-phase + operation that participates in scheduling. It must be resolved by + RefineUsage before leaving the async phase. + + If RefineUsage cannot make the source type match the target type by + backward propagation, it converts this to a proper transfer operation. + If this op survives past the async phase, it indicates a bug in RefineUsage. + + The source and result share the same underlying storage (tied operation). + + Example: + ```mlir + // Assert resource must resolve to external lifetime. + %1 = stream.async.cast %0 : !stream.resource<*>{%size} -> !stream.resource{%size} + ``` + }]; + + let arguments = (ins + Stream_AnyStreamResource:$source, + Stream_Size:$source_size, + Stream_Size:$result_size, + OptionalAttr:$affinity + ); + let results = (outs + Stream_AnyStreamResource:$result + ); + + let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? + $source `:` type($source) `{` $source_size `}` + `->` type($result) `{` $result_size `}` + attr-dict + }]; + + let builders = [ + OpBuilder<(ins + "Type":$resultType, + "Value":$source, + "Value":$size, + "IREE::Stream::AffinityAttr":$affinity + ), [{ + build($_builder, $_state, resultType, source, size, size, affinity); + }]>, + ]; + + let extraClassDeclaration = [{ + Value getOperandSize(unsigned idx) { return getSourceSize(); } + Value getResultSize(unsigned idx) { return getResultSize(); } + }]; + + let hasVerifier = 1; + let hasFolder = 1; + let hasCanonicalizer = 1; +} + def Stream_AsyncLoadOp : Stream_PureOp<"async.load", [ Stream_AsyncPhaseOp, Util_SizeAwareOp, diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir index 927bb255d17b..ee6c739932f7 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir @@ -450,6 +450,44 @@ util.func private @IntermediateTransferElision(%source: !stream.resource, %[[SIZE:.+]]: index) +util.func private @FoldAsyncCastSameType(%source: !stream.resource, %size: index) -> !stream.resource { + // A cast where source and result types match should fold away. + // CHECK-NOT: stream.async.cast + %cast = stream.async.cast %source : !stream.resource{%size} -> !stream.resource{%size} + // CHECK: util.return %[[SOURCE]] + util.return %cast : !stream.resource +} + +// ----- + +// CHECK-LABEL: @FoldAsyncCastChain +// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource, %[[SIZE:.+]]: index) +util.func private @FoldAsyncCastChain(%source: !stream.resource, %size: index) -> !stream.resource { + // A chain of casts that returns to the original type should fold away. + // CHECK-NOT: stream.async.cast + %cast0 = stream.async.cast %source : !stream.resource{%size} -> !stream.resource<*>{%size} + %cast1 = stream.async.cast %cast0 : !stream.resource<*>{%size} -> !stream.resource{%size} + // CHECK: util.return %[[SOURCE]] + util.return %cast1 : !stream.resource +} + +// ----- + +// CHECK-LABEL: @CollapseAsyncCastChain +// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource, %[[SIZE:.+]]: index) +util.func private @CollapseAsyncCastChain(%source: !stream.resource, %size: index) -> !stream.resource { + // A chain of casts with different types should collapse into one. + %cast0 = stream.async.cast %source : !stream.resource{%size} -> !stream.resource<*>{%size} + // CHECK: %[[CAST:.+]] = stream.async.cast %[[SOURCE]] : !stream.resource{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} + %cast1 = stream.async.cast %cast0 : !stream.resource<*>{%size} -> !stream.resource{%size} + // CHECK: util.return %[[CAST]] + util.return %cast1 : !stream.resource +} + +// ----- + // CHECK-LABEL: @FoldAsyncLoadBitcast util.func private @FoldAsyncLoadBitcast(%arg0: !stream.resource, %arg1: index) -> f32 { %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir index bf2b981226cf..2bc0ae52e147 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir @@ -182,6 +182,26 @@ util.func private @asyncTransferAffinities(%arg0: !stream.resource, %a // ----- +// CHECK-LABEL: @asyncCast +util.func private @asyncCast(%arg0: !stream.resource, %arg1: index) -> !stream.resource<*> { + // CHECK: = stream.async.cast %arg0 : !stream.resource{%arg1} -> !stream.resource<*>{%arg1} + %0 = stream.async.cast %arg0 : !stream.resource{%arg1} -> !stream.resource<*>{%arg1} + util.return %0 : !stream.resource<*> +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @asyncCastWithAffinity +util.func private @asyncCastWithAffinity(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource { + // CHECK: = stream.async.cast on(#hal.device.affinity<@device>) %arg0 : !stream.resource<*>{%arg1} -> !stream.resource{%arg1} + %0 = stream.async.cast on(#hal.device.affinity<@device>) %arg0 : !stream.resource<*>{%arg1} -> !stream.resource{%arg1} + util.return %0 : !stream.resource +} + +// ----- + // CHECK-LABEL: @asyncLoad util.func private @asyncLoad(%arg0: !stream.resource, %arg1: index) -> f32 { %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp index 18a565f5ddc0..2ea9618e1368 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp @@ -495,6 +495,39 @@ struct ApplyAsyncTransferOp } }; +// Converts async.cast to async.transfer when source and result lifetimes are +// both concrete but differ. The cast op is a type assertion that should fold +// away when types match, but when they can't match (e.g., constant->external), +// we need an actual transfer operation to perform the copy. +struct ConvertAsyncCastToTransfer + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IREE::Stream::AsyncCastOp op, + PatternRewriter &rewriter) const override { + auto sourceType = + llvm::cast(op.getSource().getType()); + auto resultType = + llvm::cast(op.getResult().getType()); + + // If either type is still Unknown, wait for refinement to resolve it. + if (sourceType.getLifetime() == IREE::Stream::Lifetime::Unknown || + resultType.getLifetime() == IREE::Stream::Lifetime::Unknown) { + return failure(); + } + + // If types match after refinement, let folding handle it. + if (sourceType == resultType) { + return failure(); + } + + // Lifetimes are concrete and different - need an actual transfer. + rewriter.replaceOpWithNewOp( + op, resultType, op.getSource(), op.getSourceSize(), op.getResultSize(), + op.getAffinityAttr(), op.getAffinityAttr()); + return success(); + } +}; + static void insertUsageRefinementPatterns(MLIRContext *context, ResourceUsageAnalysis &analysis, RewritePatternSet &patterns) { @@ -508,8 +541,8 @@ static void insertUsageRefinementPatterns(MLIRContext *context, ApplyGenericOp, ApplyGenericOp, ApplyGenericOp, - ApplyGenericOp>(context, - analysis); + ApplyGenericOp, + ApplyGenericOp>(context, analysis); patterns.insert, ApplyStreamableOp, ApplyStreamableOp, @@ -537,6 +570,8 @@ static void insertUsageRefinementPatterns(MLIRContext *context, ApplyStreamableOp, ApplyStreamableOp, ApplyStreamableOp>(context, analysis); + // Convert async.cast to async.transfer when lifetimes can't match. + patterns.insert(context); } //===----------------------------------------------------------------------===// @@ -570,6 +605,17 @@ struct RefineUsagePass applyPatternsGreedily(moduleOp, frozenPatterns, rewriteConfig))) { return signalPassFailure(); } + + bool hasUnresolvedCast = false; + moduleOp.walk( + [&](IREE::Stream::AsyncCastOp castOp) { + castOp.emitError() + << "unresolved async.cast after stream usage refinement"; + hasUnresolvedCast = true; + }); + if (hasUnresolvedCast) { + return signalPassFailure(); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir index 2f488407476a..5d6c0e9abe3b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir @@ -32,7 +32,7 @@ util.func public @simple_mul(%arg0: !hal.buffer_view) -> !hal.buffer_view attrib // CHECK: hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%0, %c4]) type(%[[ELEMENT_TYPE]]) encoding(%[[ENCODING_TYPE]]) // CHECK: %[[ARG0_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]]} : index // CHECK: %[[ARG0_IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor{%[[DIM0]]} in !stream.resource{%[[ARG0_SIZE]]} - // CHECK: %[[ARG0_T:.+]] = stream.async.transfer %[[ARG0_IMPORT]] : !stream.resource{%[[ARG0_SIZE]]} -> !stream.resource<*>{%[[ARG0_SIZE]]} + // CHECK: %[[ARG0_T:.+]] = stream.async.cast %[[ARG0_IMPORT]] : !stream.resource{%[[ARG0_SIZE]]} -> !stream.resource<*>{%[[ARG0_SIZE]]} %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor{%dim0} %c1 = arith.constant 1 : index @@ -41,7 +41,7 @@ util.func public @simple_mul(%arg0: !hal.buffer_view) -> !hal.buffer_view attrib // CHECK: %[[RET0:.+]] = stream.tensor.dispatch @executable::@dispatch[%c2, %c1, %c1](%[[ARG0_T]]) : (tensor{%[[DIM0]]} in !stream.resource<*>{%[[ARG0_SIZE]]}) -> tensor{%[[DIM0]]} in !stream.resource<*>{%[[RET0_SIZE]]} %1 = flow.dispatch @executable::@dispatch[%c2, %c1, %c1](%0) : (tensor{%dim0}) -> tensor{%dim0} - // CHECK: %[[RET0_T:.+]] = stream.async.transfer %[[RET0]] : !stream.resource<*>{%[[RET0_SIZE]]} -> !stream.resource{%[[RET0_SIZE]]} + // CHECK: %[[RET0_T:.+]] = stream.async.cast %[[RET0]] : !stream.resource<*>{%[[RET0_SIZE]]} -> !stream.resource{%[[RET0_SIZE]]} // CHECK: %[[RET0_EXPORT:.+]] = stream.tensor.export %[[RET0_T]] : tensor{%[[DIM0]]} in !stream.resource{%[[RET0_SIZE]]} -> !hal.buffer_view %2 = hal.tensor.export %1 : tensor{%dim0} -> !hal.buffer_view // CHECK: util.return %[[RET0_EXPORT]] : !hal.buffer_view diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir index 5fb293f2f644..b015a22bfa00 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/refine_usage.mlir @@ -591,6 +591,28 @@ stream.executable private @ex { // ----- +// Tests that async.cast is treated as a tied passthrough when deciding whether +// a clone result is mutated. The clone should still inherit the source's +// external lifetime through the cast before feeding a read-only dispatch. + +// CHECK-LABEL: @clone_external_cast_inherits_source +util.func private @clone_external_cast_inherits_source(%external: !stream.resource, %size: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + // CHECK: %[[CLONE:.+]] = stream.async.clone {{.*}} !stream.resource{{.*}} -> !stream.resource + %clone = stream.async.clone on(#hal.device.affinity<@device>) + %external : !stream.resource{%size} -> !stream.resource<*>{%size} + // CHECK-NOT: stream.async.cast + %cast = stream.async.cast on(#hal.device.affinity<@device>) + %clone : !stream.resource<*>{%size} -> !stream.resource<*>{%size} + // CHECK: stream.async.dispatch {{.*}}(%[[CLONE]]{{.*}}) : (!stream.resource{{.*}}) -> !stream.resource + %result = stream.async.dispatch on(#hal.device.affinity<@device>) + @ex::@dispatch(%cast[%c0 to %size for %size]) + : (!stream.resource<*>{%size}) -> !stream.resource<*>{%size} + util.return %result : !stream.resource<*> +} + +// ----- + // Tests that a clone of an external resource whose result IS mutated (tied // use from dispatch) gets its own lifetime determined by the result's uses, // not the source's lifetime. This is the data-isolation case where the clone @@ -736,3 +758,36 @@ util.func private @parameter_load_refines_to_constant(%scope: !util.buffer, %key // CHECK: util.return {{.*}} : !stream.resource util.return %ready : !stream.resource<*> } + +// ----- + +// Tests that stream.async.cast propagates its constraint to the source value. +// The cast asserts the result must be external, so the source (splat) must also +// become external. After refinement, the cast folds away since types match. + +// CHECK-LABEL: @asyncCastPropagation +util.func public @asyncCastPropagation(%size: index) -> !stream.resource { + %c123_i32 = arith.constant 123 : i32 + // The splat should be refined to external due to the cast constraint. + // CHECK: %[[SPLAT:.+]] = stream.async.splat %c123_i32 {{.+}} -> !stream.resource + %splat = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size} + // The cast should fold away after the source type is refined to match. + // CHECK-NOT: stream.async.cast + %cast = stream.async.cast %splat : !stream.resource<*>{%size} -> !stream.resource{%size} + // CHECK: util.return %[[SPLAT]] : !stream.resource + util.return %cast : !stream.resource +} + +// ----- + +// Tests that a cast between incompatible concrete lifetimes becomes a transfer. + +// CHECK-LABEL: @asyncCastConcreteMismatchTransfers +// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource, %[[SIZE:.+]]: index) +util.func public @asyncCastConcreteMismatchTransfers(%source: !stream.resource, %size: index) -> !stream.resource { + // CHECK-NOT: stream.async.cast + // CHECK: %[[TRANSFER:.+]] = stream.async.transfer %[[SOURCE]] : !stream.resource{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} + %cast = stream.async.cast %source : !stream.resource{%size} -> !stream.resource{%size} + // CHECK: util.return %[[TRANSFER]] : !stream.resource + util.return %cast : !stream.resource +}