Skip to content

Commit 19fd4f8

Browse files
author
Li Wei
committed
[fix] some bugs caused by bumping
1 parent 5ccca64 commit 19fd4f8

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

mppq/api/interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def dispatch_graph(
169169
quantize_operations: Collection[str] = DEFAULT_QUANTIZE_OP,
170170
dispatcher: Optional[str | GraphDispatcher] = None,
171171
dispatching_override: Optional[Dict[str, TargetPrecision]] = None,
172-
ignored_scope: Optional[list | IgnoredScope] = None,
172+
ignored_scope: Optional[dict | list | IgnoredScope] = None,
173173
quant_precision: TargetPrecision = TargetPrecision.INT8,
174174
**kwargs,
175175
) -> BaseGraph:
@@ -216,6 +216,8 @@ def dispatch_graph(
216216
)
217217

218218
if ignored_scope is not None:
219+
if isinstance(ignored_scope, dict):
220+
ignored_scope = IgnoredScope(**ignored_scope.pop("type"))
219221
if isinstance(ignored_scope, list):
220222
ignored_scope = IgnoredScope(operations=ignored_scope)
221223
assert isinstance(ignored_scope, IgnoredScope)

mppq/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def from_numpy(cls, dtype: np_type):
5858
np_type("float16"): DataType.FP16,
5959
np_type("float32"): DataType.FP32,
6060
np_type("float64"): DataType.FP64,
61-
np.bool: DataType.BOOL,
61+
np.bool_: DataType.BOOL,
6262
np.uint8: DataType.UINT8,
6363
np.int8: DataType.INT8,
6464
np.int16: DataType.INT16,

mppq/executor/torch.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,18 @@ def post_forward_hook(self, outputs: list, **kwargs) -> list:
310310
you can not select gpu to executing yet,
311311
graph will always be send to the very first visible cuda device.
312312
]. Defaults to 'cuda'.
313+
314+
feedback_tensors=[('output_1', 'input_2')]
315+
# iter 1, executor({input_2: init_tensor_2})
316+
# iter 2, input_2 will be omitted and use output_1's value.
313317
"""
314318

315319
def __init__(
316-
self, graph: BaseGraph, fp16_mode: bool = True, device: str = "cuda"
320+
self,
321+
graph: BaseGraph,
322+
fp16_mode: bool = True,
323+
device: str = "cuda",
324+
feedback_tensors: dict | None = None,
317325
) -> None:
318326
self._default_quant_fn = ppq_fake_quant
319327
self._deployed = False
@@ -325,6 +333,16 @@ def __init__(
325333
# fp16 is not available for now.
326334
self.fp16_mode = fp16_mode
327335
self.deploy()
336+
if feedback_tensors:
337+
self.feedback_tensors = {}
338+
self.feedback_dict = {}
339+
for src, dst in feedback_tensors:
340+
if src not in graph.outputs:
341+
raise ValueError(f"{src} is not an output of the graph.")
342+
if dst not in graph.inputs:
343+
raise ValueError(f"{dst} is not an input of the graph.")
344+
self.feedback_tensors[src] = None
345+
self.feedback_dict[dst] = src
328346

329347
def register_quantize_delegate(
330348
self, config: TensorQuantizationConfig, delegator: TorchQuantizeDelegator
@@ -505,6 +523,11 @@ def _forward_operations( # noqa: C901
505523
hooks: Optional[Mapping[str, RuntimeHook]] = None,
506524
) -> List[torch.Tensor]:
507525
for key, value in inputs.items():
526+
if value is None:
527+
assert key in self.feedback_dict
528+
feedback_value = self.feedback_tensors[self.feedback_dict[key]]
529+
self._graph.inputs[key].value = feedback_value
530+
continue
508531
if not isinstance(value, torch.Tensor):
509532
raise TypeError(
510533
"TorchExecutor can only accept tensor as its input, "
@@ -622,6 +645,12 @@ def _forward_operations( # noqa: C901
622645
result_collector[output_names.index(output_var.name)] = outputs[
623646
output_idx
624647
]
648+
# collect feedback tensors
649+
if (
650+
hasattr(self, "feedback_dict")
651+
and output_var.name in self.feedback_dict.values()
652+
):
653+
self.feedback_tensors[output_var.name] = outputs[output_idx]
625654
except Exception as e:
626655
raise RuntimeError(f"Op Execution Error: {str(operation)}") from e
627656

mppq/quantizer/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def quantize(
104104
graph=self._graph,
105105
dataloader=calib_dataloader,
106106
executor=executor,
107+
collate_fn=collate_fn,
108+
calib_steps=calib_steps,
107109
verbose=self._verbose,
108110
)
109111

0 commit comments

Comments
 (0)