Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2715,23 +2715,75 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
reducedShape[i] = xShape[i];
auto reducedType =
xType.getWithSizesAndDtype(reducedShape, *stashDtype);

// native_layer_norm preserves input dtype, so when stash_type
// caused a cast of x, use stashDtype for y and cast back after.
auto actualYType = yType;
if (*stashDtype != yType.getOptionalDtype()) {
actualYType = cast<Torch::ValueTensorType>(yType.getWithSizesAndDtype(
yType.getOptionalSizes(), *stashDtype));

// Also cast scale and bias to stash_type so all tensor args
// to native_layer_norm share the same dtype.
Value stashDtypeConst = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), *stashDtype);
if (auto scaleTy =
dyn_cast<Torch::ValueTensorType>(scale.getType())) {
auto newScaleTy = scaleTy.getWithSizesAndDtype(
scaleTy.getOptionalSizes(), *stashDtype);
scale = Torch::AtenToDtypeOp::create(
rewriter, binder.getLoc(), newScaleTy, scale,
/*dtype=*/stashDtypeConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
if (auto bTy = dyn_cast<Torch::ValueTensorType>(b.getType())) {
auto newBTy =
bTy.getWithSizesAndDtype(bTy.getOptionalSizes(), *stashDtype);
b = Torch::AtenToDtypeOp::create(
rewriter, binder.getLoc(), newBTy, b,
/*dtype=*/stashDtypeConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
}
auto y = Torch::AtenNativeLayerNormOp::create(
rewriter, binder.getLoc(), yType, /*meanType=*/reducedType,
rewriter, binder.getLoc(), actualYType, /*meanType=*/reducedType,
/*invStdDevType=*/reducedType, x, normalized_shape, scale, b,
constEpsilon);

int64_t numResults = binder.op->getNumResults();
if (numResults == 1) {
rewriter.replaceOp(binder.op, y.getResult0());
Value yResult = y.getResult0();
if (*stashDtype != yType.getOptionalDtype()) {
Value yDtypeConst = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), yType.getDtype());
yResult = Torch::AtenToDtypeOp::create(
rewriter, binder.getLoc(), yType, yResult,
/*dtype=*/yDtypeConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
rewriter.replaceOp(binder.op, yResult);
return success();
}

Value yResult = y.getResult0();
Value meanOutput = y.getResult1();
Value varOutput = y.getResult2();
// Convert meanType and varType back if stash_dtype is different
// Convert outputs back if stash_dtype is different
if (binder.tensorResultTypeAtIndex(meanType, 1) ||
binder.tensorResultTypeAtIndex(invStdDevType, 2))
return failure();
if (*stashDtype != yType.getOptionalDtype()) {
Value yDtypeConst = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), yType.getDtype());
yResult = Torch::AtenToDtypeOp::create(
rewriter, binder.getLoc(), yType, yResult,
/*dtype=*/yDtypeConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
if (*stashDtype != meanType.getOptionalDtype()) {
Value constDtype = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), meanType.getDtype());
Expand All @@ -2746,7 +2798,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput});
rewriter.replaceOp(binder.op, {yResult, meanOutput, varOutput});

return success();
});
Expand Down
42 changes: 42 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,48 @@ func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],f32>, %

// -----

// Test LayerNormalization with stash_type upcasting (f16 input, stash_type=f32).
// When stash_type differs from input dtype, x, scale, and bias must all be
// cast to the stash dtype before calling native_layer_norm, and the result
// must be cast back to the original dtype.
func.func @test_layer_norm_stash_type_f16(%arg0: !torch.vtensor<[2,8,256],f16>, %arg1: !torch.vtensor<[256],f16>, %arg2: !torch.vtensor<[256],f16>) -> !torch.vtensor<[2,8,256],f16>
attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.stash_type = 1 : si64} : (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[256],f16>, !torch.vtensor<[256],f16>) -> !torch.vtensor<[2,8,256],f16>
return %0 : !torch.vtensor<[2,8,256],f16>
}
// CHECK-LABEL: func.func @test_layer_norm_stash_type_f16
// CHECK-SAME: %[[X:[a-zA-Z0-9]+]]: !torch.vtensor<[2,8,256],f16>
// CHECK-SAME: %[[SCALE:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16>
// CHECK-SAME: %[[BIAS:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16>
// CHECK: %[[X_CAST:.*]] = torch.aten.to.dtype %[[X]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,256],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,256],f32>
// CHECK: %[[SCALE_CAST:.*]] = torch.aten.to.dtype %[[SCALE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[256],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[256],f32>
// CHECK: %[[BIAS_CAST:.*]] = torch.aten.to.dtype %[[BIAS]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[256],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[256],f32>
// CHECK: %[[NORM:.*]], %{{.*}}, %{{.*}} = torch.aten.native_layer_norm %[[X_CAST]], %{{.*}}, %[[SCALE_CAST]], %[[BIAS_CAST]], %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.list<int>, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.float -> !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,1],f32>
// CHECK: torch.aten.to.dtype %[[NORM]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,256],f16>

// -----

// Test LayerNormalization with stash_type upcasting returning all 3 results.
func.func @test_layer_norm_stash_type_f16_3results(%arg0: !torch.vtensor<[2,8,256],f16>, %arg1: !torch.vtensor<[256],f16>, %arg2: !torch.vtensor<[256],f16>) -> (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[2,8,1],f16>, !torch.vtensor<[2,8,1],f16>)
attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0:3 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.stash_type = 1 : si64} : (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[256],f16>, !torch.vtensor<[256],f16>) -> (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[2,8,1],f16>, !torch.vtensor<[2,8,1],f16>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[2,8,256],f16>, !torch.vtensor<[2,8,1],f16>, !torch.vtensor<[2,8,1],f16>
}
// CHECK-LABEL: func.func @test_layer_norm_stash_type_f16_3results
// CHECK-SAME: %[[X:[a-zA-Z0-9]+]]: !torch.vtensor<[2,8,256],f16>
// CHECK-SAME: %[[SCALE:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16>
// CHECK-SAME: %[[BIAS:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16>
// CHECK: %[[X_CAST:.*]] = torch.aten.to.dtype %[[X]]
// CHECK: %[[SCALE_CAST:.*]] = torch.aten.to.dtype %[[SCALE]]
// CHECK: %[[BIAS_CAST:.*]] = torch.aten.to.dtype %[[BIAS]]
// CHECK: %[[NORM:.*]], %[[MEAN:.*]], %[[VAR:.*]] = torch.aten.native_layer_norm %[[X_CAST]], %{{.*}}, %[[SCALE_CAST]], %[[BIAS_CAST]], %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.list<int>, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.float -> !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,1],f32>
// CHECK: %[[Y_BACK:.*]] = torch.aten.to.dtype %[[NORM]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,256],f16>
// CHECK: %[[MEAN_BACK:.*]] = torch.aten.to.dtype %[[MEAN]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f16>
// CHECK: %[[VAR_BACK:.*]] = torch.aten.to.dtype %[[VAR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f16>
// CHECK: return %[[Y_BACK]], %[[MEAN_BACK]], %[[VAR_BACK]]

// -----

// CHECK-LABEL: func.func @test_leaky_relu
func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} {
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
Expand Down