Skip to content

Commit 1d58ee4

Browse files
committed
[fix] fix several glitches
- torch executor forward with None is accepted - add back trainable graph - set master_by correctly
1 parent d3226fa commit 1d58ee4

File tree

16 files changed

+196
-98
lines changed

16 files changed

+196
-98
lines changed

mppq/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
Minimized PPQ quantizer package.
88
"""
99

10-
__version__ = "0.7.2"
10+
__version__ = "0.7.3"

mppq/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from mppq.api.extension import register_operation, register_platform
1515
from mppq.api.interface import (
1616
dispatch_graph,
17+
export_config,
1718
export_graph,
1819
export_onnx_graph,
1920
format_graph,
@@ -37,6 +38,7 @@
3738
"register_operation",
3839
"register_platform",
3940
"dispatch_graph",
41+
"export_config",
4042
"export_graph",
4143
"export_onnx_graph",
4244
"format_graph",

mppq/api/extension.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def register_platform(
8686
字典键值可选,作为调度器的命名,和调度器的类类型。
8787
quantizer (Dict[str | None, Type[BaseQuantizer]]): 自定义平台的量化器。
8888
字典键值可选,作为量化器的命名,和量化器的类类型。
89-
parsers (Optional[Dict[str | None, Type[GraphBuilder]]], optional): 自定义平台的图构建器。
90-
字典键值可选,作为图构建器的命名,和图构建器的类类型。 Defaults to None.
91-
exporters (Optional[Dict[str | None, Type[GraphExporter]]], optional): 自定义平台的图导出器。
92-
字典键值可选,作为图导出器的命名,和图导出器的类类型。 Defaults to None.
89+
parsers (Optional[Dict[str | None, Type[GraphBuilder]]], optional):
90+
自定义平台的图构建器。字典键值可选,作为图构建器的命名,和图构建器的类类型。
91+
Defaults to None.
92+
exporters (Optional[Dict[str | None, Type[GraphExporter]]], optional):
93+
自定义平台的图导出器。字典键值可选,作为图导出器的命名,和图导出器的类类型。
94+
Defaults to None.
9395
"""
9496
if platform_id in _PLATFORM_TO_DISPATCHER_:
9597
raise KeyError(f"Platform {platform_id} is already registered.")

mppq/api/interface.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ def export_onnx_graph(graph: BaseGraph, f: str | os.PathLike):
8686
export_graph(graph, f, exporter=exporter)
8787

8888

89+
def export_config(graph: BaseGraph, f: str | os.PathLike):
90+
r"""导出PPQ IR Graph的量化配置。
91+
Args:
92+
graph (BaseGraph): 待导出IR Graph对象。
93+
f (str|os.PathLike): 导出文件路径。
94+
"""
95+
96+
exporter = EXPORTER["onnx"]()
97+
exporter.dump_quantization_config(f, graph)
98+
99+
89100
def format_graph(
90101
graph: BaseGraph,
91102
format_constant_input: bool = True,

mppq/data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
You are not allowed to modify this 请勿修改此文件
44
"""
55

6-
from enum import Enum
6+
from enum import IntEnum
77
from numbers import Number
88
from typing import Any, Literal, Optional, Sequence, overload
99

@@ -14,7 +14,7 @@
1414
from torch import dtype as torch_type
1515

1616

17-
class DataType(Enum):
17+
class DataType(IntEnum):
1818
"""
1919
DataType defines all PPQ internal data type and its enumeration value.
2020
ATTENTION: PPQ shares same data type enumeration value with Onnx.
@@ -113,8 +113,8 @@ def to_numpy(cls, dtype) -> np_type:
113113
DataType.FP32: np_type("float32"),
114114
DataType.FP64: np_type("float64"),
115115
}
116-
assert isinstance(dtype, DataType)
117-
return numpy_converting_dict[dtype]
116+
assert isinstance(dtype, int)
117+
return numpy_converting_dict[DataType(dtype)]
118118

119119
@classmethod
120120
def to_torch(cls, dtype) -> torch_type:
@@ -131,8 +131,8 @@ def to_torch(cls, dtype) -> torch_type:
131131
DataType.FP64: torch.float64,
132132
DataType.FP8_E5M2: torch.float8_e5m2,
133133
}
134-
assert isinstance(dtype, DataType)
135-
return torch_converting_dict[dtype]
134+
assert isinstance(dtype, int)
135+
return torch_converting_dict[DataType(dtype)]
136136

137137

138138
@overload

mppq/executor/base.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ class RuntimeHook:
4646
def __init__(self, operation: Operation, **kwargs) -> None:
4747
self._hook_to = operation
4848

49-
def pre_forward_hook(self, inputs: Sequence[Tensor], **kwargs) -> List[Tensor]:
49+
def pre_forward_hook(
50+
self, inputs: Sequence[Tensor | None], **kwargs
51+
) -> List[Tensor | None]:
5052
"""user-customized pre-processing procedure of input data.
5153
5254
Args:
@@ -57,7 +59,9 @@ def pre_forward_hook(self, inputs: Sequence[Tensor], **kwargs) -> List[Tensor]:
5759
"""
5860
return list(inputs)
5961

60-
def post_forward_hook(self, outputs: Sequence[Tensor], **kwargs) -> List[Tensor]:
62+
def post_forward_hook(
63+
self, outputs: Sequence[Tensor | None], **kwargs
64+
) -> List[Tensor | None]:
6165
"""user-customized post-processing procedure of output data.
6266
6367
Args:
@@ -82,21 +86,21 @@ def __init__(self, operation: Operation, **kwargs) -> None:
8286

8387
def pre_forward_hook(
8488
self,
85-
inputs: Sequence[Tensor],
86-
quant_inputs: Sequence[Tensor] = (),
89+
inputs: Sequence[Tensor | None],
90+
quant_inputs: Sequence[Tensor | None] = (),
8791
quant_configs: Sequence[TensorQuantizationConfig] = (),
8892
**kwargs,
89-
) -> List[Tensor]:
93+
) -> List[Tensor | None]:
9094
assert len(inputs) == len(quant_inputs) == len(quant_configs)
9195
return list(quant_inputs)
9296

9397
def post_forward_hook(
9498
self,
95-
outputs: Sequence[Tensor],
96-
quant_outputs: Sequence[Tensor] = (),
99+
outputs: Sequence[Tensor | None],
100+
quant_outputs: Sequence[Tensor | None] = (),
97101
quant_configs: Sequence[TensorQuantizationConfig] = (),
98102
**kwargs,
99-
) -> List[Tensor]:
103+
) -> List[Tensor | None]:
100104
assert len(outputs) == len(quant_outputs) == len(quant_configs)
101105
return list(quant_outputs)
102106

@@ -155,7 +159,7 @@ def _prepare_input(self, inputs: Optional[GraphInput]) -> Dict[str, Tensor]:
155159
def forward(
156160
self,
157161
inputs: GraphInput,
158-
output_names: Optional[List[str]] = None,
162+
output_names: Optional[Sequence[str]] = None,
159163
hooks: Optional[Mapping[str, RuntimeHook]] = None,
160164
) -> List[torch.Tensor]:
161165
"""Forward a graph from given inputs to required output names.
@@ -169,15 +173,15 @@ def forward(
169173
def tracing_operation_meta(
170174
self,
171175
inputs: GraphInput,
172-
output_names: Optional[List[str]] = None,
176+
output_names: Optional[Sequence[str]] = None,
173177
) -> None:
174178
raise NotImplementedError("Please implement this function first.")
175179

176180
@overload
177181
def forward_single_operation(
178182
self,
179183
op: Operation,
180-
inputs: List[Tensor],
184+
inputs: Sequence[Tensor | None],
181185
ctx: Optional[TorchBackendContext] = None,
182186
return_list: Literal[True] = True,
183187
) -> Tuple[Tensor, ...]:
@@ -187,7 +191,7 @@ def forward_single_operation(
187191
def forward_single_operation(
188192
self,
189193
op: Operation,
190-
inputs: List[Tensor],
194+
inputs: Sequence[Tensor | None],
191195
ctx: Optional[TorchBackendContext] = None,
192196
return_list: bool = True,
193197
) -> Tuple[Tensor, ...] | Tensor:
@@ -196,7 +200,7 @@ def forward_single_operation(
196200
def forward_single_operation(
197201
self,
198202
op: Operation,
199-
inputs: List[Tensor],
203+
inputs: Sequence[Tensor | None],
200204
ctx: Optional[TorchBackendContext] = None,
201205
return_list: bool = True,
202206
) -> Tuple[Tensor, ...] | Tensor:
@@ -210,7 +214,7 @@ def forward_single_operation(
210214
def __call__(
211215
self,
212216
inputs: GraphInput,
213-
output_names: Optional[List[str]] = None,
217+
output_names: Optional[Sequence[str]] = None,
214218
) -> List[torch.Tensor]:
215219
return self.forward(inputs=inputs, output_names=output_names)
216220

mppq/executor/op/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class OperationForwardProtocol(Protocol):
116116
def __call__(
117117
self,
118118
op: Operation,
119-
values: Sequence[torch.Tensor],
119+
values: Sequence[torch.Tensor | None],
120120
ctx: Optional[TorchBackendContext] = None,
121121
**kwargs,
122122
) -> torch.Tensor | Tuple[torch.Tensor, ...]:

0 commit comments

Comments
 (0)