Skip to content

Commit a583906

Browse files
authored
Fix Executing Bugs (#408)
* Fix Executing Bugs * 修复了 unsqueeze 算子在多于一个轴时的顺序错误问题 * 修复了 softmax 算子在 opset 11 时默认轴错误的问题 * 修复了 图拷贝 过程中可能因为空值而出现的错误
1 parent 845086f commit a583906

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

ppq/IR/quantize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ def source_op_platform(self) -> TargetPlatform:
237237

238238
def copy(self, copy_value: bool = False):
239239
clone = QuantableVariable(super().copy(copy_value))
240-
if copy_value: clone._fp32_value = self._fp32_value.clone()
240+
if copy_value and self._fp32_value is not None:
241+
clone._fp32_value = self._fp32_value.clone()
241242
else: clone._fp32_value = self._fp32_value
242243
return clone
243244

@@ -314,10 +315,10 @@ def dequantize_graph(self, expire_device: str = 'cpu'):
314315
"""一个方便懒人的函数."""
315316
for operation in self.graph.operations.values():
316317
if isinstance(operation, QuantableOperation):
317-
operation.dequantize()
318+
operation.dequantize(expire_device=expire_device)
318319

319320
def restore_quantize_state(self, expire_device: str = 'cpu'):
320321
"""一个方便懒人的函数."""
321322
for operation in self.graph.operations.values():
322323
if isinstance(operation, QuantableOperation):
323-
operation.restore_quantize_state()
324+
operation.restore_quantize_state(expire_device=expire_device)

ppq/executor/op/torch/default.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def Unsqueeze_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke
10881088
axes = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axes', compulsive=True)
10891089

10901090
if isinstance(axes, list):
1091-
for squeezing_dim in sorted(axes, reverse=True):
1091+
for squeezing_dim in sorted(axes):
10921092
unsqueezing_tensor = torch.unsqueeze(unsqueezing_tensor, squeezing_dim)
10931093
elif isinstance(axes, int):
10941094
unsqueezing_tensor = torch.unsqueeze(unsqueezing_tensor, axes)
@@ -2113,9 +2113,11 @@ def Softmax_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackend
21132113
Returns:
21142114
torch.Tensor: [description]
21152115
"""
2116+
if op.opset.onnx_opset_version() >= 13: default_axis = -1
2117+
else: default_axis = 1
21162118
ASSERT_NUM_OF_INPUT(op=op, values=values, min_num_of_input=1, max_num_of_input=1)
21172119
[input] = values
2118-
axis = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axis', default=-1)
2120+
axis = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axis', default=default_axis)
21192121
output = F.softmax(input, axis)
21202122
return output
21212123

0 commit comments

Comments
 (0)