Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]: Add support for kleidiai quantization schemes #1447

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ng-05
Copy link

@ng-05 ng-05 commented Dec 19, 2024

Description:
Allow int8_dynamic_activation_intx_weight to work with aten _dyn_quant_matmul_4bit op

Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1447

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @ng-05!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@ng-05 ng-05 marked this pull request as draft December 19, 2024 10:44
@ng-05
Copy link
Author

ng-05 commented Jan 8, 2025

Hello @jerryzh168 ,
We want to support two diff type of int4 schemes.

  1. symmetric_groupwise -> groupsize [ 32, 64, 128 etc ]
  2. symmetric_channelwise -> groupsize is equal to channelsize of the matmul weights

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

Currently I am using "scheme" parameter to differentiate between the two.
aarch64_cpu_channelwise.json
aarch64_cpu_groupwise.json

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jan 8, 2025

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples:

granularity: Optional[
,
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

@ng-05
Copy link
Author

ng-05 commented Jan 9, 2025

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples:

granularity: Optional[

,

weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

Thanks for the inputs @jerryzh168.

I have initial change ready which extends int4_weight_only quantizer.

The 4 bit KleidiAI kernels quantizes the weight in torchao and input to 8 bit within the kernel itself instead of quantizing the input in the torchao the way int8_dynamic_activation_int4_weight does.
For this reason I am extending the int4_weight_only api. I am slightly confused if the intention of this api is to convey NO input quantisation to user?

Currently neither int4_weight_only nor int8_dynamic_activation_int4_weight fully aligns with the way kelidiai 4 bit kernels are working.

I feel int4_weight_only is closest to what we want to do, what are your thoughts on this?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jan 9, 2025

I feel int4_weight_only is closest to what we want to do, what are your thoughts on this?

yeah int4_weight_only means no input quantization, I think it aligns better with int8_dynamic_activation_int4_weight, you can use a different layout and customize the logic for input quantization.

we also have

def int8_dynamic_activation_intx_weight(
that is the same as your use case. there is some ongoing refactors/updates there as well right now

You can also check out: #995

@ng-05
Copy link
Author

ng-05 commented Jan 11, 2025

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2.
For now I have kept the API separate for review and testing.

Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

target: Target

# Allow bias access via layout
bias: Optional[torch.Tensor] = None
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layout is more of a "type" actually, why is bias Tensor passed here?

the corresponding "storage" is TensorImpl

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to access the bias to be packed with my weights and scales. I can not find any other existing way to pass bias to from_plain() api via

tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

How do you think I should access bias in the packing function here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ng-05 - bias is not required to differentiate this layout i.e. you can dispatch to this layout with and without bias.

That said, @jerryzh168 - we do need to figure out how to get the bias to the from_plain method. I know it doesn't play nice with the tensor representation abstraction for AQT, do you have any other suggestions?

Perhaps until then can we just do a add op followed by gemm, and put a TODO on fixing APIs?

Copy link
Contributor

@jerryzh168 jerryzh168 Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does not fit into AQT, I think it's fine to create a new tensor subclass, but putting bias Tensor in the layout is bit conflicting the design (has_bias boolean is fine) since it's a "type", should not store data there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me overall, can you add some tests?

@jerryzh168
Copy link
Contributor

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

I don't think we need to expose these fine grained args to torchchat cli, we just need these high level args like: https://github.com/pytorch/torchchat/blob/main/torchchat/quant_config/mobile.json

we are also working on migrating torchchat to use torchao quant api btw

@metascroy
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.

Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: https://github.com/pytorch/ao/blob/main/torchao/experimental/quant_api.py#L438

We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)

Besides taht i have a couple of questions.

  • In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
  • I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
  • If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?

@kimishpatel
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.
Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.
I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: main/torchao/experimental/quant_api.py#L438

We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Any specific reason why use subclass API instead of module swap?

@ng-05
Copy link
Author

ng-05 commented Jan 13, 2025

I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)

Besides taht i have a couple of questions.

  • In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
  • I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
  • If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?

I am unaware of executorch status and what performance you get with klediai kernels over there. I tested this change with torch.compile() and it seems to be working fine.

@ng-05
Copy link
Author

ng-05 commented Jan 13, 2025

@jerryzh168 @kimishpatel are we testing the 4 bit symmetric quantization anywhere without adding a dequant layer on the result? In my testing I am seeing very poor accuracy with symmetric 4 bit quant scheme with this PR.
For comparison, the mean relative error jumps from 0.0006 (with llama.cpp algo ) to 0.0044 (torchao algo)with kleidiai kernels.
This is the reference scheme that I am using for the 4 bit symmetric quant. ggerganov/llama.cpp#729

target: Target

# Allow bias access via layout
bias: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ng-05 - bias is not required to differentiate this layout i.e. you can dispatch to this layout with and without bias.

That said, @jerryzh168 - we do need to figure out how to get the bias to the from_plain method. I know it doesn't play nice with the tensor representation abstraction for AQT, do you have any other suggestions?

Perhaps until then can we just do a add op followed by gemm, and put a TODO on fixing APIs?

torchao/experimental/quant_api.py Outdated Show resolved Hide resolved
@metascroy
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.
Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.
I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: main/torchao/experimental/quant_api.py#L438
We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Any specific reason why use subclass API instead of module swap?

My understanding from @jerryzh168 is that long-term, torchao plans to support pt2e and subclass/quantize_ based quantization long-term. I believe torchchat is working on (and has already partially completed) moving module-swap based quantization over to use quantize_ (cc @Jack-Khuu to keep me honest there).

@metascroy
Copy link
Contributor

metascroy commented Jan 13, 2025

@jerryzh168 @kimishpatel are we testing the 4 bit symmetric quantization anywhere without adding a dequant layer on the result? In my testing I am seeing very poor accuracy with symmetric 4 bit quant scheme with this PR. For comparison, the mean relative error jumps from 0.0006 (with llama.cpp algo ) to 0.0044 (torchao algo)with kleidiai kernels. This is the reference scheme that I am using for the 4 bit symmetric quant. ggerganov/llama.cpp#729

The quantizer is tested by

@jerryzh168 @kimishpatel are we testing the 4 bit symmetric quantization anywhere without adding a dequant layer on the result? In my testing I am seeing very poor accuracy with symmetric 4 bit quant scheme with this PR. For comparison, the mean relative error jumps from 0.0006 (with llama.cpp algo ) to 0.0044 (torchao algo)with kleidiai kernels. This is the reference scheme that I am using for the 4 bit symmetric quant. ggerganov/llama.cpp#729

We have tests torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py that compare python-implemented fallback to the torchao kernel's output (in #1553, this test is renamed to torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py and is based on comparing AQT's PlainLayout to the kernel outputs)

@kimishpatel
Copy link
Contributor

I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)
Besides taht i have a couple of questions.

  • In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
  • I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
  • If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?

I am unaware of executorch status and what performance you get with klediai kernels over there. I tested this change with torch.compile() and it seems to be working fine.

are you planning to move dynamic quantized ops in aten to torchao?

@jerryzh168
Copy link
Contributor

Any specific reason why use subclass API instead of module swap?

you mean quantize_ API right? it is the officially supported API for inference path of torchao

@jerryzh168
Copy link
Contributor

agree with @kimishpatel that we want these ops in ao instead of aten, but I talked to @digantdesai last time he explains here: pytorch/pytorch#143289 (comment)

@digantdesai
Copy link
Contributor

are you planning to move dynamic quantized ops in aten to torchao?

@kimishpatel - long term yes, but in the short term we might support aten backend here. Two main blockers,
(1) as I said earlier on the ATen op PR pytorch/pytorch#143289 (comment) about Kleidi in two places, and
(2) solidify out the torchAO apis (bias packing, layout clean up from Scott R., module swap to dispatch), move out of experimental, and integrate Kleidi ops with low-bit kernels.

I think if we unblock leveraging ATen op for now from here for eager/compile, we can solve (1) and (2) without blocking actual LLM use cases, this was a time sensitive demo request from Arm side for PT2.6.

@kimishpatel
Copy link
Contributor

are you planning to move dynamic quantized ops in aten to torchao?

@kimishpatel - long term yes, but in the short term we might support aten backend here. Two main blockers, (1) as I said earlier on the ATen op PR pytorch/pytorch#143289 (comment) about Kleidi in two places, and (2) solidify out the torchAO apis (bias packing, layout clean up from Scott R., module swap to dispatch), move out of experimental, and integrate Kleidi ops with low-bit kernels.

I think if we unblock leveraging ATen op for now from here for eager/compile, we can solve (1) and (2) without blocking actual LLM use cases, this was a time sensitive demo request from Arm side for PT2.6.

Thanks for the context Digant. Lets make sure we make progress on the long term front

Description:

Allow int8_dynamic_activation_intx_weight to work with aten  _dyn_quant_matmul_4bit op

Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0

Signed-off-by: Nikhil Gupta <[email protected]>
@ng-05
Copy link
Author

ng-05 commented Jan 15, 2025

I have done the changes as per the review comments and latest refactoring (PackedLinearInt8DynamicActivationIntxWeightLayout)
Can I please get a review @metascroy @digantdesai @jerryzh168

Had to force push as previous 3 commits cant be rebased quickly on the latest refactoring

@ng-05
Copy link
Author

ng-05 commented Jan 16, 2025

@digantdesai If we add bias as postop then we will take a hit on the latency for the model. For now I have made changes that will not affect other quantizers and only take care of packing bias when target is "aten" with PackedLinearInt8DynamicActivationIntxWeightLayout

@kimishpatel
Copy link
Contributor

I have expressed my high level concerns which has been discussed. So I am gonna leave rest of the review to @metascroy and @digantdesai

@ng-05 ng-05 marked this pull request as ready for review January 17, 2025 13:26
elif target.lower() == "aten":
return Target.ATEN
else:
raise ValueError(f"Invalid target: {target}")

class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
has_weight_zeros: Optional[bool]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The packed weights from Kleidi have bias packed with them, right? If so, let's add has_bias: Optional[bool] here to layout.

if target == "aten":
if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \
weight_dtype != torch.int4 or \
has_weight_zeros != True or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the KleidiAI op does not take the zero points during packing (scale-only quantization)? So shouldn't has_weight_zeros be false?

@@ -506,6 +512,7 @@ def int8_dynamic_activation_intx_weight(
weight_dtype: torch.dtype = torch.int4,
granularity: Union[PerRow, PerGroup] = PerGroup(128),
has_weight_zeros: bool = False,
target: str = "native",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to pass this in layout's constructor because it isn't related to the quantization intent, but packing format/kernel selection e.g., layout= PackedLinearInt8DynamicActivationIntxWeightLayout(target="native")

f"- weight_dtype to be torch.int4,\n"
f"- weight_mapping_type to be MappingType.SYMMETRIC"
)
elif not isinstance(layout, PlainLayout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard this try/except on if isisntance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) instead. In case other layout is added in future, guarding on not PlainLayout is too broad.

assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only bfloat16 on PerGroup, but not on PerRow?

+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
)

if target == "aten":
Copy link
Contributor

@metascroy metascroy Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be something like:

if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
   assert (act_mapping_type == MappingType.ASYMMETRIC), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"

    if taget == "aten":
        # Do KleidiAI specific checks
    
    if target == "native":
        # Do try/except import logic

@metascroy
Copy link
Contributor

Overall I think it's close. Can you be sure to run torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py to make sure the tests pass. It is not currently enabled in CI.

You should also add some test cases for your new target to that file to check accuracy/exportability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants