Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pydantic
torch>=2.4
tqdm
transformers[sentencepiece]<5.0
vllm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like vLLM should be an optional dependency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can do it in a similar way to what we did for lighteval/lm_eval

Copy link
Collaborator Author

@Giuseppe5 Giuseppe5 Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaving it for now so that test run and I can see what other things I'm breaking in the process, but I'll remove before this PR is merged

1 change: 1 addition & 0 deletions src/brevitas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def env_to_bool(name, default):

REINIT_ON_STATE_DICT_LOAD = env_to_bool('BREVITAS_REINIT_ON_STATE_DICT_LOAD', True)
IGNORE_MISSING_KEYS = env_to_bool('BREVITAS_IGNORE_MISSING_KEYS', False)
IGNORE_EXPORT_KEYS = env_to_bool('BREVITAS_IGNORE_EXPORT_KEYS', True)
# JIT_ENABLED triggers NATIVE_STE_BACKEND_ENABLED to True, but not the other way around
JIT_ENABLED = env_to_bool('BREVITAS_JIT', False) and _enabled
NATIVE_STE_BACKEND_ENABLED = env_to_bool('BREVITAS_NATIVE_STE_BACKEND', False)
Expand Down
26 changes: 14 additions & 12 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,21 @@ def __init__(self, group_size, group_dim) -> None:

@brevitas.jit.script_method
def forward(self, x):
return dynamic_over_sub_channel_block_view(x, self.group_size, self.group_dim)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
x = padding_to_multiple(x, self.group_dim, self.group_size)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = (
tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size
block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list)
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x

def dynamic_over_sub_channel_block_view(x, group_size, group_dim):
tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
x = padding_to_multiple(x, group_dim, group_size)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[group_dim] = (tensor_shape_list[group_dim] + group_size - 1) // group_size
block_dim = group_dim + 1 if group_dim != -1 else len(tensor_shape_list)
tensor_shape_list.insert(block_dim, group_size)
x = x.view(tensor_shape_list)
return x


class StatsInputViewShapeImpl(object):
Expand Down
362 changes: 259 additions & 103 deletions src/brevitas/export/inference/handler.py

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from functools import partial

from packaging import version
import torch
from torch.nn import Module
import torch.nn as nn

from brevitas import torch_version
from brevitas.export.inference.handler import DynamicFloatInferenceHandler
from brevitas.export.inference.handler import DynamicIntInferenceHandler
from brevitas.export.inference.handler import FloatInferencetHandler
Expand Down Expand Up @@ -85,7 +83,7 @@ def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
InferenceManager.set_export_mode(self.model, enabled=False)
self.export_manager.set_export_mode(self.model, enabled=False)
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
Expand All @@ -105,8 +103,8 @@ def hook(self, module, inp, out):
# - Disable return quant tensor since all quant metadata is cached
assert len(self.hook_list) == 1
self.hook_list[0].remove()
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.model.apply(self.export_manager.set_export_handler)
self.export_manager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = QuantizationStatusManager.disable_return_quant_tensor(
self.model)
disable_quant_tensor = partial(_override_create_quant_tensor, state=True)
Expand Down
Empty file.
172 changes: 172 additions & 0 deletions src/brevitas/export/inference/vLLM/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import List
from typing import Optional

import torch
from vllm.model_executor.layers.linear import LinearMethodBase

from brevitas.graph.hadamard import get_hadK
from brevitas.nn.equalized_layer import RotatedModule

from ..handler import FloatInferencetHandler
from ..handler import FloatWeightInferencetHandler
from ..handler import GroupwiseFloatInferenceHandler
from ..handler import GroupwiseFloatWeightInferenceHandler
from ..handler import IntInferencetHandler
from ..handler import IntWeightInferencetHandler

class_mapping = {
'GroupwiseFloatInferenceHandler': GroupwiseFloatInferenceHandler,
'GroupwiseFloatWeightInferenceHandler': GroupwiseFloatWeightInferenceHandler,
'FloatInferencetHandler': FloatInferencetHandler,
'FloatWeightInferencetHandler': FloatWeightInferencetHandler,
'IntWeightInferencetHandler': IntWeightInferencetHandler,
'IntInferencetHandler': IntInferencetHandler,}


class QuantLinear(LinearMethodBase):

def __init__(
self,
input_config=None,
weight_config=None,
bias_config=None,
output_config=None,
rotation_config=None):
self.input_quant = self.configure_proxy(input_config)
if isinstance(weight_config, list):
self.weight_quant = dict()
for i, config in enumerate(weight_config):
self.weight_quant[i] = self.configure_proxy(config)
else:
self.weight_quant = self.configure_proxy(weight_config)
self.bias_quant = self.configure_proxy(bias_config)
self.output_quant = self.configure_proxy(output_config)
self.rotation = self.configure_rotation(rotation_config)

def configure_rotation(self, rotation_config):
if rotation_config is None:
return torch.nn.Identity()
rot_mat_shape = rotation_config['rotation_size']['rot_mat_shape']
k = rotation_config['rotation_size']['k']
had_mat, _ = get_hadK(rot_mat_shape)
return RotatedModule(self, had_mat, k)

def configure_proxy(self, quant_config):
# No config, no quantizer
if quant_config is None:
return torch.nn.Identity()

# Extract element that are not part of the state dict
quant_class_name = quant_config['class_type']
float_to_int_impl_type = quant_config['float_to_int_impl_type']
del quant_config['class_type']
del quant_config['float_to_int_impl_type']

# Scale and zero-point are the only float elements in the state dict
for k, v in quant_config.items():
if not isinstance(v, torch.Tensor):
if k == 'scale' or k == 'zero_point':
quant_config[k] = torch.tensor(v)
else:
quant_config[k] = torch.tensor(v, dtype=torch.int)

# Shapes must be set otherwise the state dict loading will fail
scale_shape = quant_config['scale'].shape
zero_point_shape = quant_config['zero_point'].shape
quant_class_type = class_mapping[quant_class_name]
quant_class = quant_class_type(scale_shape, zero_point_shape)

# Set the remaining attributes
quant_class.float_to_int_impl_type = float_to_int_impl_type
quant_class.load_state_dict(quant_config)
return quant_class

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
out_per_partition = sum(output_partition_sizes)
w = torch.empty(
(out_per_partition, input_size_per_partition),
device="cuda",
dtype=params_dtype,
)

layer.weight = torch.nn.Parameter(w, requires_grad=False)

# Handling the packed weights for loading
base_loader = extra_weight_attrs.get("weight_loader", None)

def packed_weight_loader(param, loaded_weight, loaded_shard_id=None, *args, **kwargs):

if loaded_shard_id is not None:
if isinstance(loaded_shard_id, int):
_loaded_shard_id = loaded_shard_id
else:
if loaded_shard_id == "q":
_loaded_shard_id = 0
elif loaded_shard_id == "k":
_loaded_shard_id = 1
elif loaded_shard_id == "v":
_loaded_shard_id = 2
else:
raise ValueError(f"Invalid loaded_shard_id: {loaded_shard_id}")

logical_widths = list(output_partition_sizes)
start_idx = sum(logical_widths[:_loaded_shard_id])
end_idx = start_idx + logical_widths[_loaded_shard_id]
weight_quant = self.weight_quant[_loaded_shard_id]
else:
start_idx = 0
end_idx = out_per_partition
weight_quant = self.weight_quant
if weight_quant is not None:
loaded_weight = weight_quant(loaded_weight.cuda())[0].cpu()

if base_loader is not None:
return base_loader(param[start_idx:end_idx], loaded_weight, *args, **kwargs)
param[start_idx:end_idx].data.copy_(loaded_weight)

setattr(layer.weight, "weight_loader", packed_weight_loader)

# If this layer has bias, allocate it
if getattr(layer, "bias", None) is not None:
b = torch.empty((out_per_partition,), device="cuda", dtype=params_dtype)
layer.bias = torch.nn.Parameter(b, requires_grad=False)
base_bias_loader = extra_weight_attrs.get("bias_loader", None)

def packed_bias_loader(param, loaded_bias, *args, **kwargs):
if isinstance(loaded_bias, (list, tuple)):
loaded_bias = torch.cat(list(loaded_bias), dim=0)
if base_bias_loader is not None:
return base_bias_loader(param, loaded_bias, *args, **kwargs)
param.data.copy_(loaded_bias)

setattr(layer.bias, "bias_loader", packed_bias_loader)

# Preserve attrs that vLLM weight loaders may attach
for k, v in extra_weight_attrs.items():
if k in ("weight_loader", "bias_loader"):
continue
setattr(layer.weight, k, v)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# x = self.rotation.rotation_forward(x)
x = self.input_quant(x)
bias = self.bias_quant(bias) if bias is not None else None
y = x.matmul(layer.weight.t())
if bias is not None:
y = y + bias
y = self.output_quant(y)
return y
Loading
Loading