-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype error during quantized generation #1371
base: main
Are you sure you want to change the base?
Conversation
…argument to generate_next token to avoid torch.compile isse (CUDAGraph overwritten); Remove size parameter in kvcache
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1371
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b030750 with merge base 3c580fc (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Very clean changes. I left a few comments
@@ -159,7 +159,7 @@ def generate( | |||
tokens = custom_generate_next_token( | |||
model, | |||
input_pos=curr_input_pos, | |||
x=tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help me understand why we need to do .clone? I worry about performance issues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, without it we're running into this error:
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 48, in generate_next_token
return sample(logits, temperature, top_k)
File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 34, in sample
return multinomial_sample_one(probs)
File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 16, in multinomial_sample_one
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
Could not get this to work with torch.compiler.cudagraph_mark_step_begin() in the first try so I opted for cloning. Can retry to make cudagraph_mark_step_begin() work, if you think performance will take a big hit.
|
||
k_out = self.k_cache | ||
v_out = self.v_cache | ||
k_out[:, :, input_pos] = k_val | ||
v_out[:, :, input_pos] = v_val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: We need to do it because the base model has one dtype, e.g. float32, but the cache is in fp16, for example, is that right?
after we save the k_val and v_val in fp16, we have to return the entire cache to be used by the attention. Dont we have to upcast it back to float32, since the model is in float32? I dont see it happening anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype in the yaml config is actually bfloat16 so the model is cast to that dtype here. Its actually the quantization component wrapping the linear layers that returns float32, so we need to explicitly downcast to the same dtype as the model. See here why this returns float32. I think we would need to upstream changes to torchao to change this.
Alternatively, we could do the downcast right after computing q,k,v but kv cache already had the dtype given during initialization. Otherwise we either need derive the dtype for q,k,v from the model parameters or change the interface for the attention layer creation to accept a dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its actually the quantization component wrapping the linear layers that returns float32
@mreso, thanks for explaining it! Lets try to fix the root cause then. I dont think it is the intended behavior of the quantizer to return float32 if the input is float16. I took a look at the wrapper and it says that:
"precision: precision of input and output. e.g. torch.float32 means input
activation is float32 and output is float32.
scales_precision: precision of per group scale."
If you have bandwidth, would you mind doing a quick test with the quantizer dtype inputs? If the inputs is indeed fp16, and precision=fp16 (and scales_precision?), and it still output fp32, then I will ping the torchao folks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a quick test:
import torch
import torch.nn as nn
from torchtune.utils.quantization import Int8DynActInt4WeightQuantizer
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(256, 256, bias=False)
self.linear2 = nn.Linear(256, 256, bias=False)
def forward(self, x):
x = self.linear(x)
x = self.linear2(x)
return x
quantizer_f32 = Int8DynActInt4WeightQuantizer()
foo = Foo()
foo = quantizer_f32.quantize(foo)
foo = foo.to(dtype=torch.bfloat16)
# foo = torch.compile(foo)
x_bf16 = torch.rand(1,256, dtype=torch.bfloat16)
print(f"{x_bf16.dtype=}")
y=foo(x_bf16)
print(f"{y.dtype=}")
x_f32 = torch.rand(1,256, dtype=torch.float32)
print(f"{x_f32.dtype=}")
y=foo(x_f32)
print(f"{y.dtype=}")
quantizer_bf16 = Int8DynActInt4WeightQuantizer(precision=torch.bfloat16)
bar = Foo()
bar = quantizer_bf16.quantize(bar)
bar = bar.to(dtype=torch.bfloat16)
# bar = torch.compile(bar)
x = torch.rand(1,256, dtype=torch.bfloat16)
print(f"{x.dtype=}")
y=bar(x)
print(f"{y.dtype=}")
As the docs state the output's precision depends on quantizer.precision. So for bfloat16 in it should be bfloat16 out:
linear: linear, in=256, out=256
linear: linear2, in=256, out=256
x_bf16.dtype=torch.bfloat16
y.dtype=torch.float32
x_f32.dtype=torch.float32
y.dtype=torch.float32
linear: linear, in=256, out=256
linear: linear2, in=256, out=256
x.dtype=torch.bfloat16
y.dtype=torch.bfloat16
This brings us back to two problems described in the original issue:
- We currently can not set the precision with the current implementation of ao + torchtune as precision has to be a dtype but torchtune only passes the string read from the config.
- If we somehow set precision to the correct dtype as e.g. described here AND enable torch.compile (which the generate recipe does) we end up with the inductor error.
Traceback (most recent call last):
File "/home/mreso/torchtune/test_quantizer.py", line 47, in <module>
y=bar(x)
...
File "/tmp/torchinductor_mreso/ji/cjib4exg2swki3pwhhi7cgq5qh7l6yriypsqejd3wdcajrx55zrc.py", line 69, in call
extern_kernels.mm(buf6, reinterpret_tensor(buf8, (256, 256), (1, 256), 0), out=buf9)
RuntimeError: Expected out tensor to have dtype float, but got c10::BFloat16 instead
So if we don't want to solve the problem by casting we need to fix the inductor issue on the ao side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
-
one is easy to fix. We can just use this mapping to go from string to torch.bfloat16
-
This seems to need attention
@msaroufim, can you take a look at this issue?
TLDR:
using compile + Int8DynActInt4WeightQuantizer with precision=bf16 breaks.
quantizer_bf16 = Int8DynActInt4WeightQuantizer(precision=torch.bfloat16)
bar = Foo()
bar = quantizer_bf16.quantize(bar)
bar = bar.to(dtype=torch.bfloat16)
bar = torch.compile(bar)
x = torch.rand(1,256, dtype=torch.bfloat16)
y=bar(x)
raises the error:
raceback (most recent call last):
File "/home/mreso/torchtune/test_quantizer.py", line 47, in <module>
y=bar(x)
...
File "/tmp/torchinductor_mreso/ji/cjib4exg2swki3pwhhi7cgq5qh7l6yriypsqejd3wdcajrx55zrc.py", line 69, in call
extern_kernels.mm(buf6, reinterpret_tensor(buf8, (256, 256), (1, 256), 0), out=buf9)
RuntimeError: Expected out tensor to have dtype float, but got c10::BFloat16 instead
If we dont add precision=torch.bfloat16
, then for an fp16 input, we get fp32 output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 any reason why we're not using quantize_()
here instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, here is where the component is defined in the config: https://github.com/pytorch/torchtune/blob/main/recipes/configs/quantization.yaml#L27
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here is where its instantiated:
Line 51 in 9e65fa9
self._quantizer = config.instantiate(cfg.quantizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim there was some issues with loading the tensor subclass quantized weights in generate.py
before (with map_location="cpu":
torchtune/torchtune/utils/_checkpointing/_checkpointer_utils.py
Lines 110 to 111 in f9f75bb
map_location="cpu", | |
mmap=mmap, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim @jerryzh168 @felipemello1 is there any short term solution to unblock us here? If its the original issue or the PR?
@@ -211,7 +211,7 @@ def forward( | |||
s_y = y.shape[1] | |||
|
|||
if self.kv_cache and input_pos is None: | |||
cache_size = self.kv_cache.size | |||
cache_size = self.kv_cache.k_cache.dim(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Much better than doing .item(). Can you keep self.size inside of the cache and do this change there? I think its a bit safer and cleaner.
FYI, a bit related to this PR: #1364
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good idea. I now just removed the item(), so size will be a tensor which can still be used in the same way as the scalar before.
This PR fixes #1349 RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source
Repro:
See #1349 for repro
Cmd:
Output:
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models