11import gc
2- import io
32import logging
43import os
54import warnings
@@ -595,32 +594,6 @@ def _save_weight_mapping(self) -> None:
595594 gc .collect ()
596595 torch .cuda .empty_cache ()
597596
598- @needs_refit # type: ignore[misc]
599- def _insert_engine_to_cache (self , hash_val : str , engine : trt .ICudaEngine ) -> None :
600- serialized_engine = engine .serialize ()
601- # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
602- # if not self.compilation_settings.strip_engine_weights:
603- # # set EXCLUDE_WEIGHTS flag to strip weights
604- # serialization_config = engine.create_serialization_config()
605- # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
606- # serialized_engine = engine.serialize_with_config(
607- # serialization_config
608- # )
609-
610- # Cache weighted engine for now
611- self .engine_cache .insert ( # type: ignore[union-attr]
612- hash_val ,
613- (
614- serialized_engine ,
615- self ._input_names ,
616- self ._output_names ,
617- self .input_specs ,
618- self .compilation_settings ,
619- self .weight_name_map ,
620- self .ctx .requires_output_allocator ,
621- ),
622- )
623-
624597 @needs_refit # type: ignore[misc]
625598 def _pull_cached_engine (self , hash_val : str ) -> Optional [TRTInterpreterResult ]:
626599 # query the cached TRT engine
@@ -673,7 +646,6 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
673646 settings = self .compilation_settings ,
674647 weight_name_map = self .weight_name_map ,
675648 )
676- serialized_engine = engine .serialize ()
677649
678650 # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
679651 # # EXCLUDE_WEIGHTS flag must be cleared
@@ -686,12 +658,8 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
686658 # )
687659 # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller
688660
689- with io .BytesIO () as engine_bytes :
690- engine_bytes .write (serialized_engine )
691- engine_str = engine_bytes .getvalue ()
692-
693661 return TRTInterpreterResult (
694- engine_str ,
662+ engine ,
695663 self ._input_names ,
696664 self ._output_names ,
697665 self .weight_name_map ,
@@ -774,14 +742,6 @@ def run(
774742 builder_config , self .compilation_settings .timing_cache_path
775743 )
776744
777- # Engine caching only for refittable engines
778- if (
779- not self .compilation_settings .immutable_weights
780- and self .compilation_settings .cache_built_engines
781- and self .engine_cache is not None
782- ):
783- self ._insert_engine_to_cache (hash_val , cuda_engine )
784-
785745 return TRTInterpreterResult (
786746 cuda_engine ,
787747 self ._input_names ,
0 commit comments