Skip to content

Commit 73c97c3

Browse files
Copilotjustinchuby
andcommitted
Changes before error encountered
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent e3ac094 commit 73c97c3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

onnxruntime/core/optimizer/insert_cast_transformer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph,
3939
std::vector<onnxruntime::NodeArg*> input_defs = {new_on_input ? new_arg : old_arg};
4040
std::vector<onnxruntime::NodeArg*> output_defs = {new_on_input ? old_arg : new_arg};
4141

42+
// Validate to_type to ensure it's a valid ONNX data type
43+
if (to_type < 0 || to_type > TensorProto_DataType_DataType_MAX) {
44+
ORT_THROW("Invalid data type value: ", to_type,
45+
". Valid range is 0 to ", TensorProto_DataType_DataType_MAX);
46+
}
47+
4248
auto& cast_node = graph.AddNode(node_name, "Cast", "cast node to cast from float16 to float32 on cpu",
4349
input_defs, output_defs);
4450
cast_node.AddAttribute("to", to_type);

0 commit comments

Comments
 (0)