Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype error during quantized generation #1371

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mreso
Copy link
Contributor

@mreso mreso commented Aug 19, 2024

This PR fixes #1349 RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source

Repro:
See #1349 for repro
Cmd:

$tune run generate --config generate_quant.yaml

Output:

INFO:torchtune.utils.logging:Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.utils.FullModelTorchTuneCheckpointer
  checkpoint_dir: quantized/
  checkpoint_files:
  - meta_model_0-8da4w.pt
  model_type: LLAMA3
  output_dir: model-output
device: cuda
dtype: bf16
enable_kv_cache: true
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
  max_seq_len: 2048
prompt: 'Amanda: I baked  cookies. Do you want some?\nJerry: Sure \nAmanda: I will
  bring you tomorrow :-)'
quantizer:
  _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /home/mreso/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16/original/tokenizer.model
top_k: 300

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
linear: layers.0.attn.q_proj, in=4096, out=4096
linear: layers.0.attn.k_proj, in=4096, out=1024
linear: layers.0.attn.v_proj, in=4096, out=1024
linear: layers.0.attn.output_proj, in=4096, out=4096
....
linear: layers.31.mlp.w1, in=4096, out=14336
linear: layers.31.mlp.w2, in=14336, out=4096
linear: layers.31.mlp.w3, in=4096, out=14336
linear: output, in=4096, out=128256
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Starting compilation to improve generation performance ...
INFO:torchtune.utils.logging:Warmup run for quantized model takes: 33.09 sec
INFO:torchtune.utils.logging:Amanda: I baked  cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-) \nJerry: What are you talking about? I want them now\nAmanda: I said tomorrow, you can't just eat them all by yourself, that's not fair. (She laughs.) \nJerry: I'm starving\nAmanda: Go ask someone else, I'm not going to be your cookie m
other \nJerry: That's not fair\nAmanda: I know, but it's not my responsibility to feed you all day \nJerry: Fine, I'll just ask someone else. \nAmanda: And don't eat all the cookies, they're for sharing, remember?\nAmanda: I'm going to go, bye \nJerry: Wait, come back, I'll pay you, I'll buy you more cookies\nAmanda: I don't need your money, I don't need your
cookies, I'll just eat them myself, bye \nJerry: You're really mean\nAmanda: I'm not mean, I'm just firm, that's all\nAmanda: I'm going to go, bye \nJerry: Wait, come back, I'll get you an ice cream, you like ice cream, don't you?\nAmanda: No, I don't want your ice cream, I don't want your cookies, I don't want your money, I'll just go, bye \nJerry: You're rea
lly mean, I'm going to tell everyone, you're mean\nAmanda
INFO:torchtune.utils.logging:Time for inference: 47.79 sec total, 6.28 tokens/sec
INFO:torchtune.utils.logging:Bandwidth achieved: 61.40 GB/s
INFO:torchtune.utils.logging:Memory used: 22.97 GB

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Cast k,v,q to kv_cache dtype
  • Clone token before feeding it to generate_next token (avoids torch.compile/CUDAGraph overwritten ssieu
  • Remove size parameter in kvcache as its incompatible with torch.compile

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

…argument to generate_next token to avoid torch.compile isse (CUDAGraph overwritten); Remove size parameter in kvcache
Copy link

pytorch-bot bot commented Aug 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1371

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b030750 with merge base 3c580fc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 19, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Very clean changes. I left a few comments

@@ -159,7 +159,7 @@ def generate(
tokens = custom_generate_next_token(
model,
input_pos=curr_input_pos,
x=tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand why we need to do .clone? I worry about performance issues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, without it we're running into this error:

RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 48, in generate_next_token
    return sample(logits, temperature, top_k)
  File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 34, in sample
    return multinomial_sample_one(probs)
  File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 16, in multinomial_sample_one
    return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

Could not get this to work with torch.compiler.cudagraph_mark_step_begin() in the first try so I opted for cloning. Can retry to make cudagraph_mark_step_begin() work, if you think performance will take a big hit.


k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: We need to do it because the base model has one dtype, e.g. float32, but the cache is in fp16, for example, is that right?

after we save the k_val and v_val in fp16, we have to return the entire cache to be used by the attention. Dont we have to upcast it back to float32, since the model is in float32? I dont see it happening anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype in the yaml config is actually bfloat16 so the model is cast to that dtype here. Its actually the quantization component wrapping the linear layers that returns float32, so we need to explicitly downcast to the same dtype as the model. See here why this returns float32. I think we would need to upstream changes to torchao to change this.

Alternatively, we could do the downcast right after computing q,k,v but kv cache already had the dtype given during initialization. Otherwise we either need derive the dtype for q,k,v from the model parameters or change the interface for the attention layer creation to accept a dtype.

Copy link
Contributor

@felipemello1 felipemello1 Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its actually the quantization component wrapping the linear layers that returns float32

@mreso, thanks for explaining it! Lets try to fix the root cause then. I dont think it is the intended behavior of the quantizer to return float32 if the input is float16. I took a look at the wrapper and it says that:

"precision: precision of input and output. e.g. torch.float32 means input
activation is float32 and output is float32.
scales_precision: precision of per group scale."

If you have bandwidth, would you mind doing a quick test with the quantizer dtype inputs? If the inputs is indeed fp16, and precision=fp16 (and scales_precision?), and it still output fp32, then I will ping the torchao folks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a quick test:

import torch
import torch.nn as nn
from torchtune.utils.quantization import Int8DynActInt4WeightQuantizer


class Foo(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(256, 256, bias=False)
        self.linear2 = nn.Linear(256, 256, bias=False)

    def forward(self, x):
        x = self.linear(x)
        x = self.linear2(x)
        return x



quantizer_f32 = Int8DynActInt4WeightQuantizer()

foo = Foo()
foo = quantizer_f32.quantize(foo)
foo = foo.to(dtype=torch.bfloat16)

# foo = torch.compile(foo)

x_bf16 = torch.rand(1,256, dtype=torch.bfloat16)
print(f"{x_bf16.dtype=}")
y=foo(x_bf16)
print(f"{y.dtype=}")

x_f32 = torch.rand(1,256, dtype=torch.float32)
print(f"{x_f32.dtype=}")
y=foo(x_f32)
print(f"{y.dtype=}")

quantizer_bf16 = Int8DynActInt4WeightQuantizer(precision=torch.bfloat16)

bar = Foo()
bar = quantizer_bf16.quantize(bar)
bar = bar.to(dtype=torch.bfloat16)

# bar = torch.compile(bar)

x = torch.rand(1,256, dtype=torch.bfloat16)
print(f"{x.dtype=}")
y=bar(x)
print(f"{y.dtype=}")

As the docs state the output's precision depends on quantizer.precision. So for bfloat16 in it should be bfloat16 out:

linear: linear, in=256, out=256
linear: linear2, in=256, out=256
x_bf16.dtype=torch.bfloat16
y.dtype=torch.float32
x_f32.dtype=torch.float32
y.dtype=torch.float32
linear: linear, in=256, out=256
linear: linear2, in=256, out=256
x.dtype=torch.bfloat16
y.dtype=torch.bfloat16

This brings us back to two problems described in the original issue:

  1. We currently can not set the precision with the current implementation of ao + torchtune as precision has to be a dtype but torchtune only passes the string read from the config.
  2. If we somehow set precision to the correct dtype as e.g. described here AND enable torch.compile (which the generate recipe does) we end up with the inductor error.
Traceback (most recent call last):
  File "/home/mreso/torchtune/test_quantizer.py", line 47, in <module>
    y=bar(x)
 ...
  File "/tmp/torchinductor_mreso/ji/cjib4exg2swki3pwhhi7cgq5qh7l6yriypsqejd3wdcajrx55zrc.py", line 69, in call
    extern_kernels.mm(buf6, reinterpret_tensor(buf8, (256, 256), (1, 256), 0), out=buf9)
RuntimeError: Expected out tensor to have dtype float, but got c10::BFloat16 instead

So if we don't want to solve the problem by casting we need to fix the inductor issue on the ao side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

  1. one is easy to fix. We can just use this mapping to go from string to torch.bfloat16

  2. This seems to need attention

@msaroufim, can you take a look at this issue?

TLDR:
using compile + Int8DynActInt4WeightQuantizer with precision=bf16 breaks.

quantizer_bf16 = Int8DynActInt4WeightQuantizer(precision=torch.bfloat16)

bar = Foo()
bar = quantizer_bf16.quantize(bar)
bar = bar.to(dtype=torch.bfloat16)
bar = torch.compile(bar)

x = torch.rand(1,256, dtype=torch.bfloat16)
y=bar(x)

raises the error:

raceback (most recent call last):
  File "/home/mreso/torchtune/test_quantizer.py", line 47, in <module>
    y=bar(x)
 ...
  File "/tmp/torchinductor_mreso/ji/cjib4exg2swki3pwhhi7cgq5qh7l6yriypsqejd3wdcajrx55zrc.py", line 69, in call
    extern_kernels.mm(buf6, reinterpret_tensor(buf8, (256, 256), (1, 256), 0), out=buf9)
RuntimeError: Expected out tensor to have dtype float, but got c10::BFloat16 instead

If we dont add precision=torch.bfloat16, then for an fp16 input, we get fp32 output.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 any reason why we're not using quantize_() here instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, here is where the component is defined in the config: https://github.com/pytorch/torchtune/blob/main/recipes/configs/quantization.yaml#L27

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here is where its instantiated:

self._quantizer = config.instantiate(cfg.quantizer)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim there was some issues with loading the tensor subclass quantized weights in generate.py before (with map_location="cpu":

map_location="cpu",
mmap=mmap,
), things might have changed though. Also we need to clarify the plan for QAT flow I think. cc @andrewor14 is the recent tensor subclass refactor tested in torchtune?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim @jerryzh168 @felipemello1 is there any short term solution to unblock us here? If its the original issue or the PR?

@@ -211,7 +211,7 @@ def forward(
s_y = y.shape[1]

if self.kv_cache and input_pos is None:
cache_size = self.kv_cache.size
cache_size = self.kv_cache.k_cache.dim(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Much better than doing .item(). Can you keep self.size inside of the cache and do this change there? I think its a bit safer and cleaner.

FYI, a bit related to this PR: #1364

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea. I now just removed the item(), so size will be a tensor which can still be used in the same way as the scalar before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
5 participants