Skip to content

Commit e2745b9

Browse files
author
Jayon02
committed
add autotp for hf
1 parent 5e8c167 commit e2745b9

4 files changed

Lines changed: 109 additions & 43 deletions

File tree

arealite/api/cli_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ class FSDPEngineConfig:
104104
)
105105

106106

107+
@dataclass
108+
class HFEngineConfig:
109+
autotp_size: Optional[int] = field(
110+
default=1,
111+
metadata={"help": "DeepSpeed AutoTP size"},
112+
)
113+
114+
107115
@dataclass
108116
class TrainEngineConfig:
109117
experiment_name: str = MISSING
@@ -136,6 +144,7 @@ class TrainEngineConfig:
136144
)
137145
backend: str = ""
138146
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
147+
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
139148

140149

141150
@dataclass

arealite/api/io_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class SaveLoadMeta:
175175
with_optim: bool
176176
tokenizer: PreTrainedTokenizerFast | None
177177
base_model_path: str | None
178+
distribute: bool = False
178179

179180

180181
@dataclass

arealite/engine/hf_engine.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import time
44
from typing import Any, Callable, Dict, List, Optional
55

6+
import deepspeed
67
import torch
78
import torch.distributed as dist
89
import transformers
10+
from safetensors.torch import save_file
911
from tensordict import TensorDict
1012
from transformers import (
1113
AutoConfig,
@@ -14,10 +16,9 @@
1416
get_linear_schedule_with_warmup,
1517
)
1618

17-
from arealite.api.cli_args import TrainEngineConfig
19+
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
1820
from arealite.api.engine_api import (
1921
FinetuneSpec,
20-
MicroBatchSpec,
2122
SaveLoadMeta,
2223
TrainEngine,
2324
WeightUpdateMeta,
@@ -34,7 +35,10 @@
3435
unsqueeze_mb_list,
3536
)
3637
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
37-
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
38+
from arealite.utils.save_load import (
39+
get_state_dict_from_repo_id_or_path,
40+
is_existing_local_path,
41+
)
3842
from realhf.api.core.data_api import load_hf_tokenizer
3943
from realhf.base import logging, name_resolve, names
4044

@@ -54,6 +58,7 @@ def __init__(self, config: TrainEngineConfig):
5458
# initialization
5559
self.initialized = False
5660
self.weight_update_group_initialized = False
61+
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
5762
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
5863

5964
def train(self, mode: bool = True):
@@ -67,31 +72,24 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
6772

6873
"""Initialize distributed communication and model."""
6974
if not dist.is_initialized():
70-
dist.init_process_group(backend="nccl")
71-
if dist.get_world_size() > 1:
72-
raise RuntimeError(
73-
"Distributed training is not supported in this engine. "
74-
"Please use FSDP for distributed training."
75-
)
75+
deepspeed.init_distributed(dist_backend="nccl", world_size=self.world_size)
7676

77-
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
78-
self.device = torch.device(int(os.environ.get("LOCAL_RANK", 0)))
77+
torch.cuda.set_device(self.local_rank)
78+
self.device = torch.device(f"cuda:{self.local_rank}")
7979

8080
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
8181
self.model_config = AutoConfig.from_pretrained(
8282
pretrained_model_name_or_path=self.config.path,
8383
trust_remote_code=True,
8484
)
8585
self.tokenizer = load_hf_tokenizer(self.config.path)
86-
with torch.device("cuda"):
87-
# initialize scratch model from config
88-
model = AutoModelForCausalLM.from_config(
89-
self.model_config,
90-
torch_dtype=dtype,
91-
attn_implementation=self.config.attn_impl,
92-
)
86+
model = AutoModelForCausalLM.from_config(
87+
self.model_config,
88+
torch_dtype=dtype,
89+
attn_implementation=self.config.attn_impl,
90+
)
9391

94-
self.model = model.to("cuda")
92+
self.model = model
9593

9694
if not self.config.init_from_scratch:
9795
# Load model from a initial checkpoint path,
@@ -102,9 +100,20 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
102100
with_optim=False,
103101
tokenizer=None,
104102
base_model_path=self.config.path,
103+
distribute=False,
105104
)
106105
self.load(load_meta)
107106

107+
if self.world_size > 1:
108+
if self._check_autotp():
109+
self.model = deepspeed.tp_model_init(
110+
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
111+
)
112+
else:
113+
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
114+
115+
self.model = self.model.to(device=self.device, non_blocking=True)
116+
108117
# Set up optimizer
109118
if self.optimizer_config is not None:
110119
assert (
@@ -153,6 +162,21 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
153162

154163
self.initialized = True
155164

165+
def _check_autotp(self):
166+
tp_size = self.config.hf.autotp_size
167+
config = self.model_config
168+
num_attention_heads = config.num_attention_heads
169+
num_key_value_heads = config.num_key_value_heads
170+
hidden_size = config.hidden_size
171+
intermediate_size = config.intermediate_size
172+
173+
return (
174+
num_attention_heads % tp_size == 0
175+
and num_key_value_heads % tp_size == 0
176+
and hidden_size % tp_size == 0
177+
and intermediate_size % tp_size == 0
178+
)
179+
156180
def destroy(self):
157181
"""Destroy the engine and release GPU memory."""
158182
self.model = None
@@ -164,7 +188,7 @@ def destroy(self):
164188

165189
def save(self, meta: SaveLoadMeta):
166190
if meta.weight_format == "hf":
167-
self._save_model_to_hf(meta.path, meta.tokenizer)
191+
self._save_model_to_hf(meta.path, meta.tokenizer, meta.distribute)
168192
elif meta.weight_format == "dcp":
169193
# TODO: implement DCP save/load for HF
170194
raise NotImplementedError("DCP format saving is not implemented yet. ")
@@ -176,7 +200,7 @@ def save(self, meta: SaveLoadMeta):
176200

177201
def load(self, meta: SaveLoadMeta):
178202
if meta.weight_format == "hf":
179-
self._load_model_from_hf(meta.path)
203+
self._load_model_from_hf(meta.path, meta.distribute)
180204
elif meta.weight_format == "dcp":
181205
# TODO: implement DCP save/load for HF
182206
raise NotImplementedError("DCP format loading is not implemented yet. ")
@@ -198,27 +222,47 @@ def _load_optimizer_state(self, path: str):
198222
self.optimizer.load_state_dict(optimizer_state_dict)
199223

200224
def _save_model_to_hf(
201-
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast]
225+
self,
226+
path: str,
227+
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
228+
distribute: bool = False,
202229
):
203230
"""Save model in HuggingFace format."""
204231
if self.model is None:
205232
raise RuntimeError("Model not initialized")
206-
os.makedirs(path, exist_ok=True)
233+
234+
if self.local_rank == 0:
235+
os.makedirs(path, exist_ok=True)
236+
237+
if self.world_size > 1:
238+
dist.barrier()
207239

208240
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
209-
self.model.save_pretrained(path, state_dict=state_dict)
210-
self.model_config.save_pretrained(path)
211-
if tokenizer is not None:
212-
tokenizer.save_pretrained(path)
213241

214-
def _load_model_from_hf(self, path: str):
242+
if distribute:
243+
save_file(
244+
state_dict, f"{path}/tp_rank_{self.local_rank:02d}_model.safetensors"
245+
)
246+
else:
247+
self.model.save_pretrained(path, state_dict=state_dict)
248+
249+
if self.local_rank == 0:
250+
self.model_config.save_pretrained(path)
251+
if self.tokenizer is not None:
252+
self.tokenizer.save_pretrained(path)
253+
254+
def _load_model_from_hf(self, path: str, distribute: bool = False):
215255
"""Load model from HuggingFace format."""
216-
full_state = get_state_dict_from_repo_id_or_path(path)
217-
self.model.load_state_dict(
218-
full_state, strict=not self.model_config.tie_word_embeddings
219-
)
220-
if self.model_config.tie_word_embeddings:
221-
self.model.tie_weights()
256+
if self.local_rank == 0 or is_existing_local_path(path):
257+
if distribute:
258+
path = f"{path}/tp_rank_{self.local_rank:02d}_model.safetensors"
259+
full_state = get_state_dict_from_repo_id_or_path(path)
260+
self.model.load_state_dict(
261+
full_state, strict=not self.model_config.tie_word_embeddings
262+
)
263+
264+
if self.model_config.tie_word_embeddings:
265+
self.model.tie_weights()
222266

223267
def upload_weights(self, meta: WeightUpdateMeta):
224268
if meta.type == "nccl":

arealite/utils/save_load.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
from typing import Dict
34

45
import torch
@@ -41,18 +42,21 @@ def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
4142
else:
4243
# Assume it's a local path
4344
local_path = repo_id_or_path
44-
if not os.path.isdir(local_path):
45-
raise ValueError(
46-
f"Local path {local_path} does not exist or is not a directory, "
47-
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
48-
)
4945

5046
# Step 3: Load all .safetensors and .bin files
5147
file_paths_to_load = []
52-
for filename in os.listdir(local_path):
53-
filepath = os.path.join(local_path, filename)
54-
if filename.endswith(".safetensors") or filename.endswith(".bin"):
55-
file_paths_to_load.append(filepath)
48+
if os.path.isdir(local_path):
49+
for filename in os.listdir(local_path):
50+
filepath = os.path.join(local_path, filename)
51+
if filename.endswith(".safetensors") or filename.endswith(".bin"):
52+
file_paths_to_load.append(filepath)
53+
elif os.path.isfile(local_path):
54+
file_paths_to_load.append(local_path)
55+
else:
56+
raise ValueError(
57+
f"Local path {local_path} does not exist or is not a valid path, "
58+
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
59+
)
5660

5761
def _load(filepath: str):
5862
if filepath.endswith(".safetensors"):
@@ -82,3 +86,11 @@ def _load(filepath: str):
8286
except Exception as e:
8387
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
8488
return state_dict
89+
90+
91+
def is_existing_local_path(path: str) -> bool:
92+
try:
93+
path_obj = Path(path)
94+
return path_obj.exists() and (path_obj.is_file() or path_obj.is_dir())
95+
except (ValueError, OSError):
96+
return False

0 commit comments

Comments
 (0)