Skip to content
Closed
6 changes: 4 additions & 2 deletions docs/source/training_tutorials/sft_lora_finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def training_function(script_args, training_args):
model = AutoModelForCausalLM.from_pretrained(script_args.model_id)

config = LoraConfig(
r=16,
r=64,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
# target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
target_modules=["q_proj", "k_proj", "v_proj"],
# target_modules=["o_proj"],
bias="none",
task_type="CAUSAL_LM",
)
Expand Down
18 changes: 18 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
from transformers import PreTrainedModel
from transformers.utils import is_peft_available

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
Expand Down Expand Up @@ -607,6 +608,21 @@ def parallelize(
skip_linear_weight_load = hasattr(model, "_weight_map")

requires_grad_information = {n: p.requires_grad for n, p in model.named_parameters()}
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

peft_parameters = set()
for mod in model.modules():
if isinstance(mod, BaseTunerLayer):
base_layer = mod.get_base_layer()
for m in mod.modules():
if m is base_layer:
continue
for p in m.parameters():
peft_parameters.add(p)
peft_parameter_names = {n for n, p in model.named_parameters() if p in peft_parameters}
else:
peft_parameter_names = set()

def should_parallelize_layer_predicate_func(layer):
if pp_size == 1:
Expand Down Expand Up @@ -757,6 +773,8 @@ def should_parallelize_layer_predicate_func(layer):
elif gqa_qkv_names_to_original_names.get(name, None) in requires_grad_information:
gqa_qkv_name = gqa_qkv_names_to_original_names[name]
parameter.requires_grad = requires_grad_information[gqa_qkv_name]
elif name in peft_parameter_names:
continue
else:
raise ValueError(
f"Could not find information for the parameter {name} to set its `requires_grad` attribute."
Expand Down
139 changes: 17 additions & 122 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@
from ..utils.misc import is_main_worker
from ..utils.require_utils import requires_neuronx_distributed
from .utils import (
FakeProj,
OptimumGQAQKVColumnParallelLinear,
WeightInformation,
embedding_to_parallel_embedding,
get_linear_weight_info,
inplace_linears_to_gqa_qkv_column_parallel_linear,
linear_to_parallel_linear,
mark_parameter_init_status_during_parallelization,
maybe_load_weights_to_gqa_qkv_column_parallel_linear,
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear,
)

Expand Down Expand Up @@ -327,124 +325,6 @@ class ParallelSelfAttention(ParallelLayer):

GQA_QKV_PROJ_NAME: str = "qkv_proj"

@classmethod
def get_layer_qualified_name(cls, model: torch.nn.Module, layer: torch.nn.Module) -> str:
layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()}
return layer_to_fully_qualified_name[id(layer)]

@classmethod
def patch_proj_to_use_gqa_qkv_column_parallel_linear(
cls,
attention_layer: torch.nn.Module,
attention_layer_qualified_name: str,
proj_qualified_name: str,
proj_name: str,
output_index: int,
):
fake_proj = FakeProj(
proj_qualified_name,
proj_name,
output_index,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)

setattr(attention_layer, proj_name, fake_proj)

@classmethod
@requires_neuronx_distributed
def replace_qkv_by_gqa_qkv_column_parallel_linear(
cls,
model: "torch.nn.Module",
attention_layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
kv_size_multiplier: Optional[int] = None,
skip_linear_weight_load: bool = False,
):
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size

if cls.NUM_KEY_VALUE_HEADS_NAME is None:
raise ValueError(f"{cls} does not defined the name of the number of key value heads.")
tp_size = get_tensor_model_parallel_size()
num_key_value_heads = getattr(attention_layer, cls.NUM_KEY_VALUE_HEADS_NAME)
if tp_size < num_key_value_heads:
raise ValueError(
f"The TP size ({tp_size}) is lower than the number of key value heads, using "
"GQAQKVColumnParallelLinear is not needed."
)

num_attention_heads = getattr(attention_layer, cls.NUM_ATTENTION_HEADS_NAME)
query_linear = getattr(attention_layer, cls.QUERIES_NAME)
key_linear = getattr(attention_layer, cls.KEYS_NAME)

hidden_size = query_linear.weight.size(1)
query_in_features = query_linear.weight.size(0)
key_value_in_features = key_linear.weight.size(0)

if kv_size_multiplier is None:
kv_size_multiplier = get_tensor_model_parallel_size() // num_key_value_heads

device = query_linear.weight.device
if device == torch.device("meta"):
device = None

gqa_qkv_column_parallel_linear = OptimumGQAQKVColumnParallelLinear(
cls.QUERIES_NAME,
cls.KEYS_NAME,
cls.VALUES_NAME,
cls.OUTPUT_PROJECTION_NAME,
num_attention_heads,
num_key_value_heads,
hidden_size,
[query_in_features, key_value_in_features],
gather_output=False,
bias=query_linear.bias is not None,
sequence_parallel_enabled=sequence_parallel_enabled,
device=device,
kv_size_multiplier=kv_size_multiplier,
)

setattr(attention_layer, cls.GQA_QKV_PROJ_NAME, gqa_qkv_column_parallel_linear)

maybe_load_weights_to_gqa_qkv_column_parallel_linear(
model,
gqa_qkv_column_parallel_linear,
try_from_checkpoint=not skip_linear_weight_load,
try_from_original_layer=not skip_linear_weight_load,
)

attention_layer_qualified_name = cls.get_layer_qualified_name(model, attention_layer)
fake_q_proj = FakeProj(
f"{attention_layer_qualified_name}.{cls.QUERIES_NAME}",
"q",
0,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)
setattr(attention_layer, cls.QUERIES_NAME, fake_q_proj)

fake_k_proj = FakeProj(
f"{attention_layer_qualified_name}.{cls.KEYS_NAME}",
"k",
1,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)
setattr(attention_layer, cls.KEYS_NAME, fake_k_proj)

fake_v_proj = FakeProj(
f"{attention_layer_qualified_name}.{cls.VALUES_NAME}",
"v",
2,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)
setattr(attention_layer, cls.VALUES_NAME, fake_v_proj)

@classmethod
@requires_neuronx_distributed
def _transform(
Expand Down Expand Up @@ -504,9 +384,24 @@ def _transform(
needs_gqa_qkv_column_parallel_linear = False

if needs_gqa_qkv_column_parallel_linear:
cls.replace_qkv_by_gqa_qkv_column_parallel_linear(
tp_size = get_tensor_model_parallel_size()
if cls.NUM_KEY_VALUE_HEADS_NAME is None:
raise ValueError(f"{cls} does not defined the name of the number of key value heads.")
if tp_size < num_key_value_heads:
raise ValueError(
f"The TP size ({tp_size}) is lower than the number of key value heads, using "
"GQAQKVColumnParallelLinear is not needed."
)
inplace_linears_to_gqa_qkv_column_parallel_linear(
model,
layer,
cls.GQA_QKV_PROJ_NAME,
cls.QUERIES_NAME,
cls.KEYS_NAME,
cls.VALUES_NAME,
cls.OUTPUT_PROJECTION_NAME,
num_attention_heads,
num_key_value_heads,
sequence_parallel_enabled=sequence_parallel_enabled,
kv_size_multiplier=kv_size_multiplier,
skip_linear_weight_load=skip_linear_weight_load,
Expand Down
Loading