-
Notifications
You must be signed in to change notification settings - Fork 239
Feat (vLLM): initial export support #1444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Giuseppe5
wants to merge
20
commits into
Xilinx:dev
Choose a base branch
from
Giuseppe5:vllm_export
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
fecfcb6
Fix
Giuseppe5 195443c
Feat (vLLM): initial export support
Giuseppe5 df68ed8
Cleanup
Giuseppe5 19aa9c9
More cleanup
Giuseppe5 aac450d
More bugfix, cleanup
Giuseppe5 fb46fe6
More cleanup and fixes
Giuseppe5 1244425
Removed too much stuff
Giuseppe5 69b1d49
temp
Giuseppe5 6f544c6
Temp 2
Giuseppe5 ed6b8f1
cleanup
Giuseppe5 7225614
requirements
Giuseppe5 2e94286
import
Giuseppe5 0a0c062
import 2
Giuseppe5 fd5edcc
Fix init
Giuseppe5 67be3f8
fix init 2
Giuseppe5 b9ae23a
Fix proxies
Giuseppe5 399363e
Update quantize.py
Giuseppe5 3a7ed83
Update main.py
Giuseppe5 c8716a7
sync
Giuseppe5 79cc073
Fix
Giuseppe5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,3 +10,4 @@ pydantic | |
| torch>=2.4 | ||
| tqdm | ||
| transformers[sentencepiece]<5.0 | ||
| vllm | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| from typing import List | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from vllm.model_executor.layers.linear import LinearMethodBase | ||
|
|
||
| from brevitas.graph.hadamard import get_hadK | ||
| from brevitas.nn.equalized_layer import RotatedModule | ||
|
|
||
| from ..handler import FloatInferencetHandler | ||
| from ..handler import FloatWeightInferencetHandler | ||
| from ..handler import GroupwiseFloatInferenceHandler | ||
| from ..handler import GroupwiseFloatWeightInferenceHandler | ||
| from ..handler import IntInferencetHandler | ||
| from ..handler import IntWeightInferencetHandler | ||
|
|
||
| class_mapping = { | ||
| 'GroupwiseFloatInferenceHandler': GroupwiseFloatInferenceHandler, | ||
| 'GroupwiseFloatWeightInferenceHandler': GroupwiseFloatWeightInferenceHandler, | ||
| 'FloatInferencetHandler': FloatInferencetHandler, | ||
| 'FloatWeightInferencetHandler': FloatWeightInferencetHandler, | ||
| 'IntWeightInferencetHandler': IntWeightInferencetHandler, | ||
| 'IntInferencetHandler': IntInferencetHandler,} | ||
|
|
||
|
|
||
| class QuantLinear(LinearMethodBase): | ||
|
|
||
| def __init__( | ||
| self, | ||
| input_config=None, | ||
| weight_config=None, | ||
| bias_config=None, | ||
| output_config=None, | ||
| rotation_config=None): | ||
| self.input_quant = self.configure_proxy(input_config) | ||
| if isinstance(weight_config, list): | ||
| self.weight_quant = dict() | ||
| for i, config in enumerate(weight_config): | ||
| self.weight_quant[i] = self.configure_proxy(config) | ||
| else: | ||
| self.weight_quant = self.configure_proxy(weight_config) | ||
| self.bias_quant = self.configure_proxy(bias_config) | ||
| self.output_quant = self.configure_proxy(output_config) | ||
| self.rotation = self.configure_rotation(rotation_config) | ||
|
|
||
| def configure_rotation(self, rotation_config): | ||
| if rotation_config is None: | ||
| return torch.nn.Identity() | ||
| rot_mat_shape = rotation_config['rotation_size']['rot_mat_shape'] | ||
| k = rotation_config['rotation_size']['k'] | ||
| had_mat, _ = get_hadK(rot_mat_shape) | ||
| return RotatedModule(self, had_mat, k) | ||
|
|
||
| def configure_proxy(self, quant_config): | ||
| # No config, no quantizer | ||
| if quant_config is None: | ||
| return torch.nn.Identity() | ||
|
|
||
| # Extract element that are not part of the state dict | ||
| quant_class_name = quant_config['class_type'] | ||
| float_to_int_impl_type = quant_config['float_to_int_impl_type'] | ||
| del quant_config['class_type'] | ||
| del quant_config['float_to_int_impl_type'] | ||
|
|
||
| # Scale and zero-point are the only float elements in the state dict | ||
| for k, v in quant_config.items(): | ||
| if not isinstance(v, torch.Tensor): | ||
| if k == 'scale' or k == 'zero_point': | ||
| quant_config[k] = torch.tensor(v) | ||
| else: | ||
| quant_config[k] = torch.tensor(v, dtype=torch.int) | ||
|
|
||
| # Shapes must be set otherwise the state dict loading will fail | ||
| scale_shape = quant_config['scale'].shape | ||
| zero_point_shape = quant_config['zero_point'].shape | ||
| quant_class_type = class_mapping[quant_class_name] | ||
| quant_class = quant_class_type(scale_shape, zero_point_shape) | ||
|
|
||
| # Set the remaining attributes | ||
| quant_class.float_to_int_impl_type = float_to_int_impl_type | ||
| quant_class.load_state_dict(quant_config) | ||
| return quant_class | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: List[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| out_per_partition = sum(output_partition_sizes) | ||
| w = torch.empty( | ||
| (out_per_partition, input_size_per_partition), | ||
| device="cuda", | ||
| dtype=params_dtype, | ||
| ) | ||
|
|
||
| layer.weight = torch.nn.Parameter(w, requires_grad=False) | ||
|
|
||
| # Handling the packed weights for loading | ||
| base_loader = extra_weight_attrs.get("weight_loader", None) | ||
|
|
||
| def packed_weight_loader(param, loaded_weight, loaded_shard_id=None, *args, **kwargs): | ||
|
|
||
| if loaded_shard_id is not None: | ||
| if isinstance(loaded_shard_id, int): | ||
| _loaded_shard_id = loaded_shard_id | ||
| else: | ||
| if loaded_shard_id == "q": | ||
| _loaded_shard_id = 0 | ||
| elif loaded_shard_id == "k": | ||
| _loaded_shard_id = 1 | ||
| elif loaded_shard_id == "v": | ||
| _loaded_shard_id = 2 | ||
| else: | ||
| raise ValueError(f"Invalid loaded_shard_id: {loaded_shard_id}") | ||
|
|
||
| logical_widths = list(output_partition_sizes) | ||
| start_idx = sum(logical_widths[:_loaded_shard_id]) | ||
| end_idx = start_idx + logical_widths[_loaded_shard_id] | ||
| weight_quant = self.weight_quant[_loaded_shard_id] | ||
| else: | ||
| start_idx = 0 | ||
| end_idx = out_per_partition | ||
| weight_quant = self.weight_quant | ||
| if weight_quant is not None: | ||
| loaded_weight = weight_quant(loaded_weight.cuda())[0].cpu() | ||
|
|
||
| if base_loader is not None: | ||
| return base_loader(param[start_idx:end_idx], loaded_weight, *args, **kwargs) | ||
| param[start_idx:end_idx].data.copy_(loaded_weight) | ||
|
|
||
| setattr(layer.weight, "weight_loader", packed_weight_loader) | ||
|
|
||
| # If this layer has bias, allocate it | ||
| if getattr(layer, "bias", None) is not None: | ||
| b = torch.empty((out_per_partition,), device="cuda", dtype=params_dtype) | ||
| layer.bias = torch.nn.Parameter(b, requires_grad=False) | ||
| base_bias_loader = extra_weight_attrs.get("bias_loader", None) | ||
|
|
||
| def packed_bias_loader(param, loaded_bias, *args, **kwargs): | ||
| if isinstance(loaded_bias, (list, tuple)): | ||
| loaded_bias = torch.cat(list(loaded_bias), dim=0) | ||
| if base_bias_loader is not None: | ||
| return base_bias_loader(param, loaded_bias, *args, **kwargs) | ||
| param.data.copy_(loaded_bias) | ||
|
|
||
| setattr(layer.bias, "bias_loader", packed_bias_loader) | ||
|
|
||
| # Preserve attrs that vLLM weight loaders may attach | ||
| for k, v in extra_weight_attrs.items(): | ||
| if k in ("weight_loader", "bias_loader"): | ||
| continue | ||
| setattr(layer.weight, k, v) | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| # x = self.rotation.rotation_forward(x) | ||
| x = self.input_quant(x) | ||
| bias = self.bias_quant(bias) if bias is not None else None | ||
| y = x.matmul(layer.weight.t()) | ||
| if bias is not None: | ||
| y = y + bias | ||
| y = self.output_quant(y) | ||
| return y |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like vLLM should be an optional dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can do it in a similar way to what we did for lighteval/lm_eval
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm leaving it for now so that test run and I can see what other things I'm breaking in the process, but I'll remove before this PR is merged