-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Describe the feature request
We tried to leverage per_channel quantization in QAT and exported the trained model in onnx format.
model = dummy pytorch model
export_model = torch.export.export_for_training(
model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(
is_per_channel=True,
))
prepared_model = prepare_qat_pt2e(export_model, quantizer)
quantized_model = convert_pt2e(prepared_model)
inp = torch.rand((32, 3, 384, 384))
print(inp.shape)
example_inputs = (inp,)
onnx_program = torch.onnx.export(
quantized_model, # model to export
example_inputs, # inputs of the model,
"my_model.onnx", # filename of the ONNX model
opset_version=20, # the ONNX version to export the model to
verbose=True,
input_names=["input"], # Rename inputs for the ONNX model
output_names=['output'], # the model's output names
dynamic=True, # create a dynamic ONNX model
dynamo=True, # True or False to select the exporter to use
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}},
verify=True, # check the model and all its submodules
)
we got the following error:
DispatchError Traceback (most recent call last)
File ~/miniconda3/envs/ims/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_core.py:553, in _add_nodes(exported_program, model, lower, registry)
552 if lower == "at_conversion":
--> 553 _handle_call_function_node_with_lowering(
554 model,
555 node,
556 node_name_to_values,
557 constant_farm,
558 registry=registry,
559 opset=opset,
560 )
561 else:
562 # No lowering
File ~/miniconda3/envs/ims/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_core.py:444, in _handle_call_function_node_with_lowering(model, node, node_name_to_values, constant_farm, registry, opset)
442 if onnx_function is None:
443 # TODO(justinchuby): Fall back to ATen op or do something else?
--> 444 raise _errors.DispatchError(
445 f"No ONNX function found for {node.target!r}. Failure message: {message}"
446 )
448 # Map FX inputs to ONNX inputs and fill optional inputs.
449 # torch_args and torch_kwargs are for op-level validation
DispatchError: No ONNX function found for <OpOverload(op='quantized_decomposed.dequantize_per_channel', overload='default')>. Failure message: No decompositions registered for the real-valued input
...
<class 'torch.onnx._internal.exporter._errors.DispatchError'>: No ONNX function found for <OpOverload(op='quantized_decomposed.dequantize_per_channel', overload='default')>. Failure message: No decompositions registered for the real-valued input
⬆️
<class 'torch.onnx._internal.exporter._errors.ConversionError'>: Error when translating node %dequantize_per_channel : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_channel.default](args = (%b__frozen_param0, %b__scale_0, %b__zero_point_0, 0, -127, 127, torch.int8), kwargs = {}). See the stack trace for more information.
Describe scenario use case
export QAT pytorch model in onnx format.