Skip to content

Fix Per Row scaling for inference #2253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged

Fix Per Row scaling for inference #2253

merged 1 commit into from
May 27, 2025

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented May 23, 2025

Stacked PRs:


Summary

In somewhere in the myriad of refactors we broke per-row scaling. This had no effect becuase dequant just so happend to work and a global scale, so we test now

Copy link

pytorch-bot bot commented May 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2253

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d0b71cc with merge base a776b1f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from 8e43100 to ae92547 Compare May 23, 2025 18:32
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2025
@drisspg drisspg added bug Something isn't working high priority topic: bug fix Use this tag for PRs that fix bugs labels May 23, 2025
@drisspg drisspg marked this pull request as draft May 23, 2025 18:44
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from ae92547 to bd0df40 Compare May 23, 2025 20:16
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from bd0df40 to a3ec2a2 Compare May 23, 2025 21:02
@drisspg drisspg marked this pull request as ready for review May 23, 2025 21:02
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from a3ec2a2 to a2f2f09 Compare May 23, 2025 21:11
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from a2f2f09 to 83640b7 Compare May 23, 2025 21:14
drisspg added a commit that referenced this pull request May 23, 2025
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from 83640b7 to f550778 Compare May 23, 2025 21:19
@drisspg drisspg force-pushed the drisspg/stack/56 branch from f550778 to 5fa37cb Compare May 23, 2025 21:38
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR restores per-row scaling support for float8 quantization and updates related APIs and tests.

  • Add block_size support to quant parameter selection and dequantization for block-wise scaling
  • Normalize granularity in the quantization API post-init
  • Fix row-wise scale handling in float8 layouts and update tests for verifying per-row scale shapes

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
quant_primitives.py Introduce block_size argument for block-wise scale compute and update dequant logic
quant_api.py Normalize granularity tuple in __post_init__
float8_layout.py Adjust row-wise scale transpose logic
affine_quantized_tensor.py Pass block_size into choose_qparams_affine_float8
test_affine_quantized_float.py Parametrize and verify per-row scale shapes in tests
Comments suppressed due to low confidence (4)

torchao/quantization/quant_primitives.py:1985

  • The comment about supporting only tensorwise scaling is outdated since block-wise scaling is now implemented; please update or remove this comment to avoid confusion.
# only tensorwise scaling is supported for now:

torchao/quantization/quant_api.py:1579

  • [nitpick] This normalization block appears duplicated from the class definition in the context excerpt; consider consolidating to avoid redundant logic.
activation_granularity, weight_granularity = _normalize_granularity(

torchao/dtypes/floatx/float8_layout.py:341

  • Removing the unsqueeze(-1) before transpose changes the tensor's rank and may break row-wise scale alignment; restore the unsqueeze or adjust downstream logic to match expected dimensions.
w_scale = w_scale.T

torchao/dtypes/affine_quantized_tensor.py:465

  • The variable block_size is used here but not defined or passed into this scope; ensure block_size is available or passed from the calling context.
scale = choose_qparams_affine_float8(

@drisspg drisspg force-pushed the drisspg/stack/56 branch 2 times, most recently from 5d02444 to 4d7c98f Compare May 23, 2025 22:38
@drisspg drisspg force-pushed the drisspg/stack/56 branch 2 times, most recently from e6800d3 to f106d46 Compare May 23, 2025 23:24
@drisspg drisspg requested a review from danielvegamyhre May 23, 2025 23:28
@drisspg drisspg force-pushed the drisspg/stack/56 branch from f106d46 to b10c3cc Compare May 24, 2025 00:00
@drisspg drisspg mentioned this pull request May 24, 2025
@drisspg drisspg force-pushed the drisspg/stack/56 branch from b10c3cc to 92a01f5 Compare May 24, 2025 17:12
stack-info: PR: #2253, branch: drisspg/stack/56
@drisspg drisspg force-pushed the drisspg/stack/56 branch from 92a01f5 to d0b71cc Compare May 24, 2025 17:42
@drisspg
Copy link
Contributor Author

drisspg commented May 24, 2025

Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 15.06it/s, est. speed input: 108.45 toks/s, output: 195.81 toks/s]
Prompt: 'Why is Pytorch 2.0 the best machine learning compiler?', Generated text: " Ok,\nI am five months after changing? This one's"
Prompt: 'Hello, my name is', Generated text: ' Samuel Rodriguez.\r\n\r\nIf you want to know some regulars of'
Prompt: 'The president of the United States is', Generated text: ' revered has taken old wine beefs, morgs.\n'
Prompt: 'The capital of France is', Generated text: ' ruling Tauresearchednow.com offers a regularsolutions'
Prompt: 'The future of AI is', Generated text: ' moving... I thought about implementing logic, physics are one transform'
[rank0]:[W524 10:43:22.504024005 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before progr

Per Tensor is broken, it has to do with how we are slicing

@drisspg
Copy link
Contributor Author

drisspg commented May 27, 2025

Okay I verified on base PR tht at I am seeing weird behavior in VLLM I am going to make a separate issue for this

@drisspg drisspg merged commit 1017c7e into main May 27, 2025
19 checks passed
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
return scale.to(dtype=scale_dtype)
if block_size is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this per tensor? should we standardize on this? some other place are probably using -1 to represent this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think -1 is fine as well,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. high priority topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants