Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autotp training #6922

Open
wants to merge 74 commits into
base: master
Choose a base branch
from
Open

Autotp training #6922

wants to merge 74 commits into from

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Jan 2, 2025

FYI @tjruwase @GuanhuaWang @delock @skyshine102 context: #5445
changes/support

  • auto tensor parallel training for HF model(zero compatible. I only tested zero1 currently)
  • distributed ckpt save(UCP is not supported).
  • HF model files save(set gather_16bit_weights_on_model_save=True in ds config).
  • Dataloader check.
  • Uts.
  • tp layer refactor by abstract layer design.

HF trainer dependency:
transformer: https://github.com/inkcherry/transformers/tree/ds_tp
accelerate: https://github.com/inkcherry/accelerate/tree/ds_tp
I could send them once ds support these api.

Usage:
Users do not need to modify the client code, they only need to configure the settings in the config file to achieve the desired functionality.
Below is an example of code for fine-tuning a LLaMA 2 model (SFT). It supports Zero3/FSDP training and enables TP training by simply adjusting the configuration

https://github.com/inkcherry/stanford_alpaca/commits/tp_demo_1127/
This branch contains three commits, with the last two commits added for quick experiments and logging purposes.
results
loss curve(gbs=16):
zero3(baseline)
image
tp(this)
image

zero1 with zero1+tp(zero compatible)
image

performance(For your reference only.):
zero3(not enabled any acceleration.) : 18GB 2.3s/it
zero1:38GB 1.30s/it
zero1+tp: 24GB 1.66s/it
extension:
I think async-TP/domino .etc. can be implemented by inheriting a class and overriding the fwd/bwd methods. The logic for gather/partition can be reused to achieve this.(please correct me if I am wrong)

Complex sharding can also be achieved through independent partitioning and gathering. Partitioning is mandatory, while gathering is required for training.
TODO:
embedding vocab parallel
Currently, the parallelism for embeddings is primarily based on hidden_dim parallel combined with allreduce. This approach takes advantage of efficient reduction kernels. and it is not forced to use.
In training, however, the more common method is vocab parallelism. Enabling by default can save a certain amount of GPU memory.

thanks for @delock guidance.
I also verified inference with cpu-inference workloads(Optimized Model List in https://github.com/intel/intel-extension-for-pytorch/tree/main).
many thanks for @xuguangxin @ikurtchen @rogerxfeng8 ,@Yejing-Lai ,@ys950902 .etc. Help review and address matters related to inference.

Returns:
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None.
"""
#TODO: If we use both Zero3 and tensor parallel simultaneously
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what is meant by the gather mechanism of tensor parallelism?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question.
I could somehow understand as it's a similar function to

def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False):
but specific for TP. The function name can be improved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skyshine102, thanks for the comment. A key difference between is zero3 and TP is that partitioned zero3 modules materialized using allgather before compute, whereas TP modules compute in a partitioned manner. So, it is unclear to me what requires gathering for TP.

Copy link
Contributor Author

@inkcherry inkcherry Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is currently not invoked during the compute.
In my opinion, the ultimate goal of training such models is to save HF format weights for inference, this is primarily focusing on saving the original HF model (not TP-partitioned) for using. https://github.com/huggingface/accelerate/blob/main/src/accelerate/accelerator.py#L3367
Additionally, by enable specifically designed gathers in the TP layer, it provides the flexibility for customized complex sharding of different layout.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tradition of TP is saving a checkpoint for each rank (related: universal format from ds team) while HF format is usually consolidated so a function is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't see why need to gather weights/params in TP training/inference. If it is only used for re-collecting weights for single point checkpoint write, then you can use our universal checkpoint feature to convert model parallel strategy after training.

Copy link
Contributor Author

@inkcherry inkcherry Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion, UCP is a great feature for recovering training from parallel topology changes.

In this patch, my plan is to support two basic Huggingface save APIs:

trainer.save_model(), which does not rely on any checkpoints. After calling this (without post-processing), it saves in the HF format for other purposes. The interface is user-friendly.
trainer.save_state(), which is used for resuming training from checkpoints.
To implement the basic trainer.save_model(), I referred to the current Zero3 strategy. _zero3_consolidated_16bit_state_dict() , For safety, it’s possible after invokesave_model(),the program may still encounter scenarios involving forward/backward/saving again , so it is implemented via context. This is similar to Zero3’s gather, but it will not be used in the forward and backward passes.

Regarding checkpoints, currently, it is non-UCP. The reasons are:

  • I noticed that UCP still requires some local post-processing (according to the documentation).

  • It has mainly been validated on Mega-ds, so I think integrating and validating it with HF-related code would require more effort. Considering my bandwidth, I lean towards we may complete this enhancement in another PR.
    Regarding the gather code, the logic for restoring shards (reverse of partition) can't be omitted (even as post-processing). Managing partition and gather via a class might be more convenient. I think we may use a config switch to directly save in a universal format during runtime if we support it.

I hope my answer is helpful for you : )

@@ -247,6 +248,11 @@ def _post_forward_hook(self, module, input, output):
self._model_times.append(elapsed_time)

def _create_model_parallel_group(self, config):

if is_autotp_training_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, control flow for training should not come here. I think some refactoring/restructuring is needed for code quality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it seems there was no suitable file currently, so I created a runtime/tensor_parallel folder now~

Copy link

@skyshine102 skyshine102 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @inkcherry for this contribution. I have spent some time to read this PR and I'm happy to be involved in this discussion. (I'm not from deepspeed team but deepspeed user. My comments are relatively minor though.)

return Yuan_LinearALlreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional code block is trying to deal with "MLP including chunk layer" (general case), but the returned module/object is in the name of GLM prefix.
It could be better to rename the GLM_LinearLayer to sth like GateUpPack_LinearLayer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comments, modified:)

@@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearALlreduce, Yuan_LinearLayer, GLM_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original coding style is LinearAllreduce instead of LinearALlreduce.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

broadcast_and_check(args, bcast_rank, bcast_group)
broadcast_and_check(kwargs, bcast_rank, bcast_group)

print(f"RANK[{dist.get_rank()}]:The Dataloader has passed the TP group consistency check.")
Copy link

@skyshine102 skyshine102 Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use the logger at rank 0 instead of print.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, modified.

Returns:
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None.
"""
#TODO: If we use both Zero3 and tensor parallel simultaneously

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question.
I could somehow understand as it's a similar function to

def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False):
but specific for TP. The function name can be improved.

@skyshine102
Copy link

skyshine102 commented Jan 7, 2025

@tjruwase @GuanhuaWang We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

@inkcherry by design should autotp training work with ZeRO3 as well?

@inkcherry I have the same question. Does this PR support the flow like https://pytorch.org/tutorials/intermediate/TP_tutorial.html#combine-tensor-parallel-with-fully-sharded-data-parallel-together ? (TP to shared weight $W$ to $W_{tp_i}$, then further shard $W_{tp_i}$ by ZeRO-3 to $W_{tp_i, dp_j}$)

@@ -31,6 +31,11 @@ class MoETypeEnum(str, Enum):
standard = "standard"


class AUTOTP_MODE(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should ideally be outside of the inference module. For example, in deepspeed/runtime/tensor_parallel module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, moved to runtime/tensor_parallel folder

@GuanhuaWang
Copy link
Member

GuanhuaWang commented Jan 9, 2025

Hi @inkcherry , @delock
sorry, I was on vacation. just took a quick look of it, this is awesome!
I will first do the code review. And as Tunji suggested, I will try to figure out Domino integration afterwards.

We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Bravo @inkcherry, this is an excellent technology and massive usability benefit for users. This is really exciting!

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

In terms of Domino integration, @GuanhuaWang will take the lead on that.

for Zero3 + TP: Currently, the logic to combine the saving of HF weights for TP & DP has not been implemented, but it is entirely feasible. If needed, it can be implemented in the future.

I would love to prioritize enabling UCP support sooner than later. @inkcherry, can you please share the work needed here?

Copy link
Member

@GuanhuaWang GuanhuaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @inkcherry , @delock
Sorry for the delay. I just left some comments. Thanks

if is_inference_mode:
dist.inference_all_reduce(input, group=group)
else:
dist.all_reduce(input.contiguous(), group=group)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any reason for input.contiguous()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that adding this makes it safer, potentially helping to avoid discontinuity introduced by transpose/permute.
FYI: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py#L23

I am not very clear on the implementation detail of inference_all_reduce, so I have kept the original dist.inference_all_reduce code path.

@staticmethod
def symbolic(graph, input):
"""Symbolic function for tracing."""
return dist.all_reduce(input.contiguous(), dist.get_tensor_model_parallel_group())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar here, is this contiguous() necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is consistent with the previous situation.


@pytest.mark.parametrize("layer_type", ["linear", "linearallreduce"])
def test(self, layer_type):
tp_size = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we parametrize and test tp_size of both 2 and 4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder, added

reuse_dist_env = True

def test_save_original_weight(self):
tp_size = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, could we parameterize both tp_size 2 and 4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder, added

return

if data_parallel_size is None:
data_parallel_size = dist.get_world_size() // tensor_model_parallel_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to consider pipeline_parallel_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this feature does not support the pipeline and pipeline-related logic will not reach this part. Perhaps we can consider adding pipeline support in the future.

self.tp_config = TPConfig()
self.tp_config.tp_size = tp_size
if tp_size <= 1:
self.tp_config.enabled = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anywhere this flag is used (i.e. there seems no design/code if enabled flag == False)? is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing that out. It's not necessary, I was referring to the inference config. I have removed it now.

Returns:
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None.
"""
#TODO: If we use both Zero3 and tensor parallel simultaneously
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't see why need to gather weights/params in TP training/inference. If it is only used for re-collecting weights for single point checkpoint write, then you can use our universal checkpoint feature to convert model parallel strategy after training.

tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """

tp_grain_size: int = 64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this argument I also did not see any use case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable is used in the autoTP parser to set tile boundaries to accelerate GEMM.

set_tp_grain_size(config.tensor_parallel.tp_grain_size)

it has not been activated in training yet, as it requires support for uneven gather. I have added clearer comments for better understanding.

Comment on lines +497 to +517
class Yuan_LinearAllreduce(LinearAllreduce):

#Yuan2
@torch.no_grad()
def partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, False)
params_list[0].data = weight
if bias is not None:
params_list[1].data = bias


class Yuan_LinearLayer(LinearLayer):
#Yuan2
@torch.no_grad()
def partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
if bias is not None:
params_list[1].data = move(bias, get_accelerator().current_device_name()).detach()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to make an abstraction of partition method with arguments passed-in for different models? if doing this, we can avoid create 2 new classes (e.g., Yuan_linear & Yuan_linear+allreduce) for every new model structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they currently only have one method. every specific shard logic should have a corresponding reverse gather logic. The current shard method hasn’t implemented the corresponding gather. I think using a class might help reserve a potential placeholder and make the code more consistent.

return new_obj


class GatherReplacedLayerParams:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need gather TP params during training or inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the reason are integrated into the comments above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants