Skip to content

RF-DETR incompatible with torch.export.export #406

@olokobayusuf

Description

@olokobayusuf

Search before asking

  • I have searched the RF-DETR issues and found no similar bug report.

Bug

torch.export.export is the recommended way to export models to other inference frameworks (CoreML, LiteRT, etc), but it breaks:

from rfdetr import RFDETRBase
from torch import randn
from torch.export import export
from torch.nn import Module

# Create model
rf_detr = RFDETRBase()
model: Module = rf_detr.model.model.cpu()
model.export()
input_size: int = rf_detr.model.resolution

# Export
exported_program = export(
    model,
    args=(randn(1, 3, input_size, input_size),)
)

Exception:

.venv/lib/python3.12/site-packages/torch/export/__init__.py:319: in export
    raise e
.venv/lib/python3.12/site-packages/torch/export/__init__.py:286: in export
    return _export(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1164: in wrapper
    raise e
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1130: in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/exported_program.py:123: in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:2176: in _export
    ep = _export_for_training(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1164: in wrapper
    raise e
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1130: in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/exported_program.py:123: in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:2037: in _export_for_training
    export_artifact = export_func(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1979: in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1770: in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1900: in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1685: in _make_fx_helper
    gm = make_fx(
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:2318: in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:2250: in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:2221: in _trace_inner
    t = dispatch_trace(
.venv/lib/python3.12/site-packages/torch/_compile.py:53: in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:929: in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1254: in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1835: in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:850: in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1312: in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
<string>:1: in <lambda>
    ???
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1589: in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py:184: in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:906: in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:825: in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1905: in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:542: in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:818: in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/export/_trace.py:1884: in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:825: in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1905: in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:542: in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:818: in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/rfdetr/models/lwdetr.py:222: in forward_export
    hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer(
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:825: in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1905: in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:542: in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:818: in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/rfdetr/models/transformer.py:226: in forward
    output_memory, output_proposals = gen_encoder_output_proposals(
.venv/lib/python3.12/site-packages/rfdetr/models/transformer.py:92: in gen_encoder_output_proposals
    valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1360: in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1407: in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch._export.non_strict_utils._NonStrictTorchFunctionHandler object at 0x14acc1670>
func = <built-in method tensor of type object at 0x119efcae8>, types = ()
args = ([FakeTensor(..., size=(), dtype=torch.int64)],)
kwargs = {'device': device(type='cpu')}

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if torch.compiler.is_dynamo_compiling():
            return func(*args, **kwargs)
    
        if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
            frame = _find_user_code_frame()
            if frame is not None:
                log.debug(
                    "%s called at %s:%s in %s",
                    func.__qualname__,
                    frame.f_code.co_filename,
                    frame.f_lineno,
                    frame.f_code.co_name,
                )
    
        func, args, kwargs = self._override(func, args, kwargs)
        try:
>           return func(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^
E           RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
E           If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
E           If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

.venv/lib/python3.12/site-packages/torch/_export/non_strict_utils.py:1051: RuntimeError

Environment

  • RF-DETR 1.3.0
  • macOS 26.0
  • Python 3.12
  • PyTorch 2.8.0

Minimal Reproducible Example

from rfdetr import RFDETRBase
from torch import randn
from torch.export import export
from torch.nn import Module

# Create model
rf_detr = RFDETRBase()
model: Module = rf_detr.model.model.cpu()
model.export()
input_size: int = rf_detr.model.resolution

# Export
exported_program = export(
    model,
    args=(randn(1, 3, input_size, input_size),)
)

Additional

No response

Are you willing to submit a PR?

  • Yes, I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions