33import time
44from typing import Any , Callable , Dict , List , Optional
55
6+ import deepspeed
67import torch
78import torch .distributed as dist
89import transformers
10+ from safetensors .torch import save_file
911from tensordict import TensorDict
1012from transformers import (
1113 AutoConfig ,
1416 get_linear_schedule_with_warmup ,
1517)
1618
17- from arealite .api .cli_args import TrainEngineConfig
19+ from arealite .api .cli_args import MicroBatchSpec , TrainEngineConfig
1820from arealite .api .engine_api import (
1921 FinetuneSpec ,
20- MicroBatchSpec ,
2122 SaveLoadMeta ,
2223 TrainEngine ,
2324 WeightUpdateMeta ,
3435 unsqueeze_mb_list ,
3536)
3637from arealite .utils .fsdp import get_cosine_schedule_with_warmup
37- from arealite .utils .save_load import get_state_dict_from_repo_id_or_path
38+ from arealite .utils .save_load import (
39+ get_state_dict_from_repo_id_or_path ,
40+ is_existing_local_path ,
41+ )
3842from realhf .api .core .data_api import load_hf_tokenizer
3943from realhf .base import logging , name_resolve , names
4044
@@ -54,6 +58,7 @@ def __init__(self, config: TrainEngineConfig):
5458 # initialization
5559 self .initialized = False
5660 self .weight_update_group_initialized = False
61+ self .local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
5762 self .world_size = int (os .environ .get ("WORLD_SIZE" , 1 ))
5863
5964 def train (self , mode : bool = True ):
@@ -67,31 +72,24 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
6772
6873 """Initialize distributed communication and model."""
6974 if not dist .is_initialized ():
70- dist .init_process_group (backend = "nccl" )
71- if dist .get_world_size () > 1 :
72- raise RuntimeError (
73- "Distributed training is not supported in this engine. "
74- "Please use FSDP for distributed training."
75- )
75+ deepspeed .init_distributed (dist_backend = "nccl" , world_size = self .world_size )
7676
77- torch .cuda .set_device (int ( os . environ . get ( "LOCAL_RANK" , 0 )) )
78- self .device = torch .device (int ( os . environ . get ( "LOCAL_RANK" , 0 )) )
77+ torch .cuda .set_device (self . local_rank )
78+ self .device = torch .device (f"cuda: { self . local_rank } " )
7979
8080 dtype = torch .bfloat16 if self .config .bf16 else torch .float16
8181 self .model_config = AutoConfig .from_pretrained (
8282 pretrained_model_name_or_path = self .config .path ,
8383 trust_remote_code = True ,
8484 )
8585 self .tokenizer = load_hf_tokenizer (self .config .path )
86- with torch .device ("cuda" ):
87- # initialize scratch model from config
88- model = AutoModelForCausalLM .from_config (
89- self .model_config ,
90- torch_dtype = dtype ,
91- attn_implementation = self .config .attn_impl ,
92- )
86+ model = AutoModelForCausalLM .from_config (
87+ self .model_config ,
88+ torch_dtype = dtype ,
89+ attn_implementation = self .config .attn_impl ,
90+ )
9391
94- self .model = model . to ( "cuda" )
92+ self .model = model
9593
9694 if not self .config .init_from_scratch :
9795 # Load model from a initial checkpoint path,
@@ -102,9 +100,20 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
102100 with_optim = False ,
103101 tokenizer = None ,
104102 base_model_path = self .config .path ,
103+ distribute = False ,
105104 )
106105 self .load (load_meta )
107106
107+ if self .world_size > 1 :
108+ if self ._check_autotp ():
109+ self .model = deepspeed .tp_model_init (
110+ self .model , tp_size = self .config .hf .autotp_size , dtype = dtype
111+ )
112+ else :
113+ raise RuntimeError ("DeepSpeed AutoTP configuration error in HFEngine. " )
114+
115+ self .model = self .model .to (device = self .device , non_blocking = True )
116+
108117 # Set up optimizer
109118 if self .optimizer_config is not None :
110119 assert (
@@ -153,6 +162,21 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
153162
154163 self .initialized = True
155164
165+ def _check_autotp (self ):
166+ tp_size = self .config .hf .autotp_size
167+ config = self .model_config
168+ num_attention_heads = config .num_attention_heads
169+ num_key_value_heads = config .num_key_value_heads
170+ hidden_size = config .hidden_size
171+ intermediate_size = config .intermediate_size
172+
173+ return (
174+ num_attention_heads % tp_size == 0
175+ and num_key_value_heads % tp_size == 0
176+ and hidden_size % tp_size == 0
177+ and intermediate_size % tp_size == 0
178+ )
179+
156180 def destroy (self ):
157181 """Destroy the engine and release GPU memory."""
158182 self .model = None
@@ -164,7 +188,7 @@ def destroy(self):
164188
165189 def save (self , meta : SaveLoadMeta ):
166190 if meta .weight_format == "hf" :
167- self ._save_model_to_hf (meta .path , meta .tokenizer )
191+ self ._save_model_to_hf (meta .path , meta .tokenizer , meta . distribute )
168192 elif meta .weight_format == "dcp" :
169193 # TODO: implement DCP save/load for HF
170194 raise NotImplementedError ("DCP format saving is not implemented yet. " )
@@ -176,7 +200,7 @@ def save(self, meta: SaveLoadMeta):
176200
177201 def load (self , meta : SaveLoadMeta ):
178202 if meta .weight_format == "hf" :
179- self ._load_model_from_hf (meta .path )
203+ self ._load_model_from_hf (meta .path , meta . distribute )
180204 elif meta .weight_format == "dcp" :
181205 # TODO: implement DCP save/load for HF
182206 raise NotImplementedError ("DCP format loading is not implemented yet. " )
@@ -198,27 +222,47 @@ def _load_optimizer_state(self, path: str):
198222 self .optimizer .load_state_dict (optimizer_state_dict )
199223
200224 def _save_model_to_hf (
201- self , path : str , tokenizer : Optional [transformers .PreTrainedTokenizerFast ]
225+ self ,
226+ path : str ,
227+ tokenizer : Optional [transformers .PreTrainedTokenizerFast ],
228+ distribute : bool = False ,
202229 ):
203230 """Save model in HuggingFace format."""
204231 if self .model is None :
205232 raise RuntimeError ("Model not initialized" )
206- os .makedirs (path , exist_ok = True )
233+
234+ if self .local_rank == 0 :
235+ os .makedirs (path , exist_ok = True )
236+
237+ if self .world_size > 1 :
238+ dist .barrier ()
207239
208240 state_dict = {k : v .cpu () for k , v in self .model .state_dict ().items ()}
209- self .model .save_pretrained (path , state_dict = state_dict )
210- self .model_config .save_pretrained (path )
211- if tokenizer is not None :
212- tokenizer .save_pretrained (path )
213241
214- def _load_model_from_hf (self , path : str ):
242+ if distribute :
243+ save_file (
244+ state_dict , f"{ path } /tp_rank_{ self .local_rank :02d} _model.safetensors"
245+ )
246+ else :
247+ self .model .save_pretrained (path , state_dict = state_dict )
248+
249+ if self .local_rank == 0 :
250+ self .model_config .save_pretrained (path )
251+ if self .tokenizer is not None :
252+ self .tokenizer .save_pretrained (path )
253+
254+ def _load_model_from_hf (self , path : str , distribute : bool = False ):
215255 """Load model from HuggingFace format."""
216- full_state = get_state_dict_from_repo_id_or_path (path )
217- self .model .load_state_dict (
218- full_state , strict = not self .model_config .tie_word_embeddings
219- )
220- if self .model_config .tie_word_embeddings :
221- self .model .tie_weights ()
256+ if self .local_rank == 0 or is_existing_local_path (path ):
257+ if distribute :
258+ path = f"{ path } /tp_rank_{ self .local_rank :02d} _model.safetensors"
259+ full_state = get_state_dict_from_repo_id_or_path (path )
260+ self .model .load_state_dict (
261+ full_state , strict = not self .model_config .tie_word_embeddings
262+ )
263+
264+ if self .model_config .tie_word_embeddings :
265+ self .model .tie_weights ()
222266
223267 def upload_weights (self , meta : WeightUpdateMeta ):
224268 if meta .type == "nccl" :
0 commit comments