Skip to content

Commit f4eefa1

Browse files
committed
update
Signed-off-by: dhx <[email protected]>
1 parent 10dcc3f commit f4eefa1

File tree

1 file changed

+0
-85
lines changed

1 file changed

+0
-85
lines changed

Diff for: training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py

-85
Original file line numberDiff line numberDiff line change
@@ -173,91 +173,6 @@ def fused_linear_cross_entropy_forward_megatron_chunked(
173173

174174
return loss, None, grad_input, grad_weight, grad_bias
175175

176-
def fused_linear_cross_entropy_forward_megatron(
177-
_input,
178-
weight,
179-
target,
180-
bias=None,
181-
reduction="none",
182-
):
183-
device = _input.device
184-
BT, H = _input.shape
185-
V = weight.shape[0]
186-
187-
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
188-
grad_input = torch.zeros_like(_input, device=device)
189-
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
190-
# we use fp32 for loss accumulator
191-
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
192-
193-
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
194-
rank = get_tensor_model_parallel_rank()
195-
world_size = get_tensor_model_parallel_world_size()
196-
vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size(V, rank, world_size)
197-
198-
target_mask = (target < vocab_start) | (target >= vocab_end)
199-
adjusted_target = target.clone() - vocab_start # relative id
200-
adjusted_target[target_mask] = 0
201-
adjusted_target_1d = adjusted_target.view(-1)
202-
203-
# input
204-
# when doing matmul, use the original precision
205-
logits = (_input @ weight.t()).float() # chunk_size x V
206-
if bias is not None:
207-
logits = logits + bias
208-
209-
# # ensure _input and target are contiguous
210-
# logits_chunk = logits_chunk.contiguous() # [chunk_size, vocab_size]
211-
# target_chunk = target_chunk.contiguous() # [chunk_size]
212-
213-
max_logits = torch.max(logits, dim=-1)[0]
214-
torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group(), async_op=False)
215-
logits = logits - max_logits.unsqueeze(-1)
216-
217-
sum_exp_logits = torch.sum(torch.exp(logits), dim=-1)
218-
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=False)
219-
220-
221-
predicted_logits = logits[torch.arange(BT, device=logits.device), adjusted_target_1d]
222-
predicted_logits[target_mask] = 0.0
223-
handle_predicted_logits = torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=True)
224-
225-
# Compute gradient
226-
grad_logits = torch.exp(logits).div_(sum_exp_logits.unsqueeze(-1))
227-
grad_logits[torch.arange(BT, device=grad_logits.device), adjusted_target_1d] -= 1.0 - target_mask.float() # chunk_size x V
228-
grad_input = grad_logits.to(dtype=torch.half) @ weight
229-
torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=False)
230-
231-
if grad_weight is not None:
232-
torch.addmm(
233-
input=grad_weight,
234-
mat1=grad_logits.t().to(
235-
_input.dtype
236-
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
237-
mat2=_input,
238-
out=grad_weight,
239-
alpha=1.0,
240-
beta=1.0,
241-
)
242-
if bias is not None:
243-
torch.add(
244-
input=grad_bias,
245-
other=grad_logits.sum(dim=0),
246-
out=grad_bias,
247-
alpha=1.0,
248-
)
249-
handle_predicted_logits.wait()
250-
loss_chunk = torch.log(sum_exp_logits) - predicted_logits
251-
loss_1d = loss_chunk
252-
253-
if reduction == "none":
254-
loss = loss_1d
255-
else:
256-
loss = torch.sum(loss_1d)
257-
258-
return loss, None, grad_input, grad_weight, grad_bias
259-
260-
261176
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
262177
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
263178
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):

0 commit comments

Comments
 (0)