From 10dcc3f5f8180e69d5e80e31a8a563ca780418d0 Mon Sep 17 00:00:00 2001 From: dhx Date: Tue, 8 Apr 2025 18:27:49 +0000 Subject: [PATCH 1/3] add fused and chunked linear-loss function Signed-off-by: dhx --- training/DeepSpeed-Domino/domino/arguments.py | 7 +- .../domino/tensor_parallel/cross_entropy.py | 339 +++++++++++++++++- training/DeepSpeed-Domino/pretrain_gpt.py | 30 +- 3 files changed, 356 insertions(+), 20 deletions(-) diff --git a/training/DeepSpeed-Domino/domino/arguments.py b/training/DeepSpeed-Domino/domino/arguments.py index e846726b9..930011864 100644 --- a/training/DeepSpeed-Domino/domino/arguments.py +++ b/training/DeepSpeed-Domino/domino/arguments.py @@ -12,7 +12,7 @@ import dataclasses from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional from domino.timer import Timers from megatron.tokenizer import build_tokenizer @@ -206,6 +206,8 @@ def parse_args(): help='Report loss and timing interval.') parser.add_argument('--save-interval', type=int, default=None, help='Number of iterations between checkpoint saves.') + parser.add_argument('--fused-linear-loss', action='store_true', + help='whether to use LigerFusedLinearCrossEntropyFunction()') args = parser.parse_args() @@ -359,6 +361,8 @@ class TransformerConfig(): no_sync_func: Callable = None # grad_sync_func: Callable = None # param_sync_func: Callable = None + + fused_linear_loss: bool = False def __post_init__(self): """ Python dataclass method that is used to modify attributes after initialization. @@ -400,5 +404,6 @@ def core_transformer_config_from_args(args): kw_args['init_method'] = args.init_method kw_args['output_layer_init_method'] = args.init_method kw_args['params_dtype'] = args.params_dtype + kw_args['fused_linear_loss'] = args.fused_linear_loss return TransformerConfig(**kw_args) diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py b/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py index a87c1f521..dd97132d1 100644 --- a/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py @@ -1,7 +1,8 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# This file is adapted from cross_entropy.py in Megatron-LM - +# This file is adapted from cross_entropy.py in Megatron-LM and fused_linear_cross_entropy.py in Liger-Kernel:src/liger_kernel/ops/ import torch +import triton +import triton.language as tl from domino.parallel_state import ( get_tensor_model_parallel_group, @@ -11,24 +12,29 @@ from .utils import VocabUtility +from liger_kernel.ops.utils import amp_custom_bwd +from liger_kernel.ops.utils import amp_custom_fwd +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip class _VocabParallelCrossEntropy(torch.autograd.Function): @staticmethod - def forward(ctx, logits, target): - max_logits = torch.max(logits, dim=-1)[0] + def forward(ctx, logits, target): + max_logits = torch.max(logits, dim=-1)[0] # [batchsize, seq_len, 1] torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()) + logits = logits - max_logits.unsqueeze(dim=-1) - partition_vocab_size = logits.size()[-1] + partition_vocab_size = logits.size()[-1] # 25216 rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size(partition_vocab_size, rank, world_size) target_mask = (target < vocab_start) | (target >= vocab_end) - adjusted_target = target.clone() - vocab_start + adjusted_target = target.clone() - vocab_start # relative id adjusted_target[target_mask] = 0 - logits_2d = logits.view(-1, partition_vocab_size) + logits_2d = logits.view(-1, partition_vocab_size) # [batchsize * seq_len, vocab_size] adjusted_target_1d = adjusted_target.view(-1) batch_indices = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[batch_indices, adjusted_target_1d].clone().contiguous() @@ -39,10 +45,10 @@ def forward(ctx, logits, target): exp_logits = torch.exp(logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) - - loss = torch.log(sum_exp_logits) - predicted_logits + + loss = torch.log(sum_exp_logits) - predicted_logits # [512, 8] + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, target_mask, adjusted_target_1d) return loss @@ -50,16 +56,323 @@ def forward(ctx, logits, target): @staticmethod def backward(ctx, grad_output): softmax, target_mask, adjusted_target_1d = ctx.saved_tensors - + grad_input = softmax.view(-1, softmax.size()[-1]) batch_indices = torch.arange(start=0, end=grad_input.size()[0], device=grad_input.device) softmax_update = 1.0 - target_mask.view(-1).float() grad_input[batch_indices, adjusted_target_1d] -= softmax_update grad_input = grad_input.view_as(softmax) - grad_input.mul_(grad_output.unsqueeze(dim=-1)) + grad_input.mul_(grad_output.unsqueeze(dim=-1)) return grad_input, None - def vocab_parallel_cross_entropy(vocab_parallel_logits, target): return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) + +MAX_FUSED_SIZE = 65536 // 2 + +def fused_linear_cross_entropy_forward_megatron_chunked( + _input, + weight, + target, + bias=None, + reduction="none", +): + device = _input.device + BT, H = _input.shape + V = weight.shape[0] # [V, H] + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + predicted_logits = torch.zeros(BT, dtype=torch.float32, device=device) + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size(V, rank, world_size) + + target_mask = (target < vocab_start) | (target >= vocab_end) + adjusted_target = target.clone() - vocab_start # relative id + adjusted_target[target_mask] = 0 + adjusted_target_1d = adjusted_target.view(-1) + + handle_grad_input_list = [] + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + # input + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + # when doing matmul, use the original precision + logits_chunk = (_input_chunk @ weight.t()).float() # chunk_size x V # since megatron has .float, I add it here. + + if bias is not None: + logits_chunk = logits_chunk + bias + # handle target + target_chunk = adjusted_target_1d[start_idx:end_idx] # chunk_size, + + # # ensure _input and target are contiguous + # logits_chunk = logits_chunk.contiguous() # [chunk_size, vocab_size] + # target_chunk = target_chunk.contiguous() # [chunk_size] + + max_logits_chunk = torch.max(logits_chunk, dim=-1)[0] + torch.distributed.all_reduce(max_logits_chunk, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group(), async_op=False) + logits_chunk = logits_chunk - max_logits_chunk.unsqueeze(-1) + + sum_exp_logits_chunk = torch.sum(torch.exp(logits_chunk), dim=-1) + torch.distributed.all_reduce(sum_exp_logits_chunk, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=False) + + predicted_logits_chunk = logits_chunk[torch.arange(end_idx-start_idx), target_chunk] + predicted_logits_chunk[target_mask[start_idx:end_idx]] = 0.0 + handle_predicted_logits_chunk = torch.distributed.all_reduce(predicted_logits_chunk, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=True) + + # ==> Compute gradient + grad_logits_chunk = torch.exp(logits_chunk).div_(sum_exp_logits_chunk.unsqueeze(-1)) + grad_logits_chunk[torch.arange(end_idx-start_idx), target_chunk] -= 1.0 - target_mask[start_idx:end_idx].float() # chunk_size x V + grad_input[start_idx:end_idx] = grad_logits_chunk.to(dtype=torch.half) @ weight # fp16 or fp32 will have different memory consumption, loss curves may be the same + + handle_grad_input = torch.distributed.all_reduce(grad_input[start_idx:end_idx], group=get_tensor_model_parallel_group(), async_op=True) + handle_grad_input_list.append(handle_grad_input) + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=grad_logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, # grad_weight accumulation (beta=1.0 brings loss decrease improvement in early iterations) + ) + if bias is not None: + torch.add( + input=grad_bias, + other=grad_logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + handle_predicted_logits_chunk.wait() + predicted_logits[start_idx:end_idx] = predicted_logits_chunk + loss_chunk = torch.log(sum_exp_logits_chunk) - predicted_logits_chunk + loss_1d[start_idx:end_idx] = loss_chunk + + for handle in handle_grad_input_list: + handle.wait() + + if reduction == "none": + loss = loss_1d + else: + loss = torch.sum(loss_1d) + + return loss, None, grad_input, grad_weight, grad_bias + +def fused_linear_cross_entropy_forward_megatron( + _input, + weight, + target, + bias=None, + reduction="none", +): + device = _input.device + BT, H = _input.shape + V = weight.shape[0] + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size(V, rank, world_size) + + target_mask = (target < vocab_start) | (target >= vocab_end) + adjusted_target = target.clone() - vocab_start # relative id + adjusted_target[target_mask] = 0 + adjusted_target_1d = adjusted_target.view(-1) + + # input + # when doing matmul, use the original precision + logits = (_input @ weight.t()).float() # chunk_size x V + if bias is not None: + logits = logits + bias + + # # ensure _input and target are contiguous + # logits_chunk = logits_chunk.contiguous() # [chunk_size, vocab_size] + # target_chunk = target_chunk.contiguous() # [chunk_size] + + max_logits = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group(), async_op=False) + logits = logits - max_logits.unsqueeze(-1) + + sum_exp_logits = torch.sum(torch.exp(logits), dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=False) + + + predicted_logits = logits[torch.arange(BT, device=logits.device), adjusted_target_1d] + predicted_logits[target_mask] = 0.0 + handle_predicted_logits = torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=True) + + # Compute gradient + grad_logits = torch.exp(logits).div_(sum_exp_logits.unsqueeze(-1)) + grad_logits[torch.arange(BT, device=grad_logits.device), adjusted_target_1d] -= 1.0 - target_mask.float() # chunk_size x V + grad_input = grad_logits.to(dtype=torch.half) @ weight + torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=False) + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=grad_logits.t().to( + _input.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + if bias is not None: + torch.add( + input=grad_bias, + other=grad_logits.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + handle_predicted_logits.wait() + loss_chunk = torch.log(sum_exp_logits) - predicted_logits + loss_1d = loss_chunk + + if reduction == "none": + loss = loss_1d + else: + loss = torch.sum(loss_1d) + + return loss, None, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + return grad_input, grad_weight, grad_bias + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + @amp_custom_fwd + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="none", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward_megatron_chunked( + _input=_input, + weight=weight, + target=target, + bias=bias, + reduction=reduction, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + return loss, z_loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output, grad_output2): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) + diff --git a/training/DeepSpeed-Domino/pretrain_gpt.py b/training/DeepSpeed-Domino/pretrain_gpt.py index 7fc4650e1..b6f47fc1e 100644 --- a/training/DeepSpeed-Domino/pretrain_gpt.py +++ b/training/DeepSpeed-Domino/pretrain_gpt.py @@ -14,6 +14,7 @@ from domino.language_model import get_language_model from domino.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy +from domino.tensor_parallel.cross_entropy import LigerFusedLinearCrossEntropyFunction _TRAIN_START_TIME = time.time() @@ -24,6 +25,14 @@ def post_language_model_processing(lm_output, labels, logit_weights, parallel_ou loss = loss.transpose(0, 1).contiguous() return loss +def post_language_model_processing_with_liger(lm_output, labels, logit_weights, parallel_output): + b, s = labels.shape + lm_output = lm_output.flatten(0, 1) + labels = labels.transpose(0, 1).flatten(0, 1) + loss, _ = LigerFusedLinearCrossEntropyFunction.apply(lm_output, logit_weights, labels) + loss = loss.view(s, b).transpose(0, 1).contiguous() + return loss + class GPTModel(DominoModule): def __init__( @@ -46,6 +55,7 @@ def __init__( post_process=self.post_process, ) self.initialize_word_embeddings() + self.config = config def set_input_tensor(self, input_tensor): self.language_model.set_input_tensor(input_tensor) @@ -66,12 +76,20 @@ def forward( ) if self.post_process: - return post_language_model_processing( - lm_output, - labels, - self.shared_embedding_or_output_weight(), - self.parallel_output, - ) + if self.config.fused_linear_loss: + return post_language_model_processing_with_liger( + lm_output, + labels, + self.shared_embedding_or_output_weight(), + self.parallel_output, + ) + else: + return post_language_model_processing( + lm_output, + labels, + self.shared_embedding_or_output_weight(), + self.parallel_output, + ) else: return lm_output From f4eefa1f70b28b83ade44c43c39435b7701fe576 Mon Sep 17 00:00:00 2001 From: dhx Date: Tue, 8 Apr 2025 19:30:30 +0000 Subject: [PATCH 2/3] update Signed-off-by: dhx --- .../domino/tensor_parallel/cross_entropy.py | 85 ------------------- 1 file changed, 85 deletions(-) diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py b/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py index dd97132d1..c682581fb 100644 --- a/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py @@ -173,91 +173,6 @@ def fused_linear_cross_entropy_forward_megatron_chunked( return loss, None, grad_input, grad_weight, grad_bias -def fused_linear_cross_entropy_forward_megatron( - _input, - weight, - target, - bias=None, - reduction="none", -): - device = _input.device - BT, H = _input.shape - V = weight.shape[0] - - grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None - grad_input = torch.zeros_like(_input, device=device) - grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None - # we use fp32 for loss accumulator - loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) - - # TODO: evaluate how CUDA synchronization caused by .item() affects the speed - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size(V, rank, world_size) - - target_mask = (target < vocab_start) | (target >= vocab_end) - adjusted_target = target.clone() - vocab_start # relative id - adjusted_target[target_mask] = 0 - adjusted_target_1d = adjusted_target.view(-1) - - # input - # when doing matmul, use the original precision - logits = (_input @ weight.t()).float() # chunk_size x V - if bias is not None: - logits = logits + bias - - # # ensure _input and target are contiguous - # logits_chunk = logits_chunk.contiguous() # [chunk_size, vocab_size] - # target_chunk = target_chunk.contiguous() # [chunk_size] - - max_logits = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group(), async_op=False) - logits = logits - max_logits.unsqueeze(-1) - - sum_exp_logits = torch.sum(torch.exp(logits), dim=-1) - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=False) - - - predicted_logits = logits[torch.arange(BT, device=logits.device), adjusted_target_1d] - predicted_logits[target_mask] = 0.0 - handle_predicted_logits = torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=True) - - # Compute gradient - grad_logits = torch.exp(logits).div_(sum_exp_logits.unsqueeze(-1)) - grad_logits[torch.arange(BT, device=grad_logits.device), adjusted_target_1d] -= 1.0 - target_mask.float() # chunk_size x V - grad_input = grad_logits.to(dtype=torch.half) @ weight - torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=False) - - if grad_weight is not None: - torch.addmm( - input=grad_weight, - mat1=grad_logits.t().to( - _input.dtype - ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. - mat2=_input, - out=grad_weight, - alpha=1.0, - beta=1.0, - ) - if bias is not None: - torch.add( - input=grad_bias, - other=grad_logits.sum(dim=0), - out=grad_bias, - alpha=1.0, - ) - handle_predicted_logits.wait() - loss_chunk = torch.log(sum_exp_logits) - predicted_logits - loss_1d = loss_chunk - - if reduction == "none": - loss = loss_1d - else: - loss = torch.sum(loss_1d) - - return loss, None, grad_input, grad_weight, grad_bias - - def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): From 8a023859b068db1eadc19e23bafc144aa42e972d Mon Sep 17 00:00:00 2001 From: dhx Date: Wed, 9 Apr 2025 19:42:51 +0000 Subject: [PATCH 3/3] Trigger formatting check Signed-off-by: dhx