diff --git a/docs/source/package_reference/weightlora.md b/docs/source/package_reference/weightlora.md new file mode 100644 index 0000000000..12ef15dfb7 --- /dev/null +++ b/docs/source/package_reference/weightlora.md @@ -0,0 +1,31 @@ + + +# WeightLoRA + +Weight LoRA is a less complex, but important, PEFT method that adds a weight $w_i$ to each LoRA adapter (here i -- adapter number). This is done in order to perform, in addition to the classical optimisation over all LoRAs $A_1, B_1, ..., A_n, B_n$, an alternative optimisation over a vector of weights $w := (w_1, ..., w_n)^T \in R^n$ with a wide variety of constraints. In our research paper, we consider two approaches: 1) the vector $w$ must be in simplex $\Delta_{n-1}$, and 2) the vector $w$ has only $K$ non-zero coordinates. Both of these methods solve the problem of finding the most important LoRA adapters in the model and concentrating training on them while disabling the rest. + +The abstract from the paper is: + +The widespread utilization of language models in modern applications is inconceivable without Parameter-Efficient Fine-Tuning techniques, such as low-rank adaptation (LoRA), which adds trainable adapters to selected layers. Although LoRA may obtain accurate solutions, it requires significant memory to train large models and intuition on which layers to add adapters. In this paper, we propose a novel method, WeightLoRA, which overcomes this issue by adaptive selection of the most critical LoRA heads throughout the optimization process. As a result, we can significantly reduce the number of trainable parameters while maintaining the capability to obtain consistent or even superior metric values. Finally, we conduct experiments for the series of competitive benchmarks and DeBERTa and BART models, comparing our approach with the most popular LoRA modifications. The experimental results demonstrate the efficacy of WeightLoRA and the superior performance of WeightLoRA+ in comparison to the baselines in nearly all cases. + +## WeightLoraConfig + +[[autodoc]] tuners.weight_lora.config.WeightLoraConfig + +## WeightLoraModel + +[[autodoc]] tuners.weight_lora.model.WeightLoraModel \ No newline at end of file diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 4ceb3dd651..31ffddaca6 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -95,6 +95,8 @@ VeraModel, XLoraConfig, XLoraModel, + WeightLoraConfig, + WeightLoraModel, get_eva_state_dict, initialize_lora_eva_weights, ) @@ -188,6 +190,8 @@ "VeraModel", "XLoraConfig", "XLoraModel", + "WeightLoraConfig", + "WeightLoraModel", "bloom_model_postprocess_past_key_value", "cast_mixed_precision_params", "get_eva_state_dict", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 65abbd4046..bffed203c7 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -43,6 +43,7 @@ from .vblora import VBLoRAConfig, VBLoRAModel from .vera import VeraConfig, VeraModel from .xlora import XLoraConfig, XLoraModel +from .weight_lora import WeightLoraConfig, WeightLoraModel __all__ = [ @@ -97,6 +98,8 @@ "VeraModel", "XLoraConfig", "XLoraModel", + "WeightLoraConfig", + "WeightLoraModel", "get_eva_state_dict", "initialize_lora_eva_weights", ] diff --git a/src/peft/tuners/weight_lora/__init__.py b/src/peft/tuners/weight_lora/__init__.py new file mode 100644 index 0000000000..fa7842837c --- /dev/null +++ b/src/peft/tuners/weight_lora/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.utils import register_peft_method + +from .config import WeightLoraConfig +from .layer import Linear, WeightLoraLayer +from .model import WeightLoraModel + +__all__ = ["Linear", "WeightLoraConfig", "WeightLoraLayer", "WeightLoraModel"] + +register_peft_method(name="weightlora", config_cls=WeightLoraConfig, model_cls=WeightLoraModel, prefix="weight_lora_", is_mixed_compatible=True) \ No newline at end of file diff --git a/src/peft/tuners/weight_lora/config.py b/src/peft/tuners/weight_lora/config.py new file mode 100644 index 0000000000..75aa0c8f9f --- /dev/null +++ b/src/peft/tuners/weight_lora/config.py @@ -0,0 +1,110 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.tuners.lora import LoraConfig +from peft.utils import PeftType + + +@dataclass +class WeightLoraConfig(LoraConfig): + """ + Configuration class of [`WeightLoraModel`]. + + Args: + r (`int`): + Lora rank. + lora_alpha (`int`): + The alpha parameter for Lora scaling. + rank_dropout (`float`): + The dropout probability for rank dimension during training. + module_dropout (`float`): + The dropout probability for disabling Lora modules during training. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + init_weights (`bool`): + Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is + discouraged. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `alpha`. + modules_to_save (`Optional[List[str]]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + """ + + r: int = field(default=8, metadata={"help": "Lora rank"}) + lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) + rank_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for rank dimension during training"} + ) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling Lora modules during training"} + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with Lora." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the Lora layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from Lora layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.WEIGHTLORA diff --git a/src/peft/tuners/weight_lora/layer.py b/src/peft/tuners/weight_lora/layer.py new file mode 100644 index 0000000000..49ca5899d7 --- /dev/null +++ b/src/peft/tuners/weight_lora/layer.py @@ -0,0 +1,201 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Optional, Set, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.lycoris_utils import LycorisLayer + + +class WeightLoraLayer(nn.Module, LycorisLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ( + "weight_lora_A", + "weight_lora_B", + "weight_lora_w", + ) + # other_param_names is defined on parent class + + def __init__(self, base_layer: nn.Module) -> None: + super().__init__() + LycorisLayer.__init__(self, base_layer) + + # Weight LoRA info + self.weight_lora_A = nn.ParameterDict({}) + self.weight_lora_B = nn.ParameterDict({}) + self.weight_lora_w = nn.ParameterDict({}) + + @property + def _available_adapters(self) -> Set[str]: + return { + *self.weight_lora_A, + *self.weight_lora_B, + *self.weight_lora_w, + } + + def create_adapter_parameters( + self, + adapter_name: str, + r: int, + shape + ): + self.weight_lora_A[adapter_name] = nn.Parameter(torch.empty(shape[0], r)) + self.weight_lora_B[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) + self.weight_lora_w[adapter_name] = nn.Parameter(torch.empty(1)) + + def reset_adapter_parameters(self, adapter_name: str): + # Vanilla LoRA initialization + nn.init.kaiming_uniform_(self.weight_lora_A[adapter_name], a=math.sqrt(5)) + nn.init.zeros_(self.weight_lora_B[adapter_name]) + nn.init.ones_(self.weight_lora_w[adapter_name]) + + def update_layer( + self, + adapter_name: str, + r: int, + lora_alpha: float, + rank_dropout: float, + module_dropout: float, + **kwargs, + ) -> None: + """Internal function to create Weight LoRA adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + lora_alpha (`float`): Alpha for the added adapter. + rank_dropout (`float`): The dropout probability for rank dimension during training + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize adapter weights. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.alpha[adapter_name] = lora_alpha + self.scaling[adapter_name] = lora_alpha / r + self.rank_dropout[adapter_name] = rank_dropout + self.module_dropout[adapter_name] = module_dropout + base_layer = self.get_base_layer() + + # Determine shape of Weight LoRA weights + if isinstance(base_layer, nn.Linear): + shape = (base_layer.in_features, base_layer.out_features) + else: + raise TypeError(f"WeightLoRA is not implemented for base layers of type {type(base_layer).__name__}") + + # Create weights with provided shape + self.create_adapter_parameters(adapter_name, r, shape) + + # Initialize weights + self.reset_adapter_parameters(adapter_name) + + # Move new weights to device + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + device = self.weight_lora_B[adapter_name].device + dtype = self.weight_lora_B[adapter_name].dtype + w_A = self.weight_lora_A[adapter_name] + w_B = self.weight_lora_B[adapter_name] + w = self.weight_lora_w[adapter_name] + + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + if cast_to_fp32: + w_A = w_A.float() + w_B = w_B.float() + + # Combine marixes + weight = w * w_A @ w_B * self.scaling[adapter_name] + weight = weight.T + if cast_to_fp32: + weight = weight.to(dtype=dtype) + + self.lora_A[adapter_name].weight.data = w_A.to(dtype) + self.lora_B[adapter_name].weight.data = w_B.to(dtype) + + # Perform rank dropout during training - drop rows of addition weights + rank_dropout = self.rank_dropout[adapter_name] + if self.training and rank_dropout: + drop = (torch.rand(weight.size(0)) > rank_dropout).float() + drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) + drop /= drop.mean() + weight *= drop + + return weight + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) + + result = result.to(previous_dtype) + return result + + +class Linear(WeightLoraLayer): + """WeightLoRA implemented in Linear layer""" + + def __init__( + self, + base_layer: nn.Module, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + adapter_name: str = "default", + r: int = 0, + lora_alpha: float = 1.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, rank_dropout, module_dropout, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + # don't add bias here, because the bias is already included in the output of the base_layer + delta_weight = delta_weight.to(input.dtype) + return F.linear(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "weight_lora." + rep diff --git a/src/peft/tuners/weight_lora/model.py b/src/peft/tuners/weight_lora/model.py new file mode 100644 index 0000000000..096d8fd116 --- /dev/null +++ b/src/peft/tuners/weight_lora/model.py @@ -0,0 +1,93 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from itertools import chain +from typing import Dict, Type, Union + +import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner + +from .layer import Linear, WeightLoraLayer + + +class WeightLoraModel(LycorisTuner): + """ + Creates Weighted Low-Rank model from a pretrained model. The original method could be roughtly described as `W_new = W_old + w_i * AB`, where w_i is a weight (scalar) of the layer `i`. + Current implementation heavily borrows + from + https://github.com/huggingface/peft/tree/main/src/peft/tuners/lokr + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`WeightLoRAConfig`]): The configuration of the WeightLoRA model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The WeightLoRA model. + + Example: + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import WeightLoraConfig, get_peft_model + + >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> config = WeightLoraConfig( + ... task_type="SEQ_CLS", + ... r=4, + ... target_modules=["fc1", "fc2", "k_proj", "out_proj", "q_proj", "v_proj"], + ... lora_alpha=32, + ... lora_dropout=0.05, + ... ) + >>> model = get_peft_model(base_model, config) + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`WeightLoraConfig`]): The configuration of the WeightLoRA model. + """ + + prefix: str = "weight_lora_" + layers_mapping: Dict[Type[torch.nn.Module], Type[WeightLoraLayer]] = { + torch.nn.Linear: Linear, + } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[WeightLoraLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + kwargs["lora_alpha"] = config.alpha_pattern.get(target_name_key, config.lora_alpha) + + if isinstance(target, WeightLoraLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) \ No newline at end of file diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index a205bd4550..4740de48c1 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -40,6 +40,7 @@ class PeftType(str, enum.Enum): - FOURIERFT - HRA - BONE + - WEIGHTLORA """ PROMPT_TUNING = "PROMPT_TUNING" @@ -64,6 +65,7 @@ class PeftType(str, enum.Enum): CPT = "CPT" BONE = "BONE" TRAINABLE_TOKENS = "TRAINABLE_TOKENS" + WEIGHTLORA = "WEIGHTLORA" class TaskType(str, enum.Enum):