diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 82d11ec0737a..dcf2f78fd8a2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13128,6 +13128,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [ } }]; let hasFolder = 1; + let hasVerifier = 1; } def Torch_AtenViewDtypeOp : Torch_Op<"aten.view.dtype", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 80aef691b92f..8bb220dea274 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1218,6 +1218,21 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenViewOp //===----------------------------------------------------------------------===// +LogicalResult AtenViewOp::verify() { + auto selfType = dyn_cast(getSelf().getType()); + auto resultType = dyn_cast(getType()); + if (!selfType || !resultType || !selfType.hasDtype() || + !resultType.hasDtype()) + return success(); + if (selfType.getDtype() != resultType.getDtype()) + return emitOpError("element type of input (") + << selfType.getDtype() << ") does not match element type of result (" + << resultType.getDtype() + << "); `aten.view` cannot change dtype, use `aten.view.dtype` for " + "dtype reinterpretation"; + return success(); +} + OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType())) return genericFold; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9a386ec35f30..e0a4bf063b62 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8592,14 +8592,17 @@ class DecomposeAtenNativeBatchNormOp SmallVector runningStatsShapeInt(inputRank, 1); runningStatsShapeInt[1] = cast(runningMean.getType()).getSizes()[0]; - Type dtype = cast(input.getType()).getOptionalDtype(); - Type reshapeType = ValueTensorType::get( - context, llvm::ArrayRef(runningStatsShapeInt), dtype); - runningMean = AtenViewOp::create(rewriter, loc, reshapeType, runningMean, - runningStatsSizeList); - runningVar = AtenViewOp::create(rewriter, loc, reshapeType, runningVar, - runningStatsSizeList); + auto reshapeType = [&](Value v) { + auto dtype = cast(v.getType()).getOptionalDtype(); + return ValueTensorType::get(context, llvm::ArrayRef(runningStatsShapeInt), + dtype); + }; + + runningMean = AtenViewOp::create(rewriter, loc, reshapeType(runningMean), + runningMean, runningStatsSizeList); + runningVar = AtenViewOp::create(rewriter, loc, reshapeType(runningVar), + runningVar, runningStatsSizeList); // normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)). Value inputSubMean = AtenSubTensorOp::create( @@ -8621,7 +8624,7 @@ class DecomposeAtenNativeBatchNormOp std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) return rewriter.notifyMatchFailure(op, "expected weight to be rank 1"); - weight = AtenViewOp::create(rewriter, loc, reshapeType, weight, + weight = AtenViewOp::create(rewriter, loc, reshapeType(weight), weight, runningStatsSizeList); batchNormOutput = AtenMulTensorOp::create( rewriter, loc, batchNormOutput.getType(), batchNormOutput, weight); @@ -8631,7 +8634,7 @@ class DecomposeAtenNativeBatchNormOp std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); - bias = AtenViewOp::create(rewriter, loc, reshapeType, bias, + bias = AtenViewOp::create(rewriter, loc, reshapeType(bias), bias, runningStatsSizeList); batchNormOutput = AtenAddTensorOp::create(rewriter, loc, batchNormOutput.getType(), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7d5e65c21cef..e80face04b3e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -961,7 +961,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True) emit("aten::_cast_Long : (Tensor, bool) -> (Tensor)", has_canonicalizer=True) emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") - emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) + emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True, has_verifier=True) emit("aten::view.dtype : (Tensor, int) -> (Tensor)") emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index d8dc39375bb6..b41de67d4019 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1013,3 +1013,27 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens %0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32> return %0 : !torch.vtensor<[1,8,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @native_batch_norm_mixed_precision( +// CHECK-SAME: %{{.*}}: !torch.vtensor<[1,3,4,4],bf16>, %{{.*}}: !torch.vtensor<[3],f32>, %{{.*}}: !torch.vtensor<[3],f32>, %[[MEAN:.*]]: !torch.vtensor<[3],f32>, %[[VAR:.*]]: !torch.vtensor<[3],f32>) +// Verify that the running stats reshape uses f32 (matching running stats dtype), +// not bf16 (the input dtype). +// CHECK: torch.aten.view %[[MEAN]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list -> !torch.vtensor<[1,3,1,1],f32> +// CHECK: torch.aten.view %[[VAR]], %{{.*}} : !torch.vtensor<[3],f32>, !torch.list -> !torch.vtensor<[1,3,1,1],f32> +func.func @native_batch_norm_mixed_precision( + %input: !torch.vtensor<[1,3,4,4],bf16>, + %weight: !torch.vtensor<[3],f32>, + %bias: !torch.vtensor<[3],f32>, + %running_mean: !torch.vtensor<[3],f32>, + %running_var: !torch.vtensor<[3],f32> +) -> (!torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32>) { + %false = torch.constant.bool false + %float1e-5 = torch.constant.float 1.000000e-05 + %float0.1 = torch.constant.float 1.000000e-01 + %out:3 = torch.aten.native_batch_norm %input, %weight, %bias, %running_mean, %running_var, %false, %float0.1, %float1e-5 : + !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float + -> !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32> + return %out#0, %out#1, %out#2 : !torch.vtensor<[1,3,4,4],bf16>, !torch.vtensor<[0],f32>, !torch.vtensor<[0],f32> +} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index c863e93fa5fa..26f7b0c55acd 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -403,3 +403,12 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) - torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } + +// ----- + +func.func @torch.aten.view$dtype_mismatch(%arg0: !torch.vtensor<[1],f32>) { + %shape = torch.prim.ListConstruct : () -> !torch.list + // expected-error @below {{'torch.aten.view' op element type of input ('f32') does not match element type of result ('bf16')}} + torch.aten.view %arg0, %shape : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],bf16> + return +}