Skip to content

feat: add IA3 prompt tuning #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/local_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
"log-dir": "logs",
"use_wandb": True,
"wandb_host": "https://api.wandb.ai",
"wandb_project": "neox"
"wandb_project": "neox",
"ia3_tuning": False
}
3 changes: 3 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def load_checkpoint(
):
"""Load a model checkpoint and return the iteration."""
if neox_args.deepspeed:
if neox_args.ia3_tuning:
neox_args.load_module_strict = False
load_optim_and_scheduler = (
not neox_args.no_load_optim
) # TODO: These should be configured by separate args
Expand All @@ -241,6 +243,7 @@ def load_checkpoint(
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler,
tag=tag,
load_module_strict=neox_args.load_module_strict
)

if checkpoint_name is None:
Expand Down
Binary file added megatron/model/.transformer.py.swp
Binary file not shown.
238 changes: 233 additions & 5 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Transformer."""

import math
import sys
import torch
import torch.nn.functional as F
import torch.nn as nn
Expand Down Expand Up @@ -88,20 +89,21 @@ def __init__(

# auto scale so geglu has equal parameters
ff_mult = 4 * 2 / 3 if self.activation_type == "geglu" else 4
ff_dim = (
self.ff_dim = (
int(ff_mult * neox_args.hidden_size) * 2
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)

self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
output_size=self.ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
ff_dim_in = self.ff_dim // 2 if self.activation_type == "geglu" else self.ff_dim
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
neox_args=neox_args,
Expand Down Expand Up @@ -134,6 +136,56 @@ def forward(self, hidden_states):
return output, output_bias


class ParallelMLPIA3(ParallelMLP):
"""MLP.

MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.

Applies IA3 rescaling of each column after non-linearity:
https://arxiv.org/pdf/2205.05638.pdf
"""

def __init__(
self, neox_args, init_method, output_layer_init_method, parallel_output=False
):
super().__init__(
neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=parallel_output
)

world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(self.ff_dim, world_size) # 4hp
self.l_ff = create_ia3_parameter(self.hidden_size_per_partition, neox_args)

def forward(self, hidden_states):

# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if (
self.activation_type == "gelu" and self.bias_gelu_fusion
) or self.activation_type == "geglu":
intermediate_parallel = self.activation_func(
intermediate_parallel, bias_parallel
)
else:
intermediate_parallel = self.activation_func(
intermediate_parallel + bias_parallel
)

# Apply IA3 rescaling:
intermediate_parallel *= self.l_ff

# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias


class ParallelLinear(nn.Module):
"""
A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size
Expand Down Expand Up @@ -590,6 +642,154 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
return output, bias


class ParallelSelfAttentionIA3(ParallelSelfAttention):
"""Applies IA3 rescaling to key and query vectors per:
https://arxiv.org/pdf/2205.05638.pdf
"""
def __init__(
self,
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=None,
rotary=False,
use_cache=False,
parallel_output=False,
):
super().__init__(
neox_args,
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
rpe=rpe,
rotary=rotary,
use_cache=use_cache,
parallel_output=parallel_output,
)
self.l_k = create_ia3_parameter(self.hidden_size_per_partition, neox_args)
self.l_v = create_ia3_parameter(self.hidden_size_per_partition, neox_args)


def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]

# =====================
# Query, Key, and Value
# =====================

# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(
mixed_x_layer, 3
)

def _apply_ia3_rescaling(layer, scale_vector):
"""Apply IA3 rescaling:

Reshapes: [sq, b, np, hn] -> [sq, b, np * hn] to perform
rescaling and then back to [sq, b, np, hn].

Note: np * hn == h/p == self.hidden_size_per_partition
"""
layer_size = layer.shape
layer = layer.reshape(layer_size[0], layer_size[1], -1)
layer *= scale_vector
return layer.reshape(layer_size)

key_layer = _apply_ia3_rescaling(key_layer, self.l_k)
value_layer = _apply_ia3_rescaling(value_layer, self.l_v)

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
query_rot, query_pass = (
query_layer[..., : self.rotary_ndims],
query_layer[..., self.rotary_ndims :],
)
key_rot, key_pass = (
key_layer[..., : self.rotary_ndims],
key_layer[..., self.rotary_ndims :],
)
else:
# full rotary
query_rot, key_rot = query_layer, key_layer
apply_rotary_fn = (
apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
)

seq_len = key_layer.shape[0]
offset = 0
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(
query_rot, key_rot, cos, sin, offset=offset
)

if exists(self.rotary_ndims):
query_layer = torch.cat((query_layer, query_pass), dim=-1)
key_layer = torch.cat((key_layer, key_pass), dim=-1)

# ==================================
# Cache key and value for inference
# ==================================

if exists(layer_past) and layer_past.numel() > 0:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
value_layer = torch.cat(
(past_value.type_as(value_layer), value_layer), dim=0
)

if self.use_cache:
present = torch.stack((key_layer, value_layer))

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
else:
context_layer = self.sparse_attention(
query_layer, key_layer, value_layer, attention_mask
)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)

# =================
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)

if self.use_cache:
output = [output, present]

return output, bias


class ParallelTransformerLayer(nn.Module):
"""A single transformer layer.

Expand Down Expand Up @@ -625,9 +825,10 @@ def __init__(

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
self_attention_cls = getattr(sys.modules[__name__], neox_args.self_attention_cls)

# Self attention.
self.attention = ParallelSelfAttention(
self.attention = self_attention_cls(
neox_args=neox_args,
attention_mask_func=attention_mask_func,
init_method=init_method,
Expand All @@ -645,7 +846,8 @@ def __init__(
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)

# MLP
self.mlp = ParallelMLP(
parallel_mlp_cls = getattr(sys.modules[__name__], neox_args.parallel_mlp_cls)
self.mlp = parallel_mlp_cls(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
Expand Down Expand Up @@ -804,3 +1006,29 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non
return logits_parallel

return mpu.gather_from_model_parallel_region(logits_parallel)


def create_ia3_parameter(param_size, neox_args):
"""Create a parameter vector for use in IA3 scaling, per:
https://arxiv.org/pdf/2205.05638.pdf
"""
if neox_args.use_cpu_initialization:
param = torch.nn.Parameter(
torch.empty(
param_size, dtype=neox_args.params_dtype
)
)
else:
param = torch.nn.Parameter(
torch.empty(
param_size,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
param.model_parallel = True
param.partition_dim = 0
# Always initialize to ones.
with torch.no_grad():
torch.nn.init.ones_(param)
return param
4 changes: 2 additions & 2 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def get_params_for_weight_decay_optimization(module, neox_args):
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
if p is not None and n not in neox_args.no_weight_decay_params
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
if p is not None and n in neox_args.no_weight_decay_params
]
)
if neox_args.weight_decay == 0.0:
Expand Down
27 changes: 26 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import subprocess
from dataclasses import dataclass
from dataclasses import dataclass, field

try:
from .template import NeoXArgsTemplate
Expand Down Expand Up @@ -355,11 +355,36 @@ class NeoXArgsModel(NeoXArgsTemplate):
"""

output_layer_parallelism: Literal["row", "column"] = "row"
ia3_tuning: bool = False
"""
Run IA3 tuning based off:
Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning
https://arxiv.org/pdf/2205.05638.pdf
"""

"""
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

self_attention_cls: str = "ParallelSelfAttention"
"""
Default class to use for self attention
"""

parallel_mlp_cls: str = "ParallelMLP"
"""
Default class to use for linear MLP parallelism
"""

no_weight_decay_params: list = field(default_factory=lambda: ["bias", "l_ff", "l_v", "l_k"])
"""
Which parameters we won't apply weight decay to
"""

load_module_strict: bool = True
"""
Whether to strictly enforce that the keys in state_dict of module & checkpoint match.
"""

@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
Expand Down
Loading