Skip to content

Commit 7d4a78c

Browse files
committed
requirements
1 parent 73e9264 commit 7d4a78c

File tree

3 files changed

+10
-22
lines changed

3 files changed

+10
-22
lines changed

requirements/requirements-llm.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ pydantic
1010
torch>=2.4
1111
tqdm
1212
transformers[sentencepiece]<5.0
13+
vllm

src/brevitas/export/inference/handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,19 +301,23 @@ def __init__(self):
301301
self.register_buffer('exponent_bit_width', torch.ones(()))
302302
self.register_buffer('exponent_bias', torch.ones(()))
303303
self.register_buffer('fp_internal_scale_min', torch.ones(()))
304-
self.register_buffer('saturating', torch.ones(()).to(torch.bool))
304+
self.register_buffer('saturating_t', torch.ones(()).to(torch.bool))
305305
self.inf_values = None
306306
self.nan_values = None
307307
self.eps = 1e-8 #torch.finfo(self.scale.dtype).tiny
308308

309+
@property
310+
def saturating(self):
311+
return bool(self.saturating_t.item())
312+
309313
def prepare_for_export(self, module):
310314
FloatToIntMixin.prepare_for_export(self, module)
311315
if module.is_quant_enabled:
312316

313317
self.exponent_bit_width = module.exponent_bit_width()
314318
self.mantissa_bit_width = module.mantissa_bit_width()
315319
self.exponent_bias = module.exponent_bias()
316-
self.saturating = torch.tensor(module.is_saturating())
320+
self.saturating_t = torch.tensor(module.is_saturating())
317321
self.inf_values = module.inf_values()
318322
self.nan_values = module.nan_values()
319323
if module.tensor_quant is not None:

src/brevitas/export/inference/vLLM/manager.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torch.nn import Module
1414
import torch.nn as nn
1515
from vllm.model_executor.layers.linear import LinearBase
16-
from vllm.model_executor.layers.linear import LinearMethodBase
1716
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
1817
from vllm.model_executor.layers.linear import QKVParallelLinear
1918
from vllm.model_executor.layers.linear import RowParallelLinear
@@ -23,24 +22,7 @@
2322
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
2423

2524
import brevitas.config as config
26-
from brevitas.export.inference.handler import DynamicFloatInferenceHandler
27-
from brevitas.export.inference.handler import DynamicIntInferenceHandler
28-
from brevitas.export.inference.handler import FloatInferencetHandler
29-
from brevitas.export.inference.handler import FloatWeightInferencetHandler
30-
from brevitas.export.inference.handler import GroupwiseFloatInferenceHandler
31-
from brevitas.export.inference.handler import GroupwiseFloatWeightInferenceHandler
32-
from brevitas.export.inference.handler import GroupwiseIntInferenceHandler
33-
from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler
34-
from brevitas.export.inference.handler import IntInferencetHandler
35-
from brevitas.export.inference.handler import IntWeightInferencetHandler
3625
from brevitas.export.inference.vLLM.handler import QuantLinear
37-
from brevitas.export.manager import _set_proxy_export_handler
38-
from brevitas.export.manager import _set_proxy_export_mode
39-
from brevitas.export.manager import _set_recurrent_layer_export_handler
40-
from brevitas.export.manager import _set_recurrent_layer_export_mode
41-
from brevitas.export.manager import BaseManager
42-
from brevitas.graph.calibrate import QuantizationStatusManager
43-
from brevitas.nn.equalized_layer import EqualizedModule
4426
from brevitas.nn.equalized_layer import RotatedModule
4527
from brevitas.nn.mixin import QuantLayerMixin
4628
from brevitas.proxy.quant_proxy import QuantProxyFromInjector
@@ -178,8 +160,9 @@ def export(self, model, filepath):
178160
proxy_dict['class_type'] = export_handler.__class__.__name__
179161
if isinstance(module, self.wrap_layers):
180162
layer_dict['rotation_config'] = dict()
181-
layer_dict['rotation_config']['rot_mat_shape'] = module.had_mat.shape[0] if module.had_mat is not None else None
163+
layer_dict['rotation_config']['rot_mat_shape'] = module.had_mat.shape[
164+
0] if module.had_mat is not None else None
182165
layer_dict['rotation_config']['k'] = module.k
183-
166+
184167
with open(json_filename, 'w') as f:
185168
json.dump(json_to_save, f, cls=EncodeTensor)

0 commit comments

Comments
 (0)