Description
代码如下:
import torch
import time
from onediff.infer_compiler import oneflow_compile
classConvBiasAddActivation(torch.nn.Module):
def __init__(self, bias=True, activation_cls=None):
super(ConvBiasAddActivation, self).__init__()
self.conv = torch.nn.Conv2d(640, 640, 3, padding = 1 , bias=bias)
self.act = activation_cls(
) if activation_cls is not None else torch.nn.Identity()
def forward(self, x, y=None, alpha=1.0, beta_gamma=None, generator=None):
x = self.conv(x)
if y is not None:
x = x.add(y, alpha=alpha)
x = self.act(x)
if beta_gamma is not None:
x = x.add(beta_gamma[0], alpha=beta_gamma[1])
return x if generator is None else (x, torch.Generator())
def test_trace_with_kwargs():
with torch.no_grad():
model = ConvBiasAddActivation(activation_cls=torch.nn.ReLU)
model.cuda()
model.eval()
x = torch.ones(128,640,16,16).cuda()
y = torch.ones(128,640,16,16).cuda()
args = (x, )
kwargs = dict(y=y, alpha=0.5, beta_gamma=(1, 0.5))
onediff_model = oneflow_compile(model)
out = onediff_model(*args, **kwargs)
# onediff测试
for i in range(10):
torch.cuda.nvtx.range_push(f"onediff_test")
torch.cuda.synchronize()
model_start = time.time()
for i in range(100):
model_out = onediff_model(*args, **kwargs)
#model_out.sum().backward()
torch.cuda.synchronize()
print(f"model costs time : {time.time() - model_start}s")
torch.cuda.nvtx.range_pop()
test_trace_with_kwargs()
报错信息为:
ERROR [2024-11-05 03:47:56] /usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/deployable_module.py:44 - Exception in wrapper: e=NotImplementedError('Transform failed of <class 'main.ConvBiasAddActivation'>: An exception occurred during class transformation:\nTraceback (most recent call last):\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py", line 62, in proxy_class\n out = transform_mgr.transform_cls(cls)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/transform/manager.py", line 84, in transform_cls\n mock_cls = self._transform_entity(mock_full_cls_name)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/transform/manager.py", line 61, in _transform_entity\n result = self.mocker.mock_entity(entity)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/import_tools/importer.py", line 95, in mock_entity\n return self.load_entity_with_mock(entity)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/import_tools/importer.py", line 123, in load_entity_with_mock\n mock_main = DynamicMockModule.from_package(main_name, verbose=False)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/import_tools/dyn_mock_mod.py", line 162, in from_package\n obj_entity = importlib.import_module(main_pkg)\n File "/usr/lib/python3.10/importlib/init.py", line 126, in import_module\n return _bootstrap._gcd_import(name[level:], package, level)\n File "", line 1050, in _gcd_import\n File "", line 1027, in _find_and_load\n File "", line 1006, in _find_and_load_unlocked\n File "", line 688, in _load_unlocked\n File "", line 883, in exec_module\n File "", line 241, in _call_with_frames_removed\n File "/workspace/jit/test_onediff_trace.py", line 48, in \n test_trace_with_kwargs()\n File "/workspace/jit/test_onediff_trace.py", line 34, in test_trace_with_kwargs\n onediff_model = oneflow_compile(model)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/compiler.py", line 21, in oneflow_compile\n return compile(torch_module, backend="oneflow", options=options)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/compiler.py", line 16, in compile\n model = backend(torch_module, options=options)\n File "/usr/local/lib/python3.10/dist-packages/onediff/infer_compiler/backends/oneflow/oneflow.py", line 71, in compile\n model._torch_module.register_load_state_dict_post_hook(state_update_hook)\n File "/usr/local/lib/python3.10/dist-packages/oneflow/nn/modules/module.py", line 433, in getattr\n raise AttributeError(\nAttributeError: 'ConvBiasAddActivation' object has no attribute 'register_load_state_dict_post_hook'\n\nException: 'ConvBiasAddActivation' object has no attribute 'register_load_state_dict_post_hook'')