@@ -46,7 +46,9 @@ class RuntimeHook:
4646 def __init__ (self , operation : Operation , ** kwargs ) -> None :
4747 self ._hook_to = operation
4848
49- def pre_forward_hook (self , inputs : Sequence [Tensor ], ** kwargs ) -> List [Tensor ]:
49+ def pre_forward_hook (
50+ self , inputs : Sequence [Tensor | None ], ** kwargs
51+ ) -> List [Tensor | None ]:
5052 """user-customized pre-processing procedure of input data.
5153
5254 Args:
@@ -57,7 +59,9 @@ def pre_forward_hook(self, inputs: Sequence[Tensor], **kwargs) -> List[Tensor]:
5759 """
5860 return list (inputs )
5961
60- def post_forward_hook (self , outputs : Sequence [Tensor ], ** kwargs ) -> List [Tensor ]:
62+ def post_forward_hook (
63+ self , outputs : Sequence [Tensor | None ], ** kwargs
64+ ) -> List [Tensor | None ]:
6165 """user-customized post-processing procedure of output data.
6266
6367 Args:
@@ -82,21 +86,21 @@ def __init__(self, operation: Operation, **kwargs) -> None:
8286
8387 def pre_forward_hook (
8488 self ,
85- inputs : Sequence [Tensor ],
86- quant_inputs : Sequence [Tensor ] = (),
89+ inputs : Sequence [Tensor | None ],
90+ quant_inputs : Sequence [Tensor | None ] = (),
8791 quant_configs : Sequence [TensorQuantizationConfig ] = (),
8892 ** kwargs ,
89- ) -> List [Tensor ]:
93+ ) -> List [Tensor | None ]:
9094 assert len (inputs ) == len (quant_inputs ) == len (quant_configs )
9195 return list (quant_inputs )
9296
9397 def post_forward_hook (
9498 self ,
95- outputs : Sequence [Tensor ],
96- quant_outputs : Sequence [Tensor ] = (),
99+ outputs : Sequence [Tensor | None ],
100+ quant_outputs : Sequence [Tensor | None ] = (),
97101 quant_configs : Sequence [TensorQuantizationConfig ] = (),
98102 ** kwargs ,
99- ) -> List [Tensor ]:
103+ ) -> List [Tensor | None ]:
100104 assert len (outputs ) == len (quant_outputs ) == len (quant_configs )
101105 return list (quant_outputs )
102106
@@ -155,7 +159,7 @@ def _prepare_input(self, inputs: Optional[GraphInput]) -> Dict[str, Tensor]:
155159 def forward (
156160 self ,
157161 inputs : GraphInput ,
158- output_names : Optional [List [str ]] = None ,
162+ output_names : Optional [Sequence [str ]] = None ,
159163 hooks : Optional [Mapping [str , RuntimeHook ]] = None ,
160164 ) -> List [torch .Tensor ]:
161165 """Forward a graph from given inputs to required output names.
@@ -169,15 +173,15 @@ def forward(
169173 def tracing_operation_meta (
170174 self ,
171175 inputs : GraphInput ,
172- output_names : Optional [List [str ]] = None ,
176+ output_names : Optional [Sequence [str ]] = None ,
173177 ) -> None :
174178 raise NotImplementedError ("Please implement this function first." )
175179
176180 @overload
177181 def forward_single_operation (
178182 self ,
179183 op : Operation ,
180- inputs : List [Tensor ],
184+ inputs : Sequence [Tensor | None ],
181185 ctx : Optional [TorchBackendContext ] = None ,
182186 return_list : Literal [True ] = True ,
183187 ) -> Tuple [Tensor , ...]:
@@ -187,7 +191,7 @@ def forward_single_operation(
187191 def forward_single_operation (
188192 self ,
189193 op : Operation ,
190- inputs : List [Tensor ],
194+ inputs : Sequence [Tensor | None ],
191195 ctx : Optional [TorchBackendContext ] = None ,
192196 return_list : bool = True ,
193197 ) -> Tuple [Tensor , ...] | Tensor :
@@ -196,7 +200,7 @@ def forward_single_operation(
196200 def forward_single_operation (
197201 self ,
198202 op : Operation ,
199- inputs : List [Tensor ],
203+ inputs : Sequence [Tensor | None ],
200204 ctx : Optional [TorchBackendContext ] = None ,
201205 return_list : bool = True ,
202206 ) -> Tuple [Tensor , ...] | Tensor :
@@ -210,7 +214,7 @@ def forward_single_operation(
210214 def __call__ (
211215 self ,
212216 inputs : GraphInput ,
213- output_names : Optional [List [str ]] = None ,
217+ output_names : Optional [Sequence [str ]] = None ,
214218 ) -> List [torch .Tensor ]:
215219 return self .forward (inputs = inputs , output_names = output_names )
216220
0 commit comments