Skip to content

make smoothquant more PT2 friendly #1639

Open
@vkuzo

Description

torchao's smoothquant recently broke after a change to PyTorch core: pytorch/pytorch#145733 . We should make the updates suggested by @anijain2305 in that issue to our code. I actually think we should go a bit farther and go with something like

#
# before
#
class _ActQuantizer:
    def __init__(self, target_dtype, quant_min=-127):
        self.target_dtype = target_dtype
        self.quant_min = quant_min

    def dynamic_quantize(self, input):
        return to_affine_quantized_intx(
            input,
            MappingType.SYMMETRIC,
            _get_per_token_block_size(input),
            self.target_dtype,
            self.quant_min,
        )

    def static_quantize(self, input, scale, zero_point):
        return to_affine_quantized_intx_static(
            input,
            scale,
            zero_point,
            list(input.shape),
            self.target_dtype,
            self.quant_min,
        )

#
# after
#
@dataclass
class _ActQuantConfig:
    target_dtype: torch.dtype
    quant_min: int = -127

# then, logic elsewhere chooses whether to call static or dynamic quant based on the contents of an instance of `_ActQuantConfig`

My feedback here is similar in spirit to #1595 - IMO it's simpler and safer to pass around dumb config objects and use them to choose which function to call, instead of encoding the "which function to call" information in the config as a callable object.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions