Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
DFX::Resolution::REQUIRED);
getState() ^= tiedUsage.getState();
})
.Case([&](IREE::Stream::AsyncCastOp op) {
// Cast is a tied passthrough - propagate source usage to result.
// The result type's lifetime constraints are set during init.
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getSource()),
DFX::Resolution::REQUIRED);
getState() ^= sourceUsage.getState();
})
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
Expand Down Expand Up @@ -877,6 +885,14 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
DFX::Resolution::OPTIONAL);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncCastOp op) {
// Cast is a tied passthrough - propagate result usage to source.
// The source type's lifetime constraints are set during init.
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::OPTIONAL);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_READ);
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,25 @@ struct ConvertTensorImportOp
}

// If byte_offset was specified, create a subview at that offset before
// the transfer. This makes the non-zero offset visible to
// the cast. This makes the non-zero offset visible to
// AnnotateDispatchArguments, which computes alignment as
// gcd(base_alignment, offset).
Value transferSource = resource;
Value transferSize = importSize;
Value castSource = resource;
Value castSize = importSize;
if (byteOffset) {
transferSource = IREE::Stream::ResourceSubviewOp::create(
castSource = IREE::Stream::ResourceSubviewOp::create(
rewriter, op.getLoc(), resource, importSize, byteOffset, tensorSize);
transferSize = tensorSize;
castSize = tensorSize;
}

// Cast to unknown lifetime for use in the program.
// This is an async-phase operation that RefineUsage will resolve by
// propagating the external constraint. The cast will fold when types match.
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
Value newImport = IREE::Stream::AsyncTransferOp::create(
rewriter, op.getLoc(), unknownType, transferSource, transferSize,
transferSize,
/*source_affinity=*/executionAffinityAttr,
/*target_affinity=*/executionAffinityAttr);

rewriter.replaceOpWithMultiple(op, {{newImport, transferSize}});
Value newImport = IREE::Stream::AsyncCastOp::create(
rewriter, op.getLoc(), unknownType, castSource, castSize,
executionAffinityAttr);
rewriter.replaceOpWithMultiple(op, {{newImport, castSize}});
return success();
}

Expand Down Expand Up @@ -184,17 +184,17 @@ struct ConvertTensorExportOp
transferTensorOperands(op.getLoc(), op.getSource(), adaptor.getSource(),
executionAffinityAttr, rewriter);

// Exporting a produced value - transfer our source value to an externally
// usable resource and directly export it. This will cause an allocation.
Value exportSource = adaptor.getSource().front();
// Exporting a produced value - cast to external lifetime.
// This is an async-phase operation that RefineUsage will resolve by
// propagating the external constraint backward and converting to a transfer
// if needed. The cast will fold when types match.
Value exportSource = source.resource;
auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
IREE::Stream::Lifetime::External);
if (source.resource.getType() != externalType) {
exportSource = IREE::Stream::AsyncTransferOp::create(
exportSource = IREE::Stream::AsyncCastOp::create(
rewriter, op.getLoc(), externalType, source.resource,
source.resourceSize, source.resourceSize,
/*source_affinity=*/source.affinity,
/*target_affinity=*/executionAffinityAttr);
source.resourceSize, executionAffinityAttr);
}

// Export (stream resource to buffer view).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ util.func public @importBufferView(%view: !hal.buffer_view) -> tensor<?x?x4xf32>
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} : index
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] :
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[RESOURCE]] :
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
%0 = hal.tensor.import %view : !hal.buffer_view -> tensor<?x?x4xf32>{%dim0, %dim1}
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
Expand All @@ -33,7 +33,7 @@ util.func public @importBufferViewWithOffset(%view: !hal.buffer_view, %fence: !h
// CHECK-SAME: : !stream.resource<external>{%[[TOTAL_SIZE]]}
// CHECK: %[[SUBVIEW:.+]] = stream.resource.subview %[[SYNCED]][%[[OFFSET]]] :
// CHECK-SAME: !stream.resource<external>{%[[TOTAL_SIZE]]} -> !stream.resource<external>{%[[TENSOR_SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SUBVIEW]]
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SUBVIEW]]
// CHECK-SAME: : !stream.resource<external>{%[[TENSOR_SIZE]]} -> !stream.resource<*>{%[[TENSOR_SIZE]]}
%0 = hal.tensor.import wait(%fence) => %view offset(%offset) : !hal.buffer_view -> tensor<4xf32>
// CHECK: util.return %[[RESULT]], %[[TENSOR_SIZE]] : !stream.resource<*>, index
Expand All @@ -52,7 +52,7 @@ util.func public @importBufferViewWithOffsetNoFence(%view: !hal.buffer_view, %of
// CHECK-SAME: : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%[[TOTAL_SIZE]]}
// CHECK: %[[SUBVIEW:.+]] = stream.resource.subview %[[RESOURCE]][%[[OFFSET]]] :
// CHECK-SAME: !stream.resource<external>{%[[TOTAL_SIZE]]} -> !stream.resource<external>{%[[TENSOR_SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SUBVIEW]]
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SUBVIEW]]
// CHECK-SAME: : !stream.resource<external>{%[[TENSOR_SIZE]]} -> !stream.resource<*>{%[[TENSOR_SIZE]]}
%0 = hal.tensor.import %view offset(%offset) : !hal.buffer_view -> tensor<4xf32>
// CHECK: util.return %[[RESULT]], %[[TENSOR_SIZE]] : !stream.resource<*>, index
Expand All @@ -67,7 +67,7 @@ util.func public @importBufferViewBitcasting(%view: !hal.buffer_view) -> tensor<
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<4xbf16>
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
// CHECK-SAME: tensor<2xui32> in !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[RESOURCE]] :
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[RESOURCE]] :
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
%0 = hal.tensor.import %view : !hal.buffer_view -> tensor<2xui32> as tensor<4xbf16>
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
Expand All @@ -86,7 +86,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe
// CHECK: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[FENCE]]
// CHECK: %[[SYNC_RESOURCE:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[ASYNC_RESOURCE]]
// CHECK-SAME: : !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.clone %[[SYNC_RESOURCE]]
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.cast %[[SYNC_RESOURCE]]
// CHECK-SAME: : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
%0 = hal.tensor.import wait(%fence) => %view : !hal.buffer_view -> tensor<4xf32>
// CHECK: util.return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
Expand All @@ -98,7 +98,7 @@ util.func public @importBufferViewAsync(%view: !hal.buffer_view, %fence: !hal.fe
// CHECK-LABEL: @exportBufferView
// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
util.func public @exportBufferView(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: index) -> !hal.buffer_view {
// CHECK: %[[VIEW:.+]] = stream.async.clone %[[TENSOR]] :
// CHECK: %[[VIEW:.+]] = stream.async.cast %[[TENSOR]] :
// CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.tensor.export %[[VIEW]] :
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
Expand Down Expand Up @@ -168,7 +168,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.promise<@dev_a>) tensor<4xf32>
// CHECK: %[[RESOURCE:.+]] = stream.tensor.import on(#hal.device.promise<@dev_a>) %[[VIEW]] : !hal.buffer_view ->
// CHECK-SAME: tensor<4xf32> in !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[CLONE:.+]] = stream.async.clone on(#hal.device.promise<@dev_a>) %[[RESOURCE]] :
// CHECK-NEXT: %[[CLONE:.+]] = stream.async.cast on(#hal.device.promise<@dev_a>) %[[RESOURCE]] :
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
%0 = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xf32>
// CHECK: util.return %[[CLONE]], %[[SIZE]] : !stream.resource<*>, index
Expand All @@ -185,7 +185,7 @@ util.func public @importBufferViewCrossDevice(%view: !hal.buffer_view) -> tensor
util.func public @exportBufferViewCrossDevice(%tensor: tensor<4xf32>) -> !hal.buffer_view attributes {
stream.affinity = #hal.device.promise<@dev_a>
} {
// CHECK: %[[CLONE:.+]] = stream.async.clone %[[TENSOR]] :
// CHECK: %[[CLONE:.+]] = stream.async.cast %[[TENSOR]] :
// CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[VIEW:.+]] = stream.tensor.export %[[CLONE]] :
// CHECK-SAME: tensor<4xf32> in !stream.resource<external>{%[[SIZE]]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ util.func public @globalStoreFromExternal(%arg0: !hal.buffer_view) {
%dim0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
// CHECK: %[[IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[SIZE]]}
// CHECK: %[[T:.+]] = stream.async.transfer %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
// CHECK: %[[T:.+]] = stream.async.cast %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
// CHECK: %[[VAR:.+]] = stream.async.transfer %[[T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<variable>{%[[SIZE]]}
// CHECK: util.global.store %[[VAR]], @var_with_buffer_view_store : !stream.resource<variable>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,55 @@ void AsyncTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<ElideUnusedOp<AsyncTransferOp>>(context);
}

//===----------------------------------------------------------------------===//
// stream.async.cast
//===----------------------------------------------------------------------===//

OpFoldResult AsyncCastOp::fold(FoldAdaptor operands) {
// Fold away cast if source and result types match.
if (getSource().getType() == getResult().getType()) {
return getSource();
}
// Fold chain: cast(cast(x)) -> x when outer result type matches inner source.
if (auto sourceCastOp = getSource().getDefiningOp<AsyncCastOp>()) {
if (sourceCastOp.getSource().getType() == getResult().getType()) {
return sourceCastOp.getSource();
}
}
return {};
}

namespace {

// Collapses chains of async casts into a single cast to the final type.
struct CollapseAsyncCastChain : public OpRewritePattern<AsyncCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncCastOp op,
PatternRewriter &rewriter) const override {
auto sourceCastOp = op.getSource().getDefiningOp<AsyncCastOp>();
if (!sourceCastOp) {
return failure();
}
// If folding would handle this (source matches result), let fold do it.
if (sourceCastOp.getSource().getType() == op.getResult().getType()) {
return failure();
}
// Collapse the chain: cast(cast(x, A), B) -> cast(x, B).
rewriter.replaceOpWithNewOp<AsyncCastOp>(
op, op.getResult().getType(), sourceCastOp.getSource(),
sourceCastOp.getSourceSize(), op.getAffinityAttr());
return success();
}
};

} // namespace

void AsyncCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<CollapseAsyncCastChain>(context);
results.insert<ElideUnusedOp<AsyncCastOp>>(context);
}

//===----------------------------------------------------------------------===//
// stream.async.load
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3032,6 +3032,35 @@ void AsyncTransferOp::getAsyncAccessRanges(
getResultSize(), getResultSize()});
}

//===----------------------------------------------------------------------===//
// stream.async.cast
//===----------------------------------------------------------------------===//

Value AsyncCastOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
}

::std::optional<unsigned>
AsyncCastOp::getTiedResultOperandIndex(unsigned resultIndex) {
return {0}; // source
}

SmallVector<int64_t> AsyncCastOp::getTiedResultOperandIndices() {
return {0}; // source
}

LogicalResult AsyncCastOp::verify() {
AsyncCastOp op = *this;
if (failed(verifyOpValueSizes(op, op.getSource(), op.getSourceSize())) ||
failed(verifyOpValueSizes(op, op.getResult(), op.getResultSize()))) {
return failure();
}
if (getSourceSize() != getResultSize()) {
return emitOpError("source and result sizes must be equal (tied storage)");
}
return success();
}

//===----------------------------------------------------------------------===//
// stream.async.load
//===----------------------------------------------------------------------===//
Expand Down
68 changes: 68 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2564,6 +2564,74 @@ def Stream_AsyncTransferOp : Stream_PureOp<"async.transfer", [
let hasFolder = 1;
}

def Stream_AsyncCastOp : Stream_PureOp<"async.cast", [
Stream_AffinityOp,
Stream_AsyncPhaseOp,
Stream_TiedResourcePassthrough,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedResult",
"getTiedResultOperandIndex",
"getTiedResultOperandIndices",
]>,
Util_SizeAwareOp,
]> {
let summary = [{Casts a resource to a different lifetime within async phase.}];
let description = [{
Asserts that a resource has a specific lifetime. This is an async-phase
operation that participates in scheduling. It must be resolved by
RefineUsage before leaving the async phase.

If RefineUsage cannot make the source type match the target type by
backward propagation, it converts this to a proper transfer operation.
If this op survives past the async phase, it indicates a bug in RefineUsage.

The source and result share the same underlying storage (tied operation).

Example:
```mlir
// Assert resource must resolve to external lifetime.
%1 = stream.async.cast %0 : !stream.resource<*>{%size} -> !stream.resource<external>{%size}
```
}];

let arguments = (ins
Stream_AnyStreamResource:$source,
Stream_Size:$source_size,
Stream_Size:$result_size,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
Stream_AnyStreamResource:$result
);

let assemblyFormat = [{
(`on` `(` $affinity^ `)`)?
$source `:` type($source) `{` $source_size `}`
`->` type($result) `{` $result_size `}`
attr-dict
}];

let builders = [
OpBuilder<(ins
"Type":$resultType,
"Value":$source,
"Value":$size,
"IREE::Stream::AffinityAttr":$affinity
), [{
build($_builder, $_state, resultType, source, size, size, affinity);
}]>,
];

let extraClassDeclaration = [{
Value getOperandSize(unsigned idx) { return getSourceSize(); }
Value getResultSize(unsigned idx) { return getResultSize(); }
}];

let hasVerifier = 1;
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Stream_AsyncLoadOp : Stream_PureOp<"async.load", [
Stream_AsyncPhaseOp,
Util_SizeAwareOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,44 @@ util.func private @IntermediateTransferElision(%source: !stream.resource<constan

// -----

// CHECK-LABEL: @FoldAsyncCastSameType
// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<external>, %[[SIZE:.+]]: index)
util.func private @FoldAsyncCastSameType(%source: !stream.resource<external>, %size: index) -> !stream.resource<external> {
// A cast where source and result types match should fold away.
// CHECK-NOT: stream.async.cast
%cast = stream.async.cast %source : !stream.resource<external>{%size} -> !stream.resource<external>{%size}
// CHECK: util.return %[[SOURCE]]
util.return %cast : !stream.resource<external>
}

// -----

// CHECK-LABEL: @FoldAsyncCastChain
// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<external>, %[[SIZE:.+]]: index)
util.func private @FoldAsyncCastChain(%source: !stream.resource<external>, %size: index) -> !stream.resource<external> {
// A chain of casts that returns to the original type should fold away.
// CHECK-NOT: stream.async.cast
%cast0 = stream.async.cast %source : !stream.resource<external>{%size} -> !stream.resource<*>{%size}
%cast1 = stream.async.cast %cast0 : !stream.resource<*>{%size} -> !stream.resource<external>{%size}
// CHECK: util.return %[[SOURCE]]
util.return %cast1 : !stream.resource<external>
}

// -----

// CHECK-LABEL: @CollapseAsyncCastChain
// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<external>, %[[SIZE:.+]]: index)
util.func private @CollapseAsyncCastChain(%source: !stream.resource<external>, %size: index) -> !stream.resource<transient> {
// A chain of casts with different types should collapse into one.
%cast0 = stream.async.cast %source : !stream.resource<external>{%size} -> !stream.resource<*>{%size}
// CHECK: %[[CAST:.+]] = stream.async.cast %[[SOURCE]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<transient>{%[[SIZE]]}
%cast1 = stream.async.cast %cast0 : !stream.resource<*>{%size} -> !stream.resource<transient>{%size}
// CHECK: util.return %[[CAST]]
util.return %cast1 : !stream.resource<transient>
}

// -----

// CHECK-LABEL: @FoldAsyncLoadBitcast
util.func private @FoldAsyncLoadBitcast(%arg0: !stream.resource<staging>, %arg1: index) -> f32 {
%c0 = arith.constant 0 : index
Expand Down
Loading
Loading