Skip to content

Feat (export/wave): support wave export #1261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/brevitas/export/shark/wave/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .handler import *
from .manager import wave_inference_mode
95 changes: 95 additions & 0 deletions src/brevitas/export/shark/wave/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from abc import abstractmethod
from typing import Tuple

from iree.turbine.kernel.wave.nn import WaveQuantLinear
import torch
from torch import Tensor
import torch.nn as nn

from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.nn import QuantLinear


class InferenceHandler(torch.nn.Module, ABC):

def attach_debug_info(self, module: nn.Module):
pass

@abstractmethod
def prepare_for_export(self, module: nn.Module):
pass


class QuantLinearFp8Handler(InferenceHandler):
handled_layer = QuantLinear

def __init__(self):
super().__init__()
self.weight_quant = FloatWeightInferencetHandler()
self.input_quant = FloatInferencetHandler()
self.wave_linear = None

def validate(self, module):
# TODO: Check that we are quantizing to the correct fp8 type, etc. etc.
pass

def prepare_for_export(self, module):
## Weight export
out_feat, input_feat = module.weight.shape[0], module.weight.shape[1]
if module.weight_quant.is_quant_enabled:
weight_quant = module.weight_quant
self.weight_quant.prepare_for_export(weight_quant)
if module.input_quant.is_quant_enabled:
input_quant = module.input_quant
self.input_quant.prepare_for_export(input_quant)
quant_params = {
'weight_scale': self.weight_quant.scale,
'weight_scale_shape': self.weight_quant.scale.shape,
'input_scale': self.input_quant.scale,
'input_scale_shape': self.input_quant.scale.shape,
'qdtype': torch.float8_e4m3fnuz}
# self.wave_linear = WaveQuantLinear(
# input_feat, out_feat, quant_params, bias=False)
# self.wave_linear.weight.data = module.weight.data
# if module.bias is not None:
# self.wave_linear.bias.data = module.bias.data
self.bias = module.bias
self.weight = module.weight
del module.weight
del module.bias

def forward(self, input):
input_q = self.input_quant.quantize(input, self.input_quant.scale.to(input.device), None)
weight_q = self.weight_quant.quantize(
self.weight, self.weight_quant.scale.to(input.device), None)

if len(input_q.shape) > 2:
B = input_q.shape[0]
output_1 = torch.stack(
[
torch._scaled_mm(
input_q[i].to(torch.float8_e4m3fnuz),
weight_q.t().to(torch.float8_e4m3fnuz),
scale_a=self.input_quant.scale.to(input.device),
scale_b=self.weight_quant.scale.to(input.device),
# bias=self.bias,
out_dtype=torch.float16) for i in range(B)],
dim=0)
else:
output_1 = torch._scaled_mm(
input_q.to(torch.float8_e4m3fnuz),
weight_q.t().to(torch.float8_e4m3fnuz),
scale_a=self.input_quant.scale.to(input.device),
scale_b=self.weight_quant.scale.to(input.device),
# bias=self.bias,
out_dtype=torch.float16)

if self.bias is not None:
output_1 += self.bias

return output_1
114 changes: 114 additions & 0 deletions src/brevitas/export/shark/wave/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from functools import partial

import torch
from torch.nn import Module
import torch.nn as nn

from brevitas.export.manager import _set_layer_export_handler
from brevitas.export.manager import _set_layer_export_mode
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import _set_recurrent_layer_export_handler
from brevitas.export.manager import _set_recurrent_layer_export_mode
from brevitas.export.manager import BaseManager
from brevitas.export.shark.wave import QuantLinearFp8Handler
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor


def _override_caching_mode(m: nn.Module, attr: str, enabled: bool, metadata_only: bool = True):
cache_var = 'cache_inference_quant_' + attr
cache_var_metadata_only = cache_var + '_metadata_only'
if hasattr(m, cache_var):
setattr(m, cache_var, enabled)
setattr(m, cache_var_metadata_only, metadata_only)


def _override_bias_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'bias', enabled, metadata_only)


def _override_act_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = True):
_override_caching_mode(m, 'act', enabled, metadata_only)


def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bool = False):
_override_caching_mode(m, 'weight', enabled, metadata_only)


def _override_create_quant_tensor(m: nn.Module, state: bool):
if hasattr(m, 'skip_create_quant_tensor'):
m.skip_create_quant_tensor = state


class wave_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
self.enabled = enabled
self.cache_quant_weight = cache_quant_weight
self.export_manager = SharkWaveManager
self.hook_list = []
self.return_quant_tensor_state = dict()

def __enter__(self):
if self.enabled:
# Register the hook and store it in the list so that it can be removed by the hook itself when called
handle = self.model.register_forward_hook(self.hook)
self.hook_list.append(handle)

# Enable bias for everything. Optionally, store the fully fake-quantized weights
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=True, metadata_only=True))
self.model.apply(lambda m: _override_act_caching_mode(m, enabled=True))
self.model.apply(
lambda m: _override_weight_caching_mode(
m, enabled=True, metadata_only=not self.cache_quant_weight))
torch._dynamo.reset()

def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
SharkWaveManager.set_export_mode(self.model, enabled=False)
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False))
if self.cache_quant_weight:
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
enable_quant_tensor = partial(_override_create_quant_tensor, state=False)
self.model.apply(enable_quant_tensor)

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
# - Set the model in export mode
# - Attach export handlers
# - Disable return quant tensor since all quant metadata is cached
assert len(self.hook_list) == 1
self.hook_list[0].remove()
self.model.apply(SharkWaveManager.set_export_handler)
SharkWaveManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
disable_quant_tensor = partial(_override_create_quant_tensor, state=True)
self.model.apply(disable_quant_tensor)


# Inheritance from BaseManager is not techincally needed
class SharkWaveManager(BaseManager):
handlers = [QuantLinearFp8Handler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
_set_layer_export_mode(model, enabled)
_set_recurrent_layer_export_mode(model, enabled)

@classmethod
def set_export_handler(cls, module: Module):
_set_layer_export_handler(cls, module)
_set_recurrent_layer_export_handler(cls, module)
1 change: 0 additions & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

quant_input = self.input_quant(inp)
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,15 @@ def quantize_llm(args, extra_args=None):
if args.eval and not args.no_quantize:

print("Model eval...")
with torch.no_grad(), quant_inference_mode(model, compile=args.compile_eval):
from brevitas.export.shark.wave import wave_inference_mode
with torch.no_grad(), wave_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
# with torch.no_grad(), quant_inference_mode(model, compile=args.compile_eval):
# model(**calibration_loader[0])
# quant_ppl = compute_perplexity(
# model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

few_shot_eval_results = dict()
Expand Down