Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13128,6 +13128,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
}
}];
let hasFolder = 1;
let hasVerifier = 1;
}

def Torch_AtenViewDtypeOp : Torch_Op<"aten.view.dtype", [
Expand Down
15 changes: 15 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,21 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenViewOp
//===----------------------------------------------------------------------===//

LogicalResult AtenViewOp::verify() {
auto selfType = dyn_cast<BaseTensorType>(getSelf().getType());
auto resultType = dyn_cast<BaseTensorType>(getType());
if (!selfType || !resultType || !selfType.hasDtype() ||
!resultType.hasDtype())
return success();
if (selfType.getDtype() != resultType.getDtype())
return emitOpError("element type of input (")
<< selfType.getDtype() << ") does not match element type of result ("
<< resultType.getDtype()
<< "); `aten.view` cannot change dtype, use `aten.view.dtype` for "
"dtype reinterpretation";
return success();
}

OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType()))
return genericFold;
Expand Down
21 changes: 12 additions & 9 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8592,14 +8592,17 @@ class DecomposeAtenNativeBatchNormOp
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] =
cast<BaseTensorType>(runningMean.getType()).getSizes()[0];
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype);

runningMean = AtenViewOp::create(rewriter, loc, reshapeType, runningMean,
runningStatsSizeList);
runningVar = AtenViewOp::create(rewriter, loc, reshapeType, runningVar,
runningStatsSizeList);
auto reshapeType = [&](Value v) {
auto dtype = cast<ValueTensorType>(v.getType()).getOptionalDtype();
return ValueTensorType::get(context, llvm::ArrayRef(runningStatsShapeInt),
dtype);
};

runningMean = AtenViewOp::create(rewriter, loc, reshapeType(runningMean),
runningMean, runningStatsSizeList);
runningVar = AtenViewOp::create(rewriter, loc, reshapeType(runningVar),
runningVar, runningStatsSizeList);

// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
Value inputSubMean = AtenSubTensorOp::create(
Expand All @@ -8621,7 +8624,7 @@ class DecomposeAtenNativeBatchNormOp
std::optional<unsigned> weightRank = getTensorRank(weight);
if (!weightRank || *weightRank != 1)
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
weight = AtenViewOp::create(rewriter, loc, reshapeType, weight,
weight = AtenViewOp::create(rewriter, loc, reshapeType(weight), weight,
runningStatsSizeList);
batchNormOutput = AtenMulTensorOp::create(
rewriter, loc, batchNormOutput.getType(), batchNormOutput, weight);
Expand All @@ -8631,7 +8634,7 @@ class DecomposeAtenNativeBatchNormOp
std::optional<unsigned> biasRank = getTensorRank(bias);
if (!biasRank || *biasRank != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
bias = AtenViewOp::create(rewriter, loc, reshapeType, bias,
bias = AtenViewOp::create(rewriter, loc, reshapeType(bias), bias,
runningStatsSizeList);
batchNormOutput =
AtenAddTensorOp::create(rewriter, loc, batchNormOutput.getType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
emit("aten::_cast_Long : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True, has_verifier=True)
emit("aten::view.dtype : (Tensor, int) -> (Tensor)")
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True)
Expand Down
24 changes: 24 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,27 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
return %0 : !torch.vtensor<[1,8,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @native_batch_norm_mixed_precision(
// CHECK-SAME: %{{.*}}: !torch.vtensor<[1,3,4,4],bf16>, %{{.*}}: !torch.vtensor<[3],f32>, %{{.*}}: !torch.vtensor<[3],f32>, %[[MEAN:.*]]: !torch.vtensor<[3],f32>, %[[VAR:.*]]: !torch.vtensor<[3],f32>)
// Verify that the running stats reshape uses f32 (matching running stats dtype),
// not bf16 (the input dtype).
// CHECK: torch.aten.view %[[MEAN]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list<int> -> !torch.vtensor<[1,3,1,1],f32>
// CHECK: torch.aten.view %[[VAR]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list<int> -> !torch.vtensor<[1,3,1,1],f32>
func.func @native_batch_norm_mixed_precision(
%input: !torch.vtensor<[1,3,4,4],bf16>,
%weight: !torch.vtensor<[3],f32>,
%bias: !torch.vtensor<[3],f32>,
%running_mean: !torch.vtensor<[3],f32>,
%running_var: !torch.vtensor<[3],f32>
) -> (!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>) {
%false = torch.constant.bool false
%float1e-5 = torch.constant.float 1.000000e-05
%float0.1 = torch.constant.float 1.000000e-01
%out:3 = torch.aten.native_batch_norm %input, %weight, %bias, %running_mean, %running_var, %false, %float0.1, %float1e-5 :
!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float
-> !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>
return %out#0, %out#1, %out#2 : !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>
}
9 changes: 9 additions & 0 deletions test/Dialect/Torch/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,12 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -
torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}

// -----

func.func @torch.aten.view$dtype_mismatch(%arg0: !torch.vtensor<[1],f32>) {
%shape = torch.prim.ListConstruct : () -> !torch.list<int>
// expected-error @below {{'torch.aten.view' op element type of input ('f32') does not match element type of result ('bf16')}}
torch.aten.view %arg0, %shape : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],bf16>
return
}
Loading