Open
Description
can fixed with this patch
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index fde31af60..392ff2b39 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -177,6 +177,23 @@ class TorchFXImporter:
return self._call_binary_op(
relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
)
+ elif isinstance(lhs, int):
+ return self._call_binary_op(
+ relax.op.add, relax.const(lhs, dtype="int64"), rhs
+ )
+ elif isinstance(rhs, int):
+ return self._call_binary_op(
+ relax.op.add, lhs, relax.const(rhs, dtype="int64")
+ )
+ elif isinstance(lhs, float):
+ return self._call_binary_op(
+ relax.op.add, relax.const(lhs, dtype="float32"), rhs
+ )
+ elif isinstance(rhs, float):
+ return self._call_binary_op(
+ relax.op.add, lhs, relax.const(rhs, dtype="float32")
+ )
+
return lhs + rhs
def _max(self, node: fx.node.Node) -> relax.Expr:
Originally posted by @haili-tian in #36 (comment)
after adding this code patch, it could fix the clip_to_text_embeddings(pipe) function, but when execute vae_to_image(pipe) encountered another error
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:670, in OutputGraph.call_user_compiler(self, gm)
669 else:
--> 670 compiled_fn = compiler_fn(gm, self.fake_example_inputs())
671 _step_logger()(logging.INFO, f"done compiler function {name}")
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:1055, in wrap_backend_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
1054 else:
-> 1055 compiled_gm = compiler_fn(gm, example_inputs)
1057 return compiled_gm
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/dynamo.py:161, in dynamo_capture_subgraphs.<locals>._capture(graph_module, example_inputs)
160 input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
--> 161 mod_ = from_fx(
162 graph_module,
163 input_info,
164 keep_params_as_input=keep_params_as_input,
165 unwrap_unit_return_tuple=True,
166 )
167 new_name = f"subgraph_{len(mod.get_global_vars())}"
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/fx_translator.py:1492, in from_fx(model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple)
1404 """Convert a PyTorch FX GraphModule to a Relax program
1405
1406 Parameters
(...)
1490 check the placeholder rows in the beginning of the tabular.
1491 """
-> 1492 return TorchFXImporter().from_fx(
1493 model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple
1494 )
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/fx_translator.py:1377, in TorchFXImporter.from_fx(self, model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple)
1375 func_name = node.name.rstrip("0123456789_")
1376 assert (
-> 1377 func_name in self.convert_map
1378 ), f"Unsupported function type {func_name}"
1379 self.env[node] = self.convert_map[func_name](node)
AssertionError: Unsupported function type conv2d
The above exception was the direct cause of the following exception:
BackendCompilerFailed Traceback (most recent call last)
Cell In[13], line 1
----> 1 vae = vae_to_image(pipe)
Cell In[10], line 22, in vae_to_image(pipe)
19 vae_to_image = VAEModelWrapper(vae)
21 z = torch.rand((1, 4, 64, 64), dtype=torch.float32)
---> 22 mod = dynamo_capture_subgraphs(
23 vae_to_image.forward,
24 z,
25 keep_params_as_input=True,
26 )
27 assert len(mod.functions) == 1
29 return tvm.IRModule({"vae": mod["subgraph_0"]})
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/dynamo.py:175, in dynamo_capture_subgraphs(model, *params, **kwargs)
172 compiled_model = torch.compile(model, backend=_capture)
174 with torch.no_grad():
--> 175 compiled_model(*params, **kwargs)
177 return mod
it seems conv2d is considered call_function
instead of call_module
causing it to not exist in self.convert_map
but I do not understand why conv2d become a function instead of module.
Metadata
Metadata
Assignees
Labels
No labels