@@ -471,114 +471,122 @@ std::shared_ptr<ov::op::v0::Constant> Tensor::get_ov_constant() const {
471471 " The size of the external data file does not match the byte size of an initializer '" + get_name () +
472472 " ' in the model" );
473473 }
474- } else if (element_count == shape_size (m_shape) && m_tensor_proto != nullptr ) {
475- switch (m_tensor_proto->data_type ()) {
476- case TensorProto_DataType::TensorProto_DataType_FLOAT:
477- case TensorProto_DataType::TensorProto_DataType_DOUBLE:
478- case TensorProto_DataType::TensorProto_DataType_INT32:
479- case TensorProto_DataType::TensorProto_DataType_INT64:
480- case TensorProto_DataType::TensorProto_DataType_UINT32:
481- case TensorProto_DataType::TensorProto_DataType_UINT64:
482- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data_ptr ());
483- break ;
484- case TensorProto_DataType::TensorProto_DataType_INT4:
485- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
486- break ;
487- case TensorProto_DataType::TensorProto_DataType_INT8:
488- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
489- break ;
490- case TensorProto_DataType::TensorProto_DataType_INT16:
491- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int16_t >().data ());
492- break ;
493- case TensorProto_DataType::TensorProto_DataType_UINT4:
494- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
495- break ;
496- case TensorProto_DataType::TensorProto_DataType_UINT8:
497- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
498- break ;
499- case TensorProto_DataType::TensorProto_DataType_UINT16:
500- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint16_t >().data ());
501- break ;
502- case TensorProto_DataType::TensorProto_DataType_BOOL:
503- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<char >().data ());
504- break ;
505- case TensorProto_DataType::TensorProto_DataType_BFLOAT16:
506- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::bfloat16>().data ());
507- break ;
508- case TensorProto_DataType::TensorProto_DataType_FLOAT16:
509- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float16>().data ());
510- break ;
511- case TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
512- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e4m3>().data ());
513- break ;
514- case TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
515- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e5m2>().data ());
516- break ;
517- case TensorProto_DataType::TensorProto_DataType_STRING:
518- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<std::string>().data ());
519- break ;
520- default :
521- ONNX_UNSUPPORTED_DATA_TYPE (
522- m_tensor_proto->data_type (),
523- " BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT4, INT8, INT16, INT32, INT64, "
524- " UINT4, UINT8, UINT16, UINT32, UINT64, STRING" );
474+ } else {
475+ const auto shape_elements = shape_size (m_shape);
476+ if (element_count != shape_elements && !(element_count == 0 && m_shape.empty ())) {
477+ FRONT_END_THROW (
478+ " The number of elements implied by the data size does not match the shape of an initializer '" +
479+ get_name () + " ' in the model" );
525480 }
526- } else if (element_count == shape_size (m_shape) && m_tensor_place != nullptr ) {
527- switch (m_tensor_place->get_element_type ()) {
528- case ov::element::f32 :
529- case ov::element::f64 :
530- case ov::element::i32 :
531- case ov::element::i64 :
532- case ov::element::u32 :
533- case ov::element::u64 :
534- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data_ptr ());
535- break ;
536- case ov::element::i4:
537- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
538- break ;
539- case ov::element::i8 :
540- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
541- break ;
542- case ov::element::i16 :
543- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int16_t >().data ());
544- break ;
545- case ov::element::u4:
546- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
547- break ;
548- case ov::element::u8 :
549- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
550- break ;
551- case ov::element::u16 :
552- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint16_t >().data ());
553- break ;
554- case ov::element::boolean:
555- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<char >().data ());
556- break ;
557- case ov::element::bf16 :
558- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::bfloat16>().data ());
559- break ;
560- case ov::element::f16 :
561- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float16>().data ());
562- break ;
563- case ov::element::f8e4m3:
564- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e4m3>().data ());
565- break ;
566- case ov::element::f8e5m2:
567- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e5m2>().data ());
568- break ;
569- case ov::element::string:
570- constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<std::string>().data ());
571- break ;
572- default :
573- ONNX_UNSUPPORTED_DATA_TYPE (
574- m_tensor_proto->data_type (),
575- " BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT4, INT8, INT16, INT32, INT64, "
576- " UINT4, UINT8, UINT16, UINT32, UINT64, STRING" );
481+ if (element_count == 0 ) {
482+ constant = common::make_failsafe_constant (ov_type);
483+ } else if (m_tensor_proto != nullptr ) {
484+ switch (m_tensor_proto->data_type ()) {
485+ case TensorProto_DataType::TensorProto_DataType_FLOAT:
486+ case TensorProto_DataType::TensorProto_DataType_DOUBLE:
487+ case TensorProto_DataType::TensorProto_DataType_INT32:
488+ case TensorProto_DataType::TensorProto_DataType_INT64:
489+ case TensorProto_DataType::TensorProto_DataType_UINT32:
490+ case TensorProto_DataType::TensorProto_DataType_UINT64:
491+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data_ptr ());
492+ break ;
493+ case TensorProto_DataType::TensorProto_DataType_INT4:
494+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
495+ break ;
496+ case TensorProto_DataType::TensorProto_DataType_INT8:
497+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
498+ break ;
499+ case TensorProto_DataType::TensorProto_DataType_INT16:
500+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int16_t >().data ());
501+ break ;
502+ case TensorProto_DataType::TensorProto_DataType_UINT4:
503+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
504+ break ;
505+ case TensorProto_DataType::TensorProto_DataType_UINT8:
506+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
507+ break ;
508+ case TensorProto_DataType::TensorProto_DataType_UINT16:
509+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint16_t >().data ());
510+ break ;
511+ case TensorProto_DataType::TensorProto_DataType_BOOL:
512+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<char >().data ());
513+ break ;
514+ case TensorProto_DataType::TensorProto_DataType_BFLOAT16:
515+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::bfloat16>().data ());
516+ break ;
517+ case TensorProto_DataType::TensorProto_DataType_FLOAT16:
518+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float16>().data ());
519+ break ;
520+ case TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
521+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e4m3>().data ());
522+ break ;
523+ case TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
524+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e5m2>().data ());
525+ break ;
526+ case TensorProto_DataType::TensorProto_DataType_STRING:
527+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<std::string>().data ());
528+ break ;
529+ default :
530+ ONNX_UNSUPPORTED_DATA_TYPE (m_tensor_proto->data_type (),
531+ " BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT4, "
532+ " INT8, INT16, INT32, INT64, "
533+ " UINT4, UINT8, UINT16, UINT32, UINT64, STRING" );
534+ }
535+ } else if (m_tensor_place != nullptr ) {
536+ switch (m_tensor_place->get_element_type ()) {
537+ case ov::element::f32 :
538+ case ov::element::f64 :
539+ case ov::element::i32 :
540+ case ov::element::i64 :
541+ case ov::element::u32 :
542+ case ov::element::u64 :
543+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data_ptr ());
544+ break ;
545+ case ov::element::i4:
546+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
547+ break ;
548+ case ov::element::i8 :
549+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int8_t >().data ());
550+ break ;
551+ case ov::element::i16 :
552+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<int16_t >().data ());
553+ break ;
554+ case ov::element::u4:
555+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
556+ break ;
557+ case ov::element::u8 :
558+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint8_t >().data ());
559+ break ;
560+ case ov::element::u16 :
561+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<uint16_t >().data ());
562+ break ;
563+ case ov::element::boolean:
564+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<char >().data ());
565+ break ;
566+ case ov::element::bf16 :
567+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::bfloat16>().data ());
568+ break ;
569+ case ov::element::f16 :
570+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float16>().data ());
571+ break ;
572+ case ov::element::f8e4m3:
573+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e4m3>().data ());
574+ break ;
575+ case ov::element::f8e5m2:
576+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<ov::float8_e5m2>().data ());
577+ break ;
578+ case ov::element::string:
579+ constant = std::make_shared<ov::op::v0::Constant>(ov_type, m_shape, get_data<std::string>().data ());
580+ break ;
581+ default :
582+ ONNX_UNSUPPORTED_DATA_TYPE (m_tensor_proto->data_type (),
583+ " BOOL, BFLOAT16, FLOAT8E4M3FN, FLOAT8E5M2, FLOAT, FLOAT16, DOUBLE, INT4, "
584+ " INT8, INT16, INT32, INT64, "
585+ " UINT4, UINT8, UINT16, UINT32, UINT64, STRING" );
586+ }
587+ } else {
588+ FRONT_END_THROW (" Tensor shape doesn't match data size" );
577589 }
578- } else if (element_count == 0 && m_shape.size () == 0 ) {
579- constant = common::make_failsafe_constant (ov_type);
580- } else {
581- FRONT_END_THROW (" Tensor shape doesn't match data size" );
582590 }
583591
584592 if (m_tensor_proto != nullptr && m_tensor_proto->has_name ()) {
0 commit comments