-
Notifications
You must be signed in to change notification settings - Fork 414
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Search before asking
- I have searched the RF-DETR issues and found no similar bug report.
Bug
torch.export.export
is the recommended way to export models to other inference frameworks (CoreML, LiteRT, etc), but it breaks:
from rfdetr import RFDETRBase
from torch import randn
from torch.export import export
from torch.nn import Module
# Create model
rf_detr = RFDETRBase()
model: Module = rf_detr.model.model.cpu()
model.export()
input_size: int = rf_detr.model.resolution
# Export
exported_program = export(
model,
args=(randn(1, 3, input_size, input_size),)
)
Exception:
.venv/lib/python3.12/site-packages/torch/export/__init__.py:319: in export
raise e
.venv/lib/python3.12/site-packages/torch/export/__init__.py:286: in export
return _export(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1164: in wrapper
raise e
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1130: in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/exported_program.py:123: in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:2176: in _export
ep = _export_for_training(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1164: in wrapper
raise e
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1130: in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/exported_program.py:123: in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:2037: in _export_for_training
export_artifact = export_func(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1979: in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1770: in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1900: in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1685: in _make_fx_helper
gm = make_fx(
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:2318: in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:2250: in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:2221: in _trace_inner
t = dispatch_trace(
.venv/lib/python3.12/site-packages/torch/_compile.py:53: in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:929: in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1254: in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1835: in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:850: in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1312: in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
<string>:1: in <lambda>
???
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1589: in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py:184: in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:906: in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:825: in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1905: in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:542: in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:818: in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1884: in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:825: in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1905: in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:542: in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:818: in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/rfdetr/models/lwdetr.py:222: in forward_export
hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer(
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:825: in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1905: in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:542: in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:818: in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/rfdetr/models/transformer.py:226: in forward
output_memory, output_proposals = gen_encoder_output_proposals(
.venv/lib/python3.12/site-packages/rfdetr/models/transformer.py:92: in gen_encoder_output_proposals
valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1360: in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1407: in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <torch._export.non_strict_utils._NonStrictTorchFunctionHandler object at 0x14acc1670>
func = <built-in method tensor of type object at 0x119efcae8>, types = ()
args = ([FakeTensor(..., size=(), dtype=torch.int64)],)
kwargs = {'device': device(type='cpu')}
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if torch.compiler.is_dynamo_compiling():
return func(*args, **kwargs)
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
frame = _find_user_code_frame()
if frame is not None:
log.debug(
"%s called at %s:%s in %s",
func.__qualname__,
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
)
func, args, kwargs = self._override(func, args, kwargs)
try:
> return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
E RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
E If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
E If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
.venv/lib/python3.12/site-packages/torch/_export/non_strict_utils.py:1051: RuntimeError
Environment
- RF-DETR 1.3.0
- macOS 26.0
- Python 3.12
- PyTorch 2.8.0
Minimal Reproducible Example
from rfdetr import RFDETRBase
from torch import randn
from torch.export import export
from torch.nn import Module
# Create model
rf_detr = RFDETRBase()
model: Module = rf_detr.model.model.cpu()
model.export()
input_size: int = rf_detr.model.resolution
# Export
exported_program = export(
model,
args=(randn(1, 3, input_size, input_size),)
)
Additional
No response
Are you willing to submit a PR?
- Yes, I'd like to help by submitting a PR!
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working