Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama3 not fx traced #33966

Open
4 tasks
myungjin opened this issue Oct 4, 2024 · 1 comment
Open
4 tasks

llama3 not fx traced #33966

myungjin opened this issue Oct 4, 2024 · 1 comment
Labels

Comments

@myungjin
Copy link

myungjin commented Oct 4, 2024

System Info

  • transformers version: 4.45.1
  • Platform: Linux-5.4.247-162.350.amzn2.x86_64-x86_64-with-glibc2.26
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: Tesla V100-SXM2-16GB

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers.utils.fx import symbolic_trace
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
gm = symbolic_trace(model, input_names=["input_ids", "attention_mask", "past_key_values"])
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1503, in symbolic_trace
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1326, in trace
    self.graph = super().trace(root, concrete_args=concrete_args)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 822, in trace
    (self.create_arg(fn(*args)),),
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1190, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
    layer_outputs = decoder_layer(
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1190, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 729, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1190, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 646, in forward
    if query_states.device.type == "cuda" and causal_mask is not None:
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 669, in __bool__
    return super().__bool__()
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/proxy.py", line 447, in __bool__
    return self.tracer.to_bool(self)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/proxy.py", line 307, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

Expected behavior

it should run without error. symbolic_trace() with ["input_ids", "attention_mask"] runs fine. However, when ["input_ids", "attention_mask", "past_key_values"] is fed as input_names, the error occurs. If using past_key_values is incorrect, it should be warned and aborted before trying to trace the model.

While a fix on a related error (#29923) is included in the released version, it seems there is still some bug.

@myungjin myungjin added the bug label Oct 4, 2024
@ArthurZucker
Copy link
Collaborator

Hey! Thanks for reporting! cc @michaelbenayoun I can reproduce, but no idea how to avoid this. It's also not new, so I think past key value path was not tested!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants