diff --git a/configs/local_setup.yml b/configs/local_setup.yml index b8ec4b06a..63d570a6f 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -22,6 +22,10 @@ "load": "checkpoints", "checkpoint_validation_with_forward_pass": False, + + # "launcher": "openmpi", + #"deepspeed_mpi": true, + "tensorboard_dir": "tensorboard", "log_dir": "logs", } diff --git a/configs/rwkv/1.5B.yml b/configs/rwkv/1.5B.yml new file mode 100644 index 000000000..473bde88e --- /dev/null +++ b/configs/rwkv/1.5B.yml @@ -0,0 +1,103 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 24, + "hidden_size": 2048, + "num_attention_heads": 32, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 4, + + "attention_config": [[["rwkv"], 24]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + "seed": 1234, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/configs/rwkv/430M.yml b/configs/rwkv/430M.yml new file mode 100644 index 000000000..a42e1796b --- /dev/null +++ b/configs/rwkv/430M.yml @@ -0,0 +1,103 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 24, + "hidden_size": 1024, + "num_attention_heads": 16, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 1, + + "attention_config": [[["rwkv"], 24]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + "seed": 1234, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/configs/rwkv/7B.yml b/configs/rwkv/7B.yml new file mode 100644 index 000000000..7e999d250 --- /dev/null +++ b/configs/rwkv/7B.yml @@ -0,0 +1,102 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 64, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 8, + + "attention_config": [[["rwkv"], 32]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 500, + "lr_decay_iters": 500, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 7899048db..1b6aa9b54 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -258,6 +258,7 @@ def init_specs(self): LayerSpec( RWKVResidualLayerPipe, neox_args=self.neox_args, + init_method=self.init_method, layer_number=i, ) ) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index b3741a3fc..88d99cd86 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -7,7 +7,17 @@ import torch.nn as nn from torch.nn import functional as F from torch.utils.cpp_extension import load - +from megatron import mpu +from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region +try: + from fla.ops.rwkv6 import chunk_rwkv6 + import einops +except ModuleNotFoundError: + print( + "Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \ + or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels." + ) + pass class WKV(torch.autograd.Function): """ @@ -95,6 +105,18 @@ def backward(ctx, gy): def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): return WKV.apply(B, T, C, H, r, k, v, w, u) +@torch.compiler.disable(recursive=True) +# torch.compiler introduces errors in numerical precision (torch 2.4) +def RUN_FLA_CHUNK(B, T, C, H, r, k, v, w, u, h=None, scale=1.0, chunk_size=32): + r = r.view(B,T,H,-1).transpose(1,2) + k = k.view(B,T,H,-1).transpose(1,2) + v = v.view(B,T,H,-1).transpose(1,2) + # u can be 3d or 2d (B, H, -1) or just (H, -1) to save VRAM + w = -torch.exp(w.view(B,T,H,-1).transpose(1,2)) + # change to scale=-1.0 when using fp16, this will apply scale to r and k. + o, final_state = chunk_rwkv6(r, k, v, w, u=u, scale=scale, initial_state=h, + output_final_state=False, chunk_size=chunk_size) #initial_state=None and output_final_state=False for rwkv6 + return o.transpose(1,2).reshape(B,T,C), final_state # RWKV6 time mix class RWKV_TimeMix(nn.Module): @@ -104,7 +126,7 @@ class RWKV_TimeMix(nn.Module): TODO: fix jit compiling. """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number @@ -172,14 +194,46 @@ def __init__(self, neox_args, layer_number): ) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.receptance = nn.Linear( - neox_args.hidden_size, neox_args.dim_att, bias=False - ) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) - - self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) - self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) - self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.receptance = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + self.key = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + self.value = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + self.output = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.dim_att, + output_size=neox_args.hidden_size, + gather_output=True, + init_method=init_method, + bias=False, + ) + self.gate = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=True, + init_method=init_method, + bias=False, + ) self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -200,13 +254,15 @@ def jit_func(self, x): xr = x + xx * (self.time_maa_r + mr) xg = x + xx * (self.time_maa_g + mg) - r = self.receptance(xr) - k = self.key(xk) - v = self.value(xv) - g = F.silu(self.gate(xg)) + r, _ = self.receptance(xr) + k, _ = self.key(xk) + v, _ = self.value(xv) + gated, _ = self.gate(xg) + g = F.silu(gated) ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww + w = scatter_to_model_parallel_region(w) return r, k, v, g, w @@ -215,28 +271,39 @@ def jit_func_2(self, x, g): x = x.view(B * T, C) x = self.ln_x(x).view(B, T, C) - x = self.output(x * g) + x, _ = self.output(x * g) + return x def forward(self, x): B, T, C = x.size() - H = self.neox_args.num_attention_heads + C_tp = C//mpu.get_model_parallel_world_size() + H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() r, k, v, g, w = self.jit_func(x) - x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) + if self.neox_args.rwkv_fla: + x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) + else: + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) + + x = gather_from_model_parallel_region(x) return self.jit_func_2(x, g) -class RWKV_ChannelMix(nn.Module): +class ParallelRWKV_ChannelMix(nn.Module): """ Channel Mix layer. The ffn in RWKV """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number + + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix @@ -247,21 +314,43 @@ def __init__(self, neox_args, layer_number): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) - self.receptance = nn.Linear( - neox_args.hidden_size, neox_args.hidden_size, bias=False - ) - self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) + self.key = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.ffn_dim, + gather_output=False, + init_method=init_method, + bias=False, + ) + + self.receptance = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size, + gather_output=True, + init_method=init_method, + bias=False + ) + self.value = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.ffn_dim, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=init_method, + parallel_output=False, + bias=False + ) def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k xr = x + xx * self.time_maa_r - k = self.key(xk) + k, _ = self.key(xk) k = torch.relu(k) ** 2 - kv = self.value(k) - return torch.sigmoid(self.receptance(xr)) * kv + kv, _ = self.value(k) + receptance, _ = self.receptance(xr) + return torch.sigmoid(receptance) * kv class RWKVResidualLayer(nn.Module): @@ -269,7 +358,7 @@ class RWKVResidualLayer(nn.Module): RWKV layer definition """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args self.layer_number = layer_number @@ -277,8 +366,8 @@ def __init__(self, neox_args, layer_number): self.bf16 = neox_args.precision == "bfloat16" assert ( neox_args.intermediate_size == None or neox_args.expansion_factor == None - ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - if not hasattr(neox_args, "dim_att"): + ), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv" + if not neox_args.dim_att: neox_args.dim_att = neox_args.hidden_size if neox_args.intermediate_size: neox_args.ffn_dim = neox_args.intermediate_size @@ -297,15 +386,16 @@ def __init__(self, neox_args, layer_number): self.num_attention_heads = neox_args.num_attention_heads assert neox_args.dim_att % self.num_attention_heads == 0 + self.init_method = init_method if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p=neox_args.attention_dropout) self.ln1 = nn.LayerNorm(neox_args.hidden_size) self.ln2 = nn.LayerNorm(neox_args.hidden_size) - self.att = RWKV_TimeMix(neox_args, layer_number) + self.att = RWKV_TimeMix(neox_args, layer_number, init_method=init_method) - self.ffn = RWKV_ChannelMix(neox_args, layer_number) + self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p=neox_args.attention_dropout) @@ -313,27 +403,28 @@ def __init__(self, neox_args, layer_number): self.drop1 = nn.Dropout(p=neox_args.hidden_dropout) if layer_number == 0: - global wkv_cuda - """ - Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. - """ - wkv_cuda = load( - name="wkv6", - sources=[ - "megatron/model/rwkv/v6/cuda/wkv6_op.cpp", - f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", - ], - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={self.neox_args.head_size}", - f"-D_T_={self.neox_args.seq_length}", - ], - ) + if not self.neox_args.rwkv_fla: + global wkv_cuda + """ + Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. + """ + wkv_cuda = load( + name="wkv6", + sources=[ + "megatron/model/rwkv/v6/cuda/wkv6_op.cpp", + f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={self.neox_args.head_size}", + f"-D_T_={self.neox_args.seq_length}", + ], + ) def forward(self, x): neox_args = self.neox_args @@ -353,7 +444,6 @@ def forward(self, x): return x - class RWKVResidualLayerPipe(RWKVResidualLayer): """ RWKV Pipeline Layer @@ -363,4 +453,9 @@ def forward(self, args): assert len(args) == 2 hidden_states, mask = args neox_args = self.neox_args - return super().forward(hidden_states), mask + if self.layer_number == 0: + hidden_states = hidden_states.transpose(0,1) + hidden_states = super().forward(hidden_states) + if self.layer_number == self.neox_args.num_layers-1: + hidden_states = hidden_states.transpose(0,1) + return hidden_states, mask diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index f3daacd4d..c0a33d4a4 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1113,17 +1113,11 @@ def calculate_derived(self): if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba" assert ( - self.hidden_dropout == 0.0, + self.hidden_dropout != 0.0, ), "Mamba does not yet have dropout implemented" if "rwkv" in self.attention_config: - assert ( - self.model_parallel_size == 1 - ), "RWKV not currently compatible with model parallelism" if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV" - assert ( - self.hidden_dropout == 0.0, - ), "RWKV does not yet have dropout implemented" # Sparsity config if self.sparsity_config is None: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 9c8d3635f..4846a5718 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate): } """ + rwkv_fla: bool = False + """ + Whether to use the Flash Linear Attention implementation of the RWKV kernel, or the CUDA kernel version. + """ + num_unique_layers: int = None """ Number of unique transformer layers. num-layers should be divisible by this value. Currently only has an effect when pipe_parallel_size=0. @@ -497,7 +502,6 @@ class NeoXArgsModel(NeoXArgsTemplate): # Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905) output_layer_parallelism: Literal["column"] = "column" - """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ diff --git a/requirements/requirements-rwkv.txt b/requirements/requirements-rwkv.txt new file mode 100644 index 000000000..38c786d5b --- /dev/null +++ b/requirements/requirements-rwkv.txt @@ -0,0 +1 @@ +git+https://github.com/sustcsonglin/flash-linear-attention \ No newline at end of file