Skip to content

An error was encountered setting torch._dynamo.decorators.mark_unbacked #1790

Open
@songh11

Description

@songh11

Hello, I want batch set up to be dynamic and I use torch._dynamo.mark_dynamic to set it. But I found that recompile is triggered when batch is 1 and 2. Then I used torch._dynamo.decorators.mark_unbacked but it quantizes incorrectly. Can you look at this problem?

My environment:
torch: 2.5.0
torchao: 0.8.0

This is the minimum repetition code

import torch


from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int8_weight
)
torch._logging.set_logs(recompiles=True, recompiles_verbose = True)

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(128, 256)

    def forward(self, x):
        return self.linear(x)

model = MyModel().cuda().eval()
model = torch.compile(model, fullgraph=True)

# quant
quantize_(model, int8_dynamic_activation_int8_weight())

example_input = torch.randn(2, 64, 128).cuda()
torch._dynamo.decorators.mark_unbacked(example_input, 0)
torch._dynamo.mark_dynamic(example_input, 0)
model(example_input)

x1 = torch.randn(1, 64, 128).cuda()
x2 = torch.randn(2, 64, 128).cuda()

print("input shape: ", x1.shape)
model(x1)
print("input shape: ", x2.shape)
model(x2)

This is the error log

W0227 10:58:38.277000 1279033 torch/fx/experimental/symbolic_shapes.py:5124] [0/0] failed during evaluate_expr(Ne(u0, 1), hint=None, size_oblivious=False, forcing_spec=False E0227 10:58:38.277000 1279033 torch/fx/experimental/recording.py:298] [0/0] failed while running evaluate_expr(*(Ne(u0, 1), None), **{'fx_node': False}) Traceback (most recent call last): File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2132, in run_node return node.target(*args, **kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 433, in _dispatch__torch_function__ return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 412, in wrapper return func(f, types, args, kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 126, in _ return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 83, in _quantized_linear_op quantized_tensor = input_quant_func(input_tensor, **quant_kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 800, in _int8_symm_per_token_reduced_range_quant return to_affine_quantized_intx( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 250, in from_hp_to_intx scale, zero_point = choose_qparams_affine( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 738, in choose_qparams_affine return _choose_qparams_affine( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__ return self._op(*args, **(kwargs or {})) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 840, in _choose_qparams_affine shape_for_reduction, reduction_dims = _get_reduction_params( File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params if block_size[i] != input_size[i] and block_size[i] > 1: File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/__init__.py", line 680, in __bool__ return self.node.bool_() File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 511, in bool_ return self.guard_bool("", 0) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper return retlog(fn(*args, **kwargs)) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec) File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5238, in _evaluate_expr raise self._make_data_dependent_error( torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2017, in get_fake_value
ret_val = wrap_fake_exception(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
return fn()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2018, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/dynamo/utils.py", line 2132, in run_node
return node.target(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 433, in dispatch__torch_function

return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/utils.py", line 412, in wrapper
return func(f, types, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 126, in _
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 83, in _quantized_linear_op
quantized_tensor = input_quant_func(input_tensor, **quant_kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 800, in _int8_symm_per_token_reduced_range_quant
return to_affine_quantized_intx(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 250, in from_hp_to_intx
scale, zero_point = choose_qparams_affine(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 738, in choose_qparams_affine
return _choose_qparams_affine(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_ops.py", line 1116, in call
return self._op(*args, **(kwargs or {}))
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 840, in _choose_qparams_affine
shape_for_reduction, reduction_dims = get_reduction_params(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/init.py", line 680, in bool
return self.node.bool
()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 511, in bool

return self.guard_bool("", 0)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
return retlog(fn(args, **kwargs))
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5122, in evaluate_expr
return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5238, in _evaluate_expr
raise self._make_data_dependent_error(
RuntimeError: Failed running call_function (
(FakeTensor(..., device='cuda:0', size=(u0, 64, 128)), LinearActivationQuantizedTensor(AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., device='cuda:0', size=(256, 128), dtype=torch.int8)... , scale=FakeTensor(..., device='cuda:0', size=(256,))... , zero_point=FakeTensor(..., device='cuda:0', size=(256,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 128), shape=torch.Size([256, 128]), device=cuda:0, dtype=torch.float32, requires_grad=False), <function _int8_symm_per_token_reduced_range_quant at 0x7fa631feac20>, quant_kwargs={})), Parameter(FakeTensor(..., device='cuda:0', size=(256,), requires_grad=True))), **{}):
Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 27, in
model(example_input)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in call
return self._torchdynamo_orig_callable(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in call
return _compile(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
return variables.UserFunctionVariable(fn, source=source).call_function(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
return super().call_function(tx, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call
tracer.run()
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 897, in call_function
tensor_variable = wrap_fx_proxy(
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2037, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2124, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2072, in get_fake_value
raise UserError( # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Ne(u0, 1) (unhinted: Ne(u0, 1)). (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to True.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 229, in _get_reduction_params
if block_size[i] != input_size[i] and block_size[i] > 1:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
File "/root/picasso/songh/workspace/QUANT/test_quant.py", line 16, in forward
return self.linear(x)
File "/root/picasso/songh/my_venv/py310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions