@@ -358,31 +358,27 @@ def create_constant(
358
358
shape = trt .Dims ()
359
359
else :
360
360
shape = list (torch_value .shape )
361
- if torch_value is not None :
362
- if torch_value .dtype == torch .bfloat16 :
363
- torch_value_fp32 = torch_value .to (torch .float32 )
364
- numpy_value = torch_value_fp32 .numpy ()
365
- else :
366
- numpy_value = torch_value .numpy ()
367
361
368
- ctx .mapping [name + " CONSTANT" ] = numpy_value .reshape (- 1 )
369
- constant = ctx .net .add_constant (
370
- shape ,
371
- numpy_value ,
372
- )
373
- constant .name = name
374
- if torch_value .dtype == torch .bfloat16 :
375
- return cast_trt_tensor (
376
- ctx ,
377
- constant .get_output (0 ),
378
- trt .DataType .BF16 ,
379
- name + "_bf16_cast" ,
380
- )
381
- return constant .get_output (0 )
362
+ if torch_value .dtype == torch .bfloat16 :
363
+ torch_value_fp32 = torch_value .to (torch .float32 )
364
+ numpy_value = torch_value_fp32 .numpy ()
382
365
else :
383
- raise ValueError (
384
- f"Cannot convert tensor '{ name } ' to a TensorRT constant because its value is None."
366
+ numpy_value = torch_value .numpy ()
367
+
368
+ ctx .mapping [name + " CONSTANT" ] = numpy_value .reshape (- 1 )
369
+ constant = ctx .net .add_constant (
370
+ shape ,
371
+ numpy_value ,
372
+ )
373
+ constant .name = name
374
+ if torch_value .dtype == torch .bfloat16 :
375
+ return cast_trt_tensor (
376
+ ctx ,
377
+ constant .get_output (0 ),
378
+ trt .DataType .BF16 ,
379
+ name + "_bf16_cast" ,
385
380
)
381
+ return constant .get_output (0 )
386
382
387
383
388
384
def get_trt_tensor (
@@ -423,53 +419,6 @@ def get_trt_tensor(
423
419
raise AssertionError (f"Cannot convert { input_val } to TRT constant" )
424
420
425
421
426
- def to_torch (
427
- value : Optional [Union [torch .Tensor , np .ndarray , int , float , bool ]],
428
- dtype : Optional [Union [torch .dtype , np .dtype , TRTDataType , _enums .dtype ]] = None ,
429
- ) -> Optional [torch .Tensor ]:
430
- """
431
- Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
432
- Args:
433
- value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
434
- A PyTorch tensor, Numpy array, int, float, or bool
435
- dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
436
- If a dtype is given, we will convert the type of the given `value` to this dtype.
437
- Returns:
438
- A PyTorch tensor or None, if the input was None.
439
- """
440
-
441
- cpu_device = torch .device ("cpu" )
442
- torch_dtype = (
443
- _enums .dtype ._from (dtype ).to (torch .dtype , use_default = True ) if dtype else None
444
- )
445
-
446
- with unset_fake_temporarily ():
447
- if value is None :
448
- return None
449
-
450
- elif isinstance (value , torch .Tensor ):
451
- output = value .to (cpu_device ).contiguous ()
452
-
453
- elif isinstance (value , np .ndarray ):
454
- output = torch .from_numpy (value ).to (cpu_device ).contiguous ()
455
-
456
- elif isinstance (value , int ):
457
- output = torch .tensor ([value ], device = cpu_device , dtype = torch .int32 )
458
-
459
- elif isinstance (value , float ):
460
- output = torch .tensor ([value ], device = cpu_device , dtype = torch .float32 )
461
-
462
- elif isinstance (value , bool ):
463
- output = torch .tensor ([value ], device = cpu_device , dtype = torch .bool )
464
-
465
- else :
466
- raise AssertionError (
467
- f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: { type (value )} "
468
- )
469
-
470
- return output .to (torch_dtype ) if torch_dtype else output
471
-
472
-
473
422
@overload
474
423
def get_positive_dim (dim : int , dim_size : int ) -> int : ...
475
424
@@ -633,42 +582,92 @@ def to_numpy(
633
582
Returns:
634
583
A Numpy array or None, if the input was None.
635
584
"""
636
- output = None
585
+ with unset_fake_temporarily ():
586
+ output = None
637
587
638
- if value is None or isinstance (value , np .ndarray ):
639
- output = value
588
+ if value is None or isinstance (value , np .ndarray ):
589
+ output = value
640
590
641
- elif isinstance (value , torch .Tensor ):
642
- if value .is_quantized :
643
- value = value .dequantize ()
644
- elif value .dtype == torch .bfloat16 :
645
- # TODO: Remove when numpy has a BF16 type
646
- _LOGGER .warning (
647
- "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation" ,
591
+ elif isinstance (value , torch .Tensor ):
592
+ if value .is_quantized :
593
+ value = value .dequantize ()
594
+ elif value .dtype == torch .bfloat16 :
595
+ # TODO: Remove when numpy has a BF16 type
596
+ _LOGGER .warning (
597
+ "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation" ,
598
+ )
599
+ value = value .to (torch .float )
600
+
601
+ output = value .cpu ().detach ().contiguous ().numpy ()
602
+
603
+ elif isinstance (value , int ):
604
+ output = np .array ([value ], dtype = np .int32 )
605
+
606
+ elif isinstance (value , float ):
607
+ output = np .array ([value ], dtype = np .float32 )
608
+
609
+ elif isinstance (value , bool ):
610
+ output = np .array ([value ], dtype = np .bool_ )
611
+
612
+ if isinstance (output , np .ndarray ) or output is None :
613
+ return (
614
+ output
615
+ if (dtype is None or output is None )
616
+ else output .astype (
617
+ _enums .dtype ._from (dtype ).to (np .dtype , use_default = True )
618
+ )
619
+ )
620
+ else :
621
+ raise AssertionError (
622
+ f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: { value } "
648
623
)
649
- value = value .to (torch .float )
650
624
651
- output = value .cpu ().detach ().contiguous ().numpy ()
652
625
653
- elif isinstance (value , int ):
654
- output = np .array ([value ], dtype = np .int32 )
626
+ def to_torch (
627
+ value : Optional [Union [torch .Tensor , np .ndarray , int , float , bool ]],
628
+ dtype : Optional [Union [torch .dtype , np .dtype , TRTDataType , _enums .dtype ]] = None ,
629
+ ) -> Optional [torch .Tensor ]:
630
+ """
631
+ Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
632
+ Args:
633
+ value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
634
+ A PyTorch tensor, Numpy array, int, float, or bool
635
+ dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
636
+ If a dtype is given, we will convert the type of the given `value` to this dtype.
637
+ Returns:
638
+ A PyTorch tensor or None, if the input was None.
639
+ """
655
640
656
- elif isinstance (value , float ):
657
- output = np .array ([value ], dtype = np .float32 )
641
+ cpu_device = torch .device ("cpu" )
642
+ torch_dtype = (
643
+ _enums .dtype ._from (dtype ).to (torch .dtype , use_default = True ) if dtype else None
644
+ )
658
645
659
- elif isinstance (value , bool ):
660
- output = np .array ([value ], dtype = np .bool_ )
646
+ with unset_fake_temporarily ():
647
+ if value is None :
648
+ return None
661
649
662
- if isinstance (output , np .ndarray ) or output is None :
663
- return (
664
- output
665
- if (dtype is None or output is None )
666
- else output .astype (_enums .dtype ._from (dtype ).to (np .dtype , use_default = True ))
667
- )
668
- else :
669
- raise AssertionError (
670
- f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: { value } "
671
- )
650
+ elif isinstance (value , torch .Tensor ):
651
+ output = value .to (cpu_device ).contiguous ()
652
+
653
+ elif isinstance (value , np .ndarray ):
654
+ output = torch .from_numpy (value ).to (cpu_device ).contiguous ()
655
+
656
+ elif isinstance (value , int ):
657
+ output = torch .tensor ([value ], device = cpu_device , dtype = torch .int32 )
658
+
659
+ elif isinstance (value , float ):
660
+ output = torch .tensor ([value ], device = cpu_device , dtype = torch .float32 )
661
+
662
+ elif isinstance (value , bool ):
663
+ output = torch .tensor ([value ], device = cpu_device , dtype = torch .bool )
664
+
665
+ else :
666
+ raise AssertionError (
667
+ f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: { type (value )} "
668
+ )
669
+
670
+ return output .to (torch_dtype ) if torch_dtype else output
672
671
673
672
674
673
def flatten_dims (
0 commit comments