Skip to content

Commit fb59044

Browse files
rkayaithclaude
andcommitted
Add verifier to aten.view to reject element type mismatches
`aten.view` maps to the shape-only overload (`aten::view(Tensor, SymInt[])`), which preserves dtype. Without a verifier, invalid IR with mismatched input/output element types reaches `genericViewLikeFold` and crashes with an assertion failure in `DenseElementsAttr::get`. Fixes #4479 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1f2abc8 commit fb59044

File tree

4 files changed

+26
-1
lines changed

4 files changed

+26
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13128,6 +13128,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
1312813128
}
1312913129
}];
1313013130
let hasFolder = 1;
13131+
let hasVerifier = 1;
1313113132
}
1313213133

1313313134
def Torch_AtenViewDtypeOp : Torch_Op<"aten.view.dtype", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,21 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
12181218
// AtenViewOp
12191219
//===----------------------------------------------------------------------===//
12201220

1221+
LogicalResult AtenViewOp::verify() {
1222+
auto selfType = dyn_cast<BaseTensorType>(getSelf().getType());
1223+
auto resultType = dyn_cast<BaseTensorType>(getType());
1224+
if (!selfType || !resultType || !selfType.hasDtype() ||
1225+
!resultType.hasDtype())
1226+
return success();
1227+
if (selfType.getDtype() != resultType.getDtype())
1228+
return emitOpError("element type of input (")
1229+
<< selfType.getDtype() << ") does not match element type of result ("
1230+
<< resultType.getDtype()
1231+
<< "); `aten.view` cannot change dtype, use `aten.view.dtype` for "
1232+
"dtype reinterpretation";
1233+
return success();
1234+
}
1235+
12211236
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
12221237
if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType()))
12231238
return genericFold;

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ def emit_with_mutating_variants(key, **kwargs):
961961
emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
962962
emit("aten::_cast_Long : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
963963
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
964-
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
964+
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True, has_verifier=True)
965965
emit("aten::view.dtype : (Tensor, int) -> (Tensor)")
966966
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
967967
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True)

test/Dialect/Torch/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,12 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -
403403
torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
404404
return %arg0 : !torch.vtensor<[?],f32>
405405
}
406+
407+
// -----
408+
409+
func.func @torch.aten.view$dtype_mismatch(%arg0: !torch.vtensor<[1],f32>) {
410+
%shape = torch.prim.ListConstruct : () -> !torch.list<int>
411+
// expected-error @below {{'torch.aten.view' op element type of input ('f32') does not match element type of result ('bf16')}}
412+
torch.aten.view %arg0, %shape : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],bf16>
413+
return
414+
}

0 commit comments

Comments
 (0)