Skip to content

Commit 61bc432

Browse files
committed
add mistral support
1 parent 5d4bba7 commit 61bc432

File tree

5 files changed

+351
-0
lines changed

5 files changed

+351
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ You can use whatever you want, whenever you want. We just introduce a simple int
8484

8585
**Clouds we have tested on:**
8686

87+
- Azure
8788
- LambdaLabs
8889
- FluidStack
8990

higgsfield/mistral/__init__.py

Whitespace-only changes.

higgsfield/mistral/mistral.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import os
2+
import functools
3+
from pathlib import Path
4+
5+
import torch
6+
import torch.distributed as dist
7+
8+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
9+
FullyShardedDataParallel as FSDP,
10+
CPUOffload,
11+
)
12+
13+
from torch.distributed.fsdp import (
14+
MixedPrecision,
15+
ShardingStrategy,
16+
)
17+
from torch.distributed.fsdp.wrap import (
18+
transformer_auto_wrap_policy,
19+
)
20+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
21+
checkpoint_wrapper,
22+
CheckpointImpl,
23+
apply_activation_checkpointing,
24+
)
25+
26+
from transformers import (
27+
MistralForCausalLM,
28+
MistralConfig,
29+
default_data_collator,
30+
)
31+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
32+
from optimum.bettertransformer import BetterTransformer
33+
34+
from higgsfield.checkpoint.fsdp_checkpoint import (
35+
save_distributed_model_rank0,
36+
fsdp_model_state_dict_rank0,
37+
)
38+
39+
from higgsfield.mistral.mistral_utils import (
40+
load_mistral_from_checkpoint,
41+
load_mistral_from_config,
42+
)
43+
44+
class Mistral(FSDP):
45+
def __init__(
46+
self,
47+
model_name,
48+
checkpoint_path=None,
49+
zero_stage=3,
50+
fast_attn=False,
51+
precision="bf16",
52+
cpu_init_rank0=False,
53+
cpu_offload=False,
54+
num_embeddings=None,
55+
cache_dir=None,
56+
):
57+
58+
rank = dist.get_rank()
59+
60+
61+
model = MistralForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
62+
63+
if num_embeddings:
64+
model.resize_token_embeddings(num_embeddings)
65+
66+
67+
if fast_attn:
68+
#raise NotImplementedError("Fast attention is not supported yet")
69+
model = BetterTransformer.transform(model)
70+
71+
fpSixteen = MixedPrecision(
72+
param_dtype=torch.float16,
73+
reduce_dtype=torch.float16,
74+
buffer_dtype=torch.float16,
75+
)
76+
77+
bfSixteen_mixed = MixedPrecision(
78+
param_dtype=torch.float32,
79+
reduce_dtype=torch.bfloat16,
80+
buffer_dtype=torch.bfloat16,
81+
)
82+
83+
pure_bf16 = False
84+
if precision == "fp16":
85+
mixed_precision_policy = fpSixteen
86+
87+
elif precision == "bf16":
88+
mixed_precision_policy = None
89+
pure_bf16 = True
90+
91+
elif precision == "bf16_mixed":
92+
mixed_precision_policy = bfSixteen_mixed
93+
94+
else:
95+
mixed_precision_policy = None
96+
97+
if pure_bf16:
98+
model.to(torch.bfloat16)
99+
100+
wrapping_policy = functools.partial(
101+
transformer_auto_wrap_policy,
102+
transformer_layer_cls={
103+
MistralDecoderLayer,
104+
}
105+
)
106+
107+
if zero_stage == 0:
108+
sharding_strategy = ShardingStrategy.NO_SHARD
109+
110+
elif zero_stage == 1:
111+
raise NotImplementedError("stage 1 is not supported. Only 0 2 3")
112+
113+
elif zero_stage == 2:
114+
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
115+
116+
elif zero_stage == 3:
117+
sharding_strategy = ShardingStrategy.FULL_SHARD
118+
else:
119+
raise NotImplementedError("stage can be only 0 2 3")
120+
121+
if cpu_init_rank0 and rank != 0:
122+
param_init_fn = lambda module: module.to_empty(
123+
device=torch.device('cuda'),
124+
recurse=False,
125+
)
126+
else:
127+
param_init_fn = None
128+
129+
if cpu_offload:
130+
cpu_offload = CPUOffload(offload_params=True)
131+
else:
132+
cpu_offload = None
133+
134+
super().__init__(
135+
model,
136+
auto_wrap_policy=wrapping_policy,
137+
cpu_offload=cpu_offload,
138+
mixed_precision=mixed_precision_policy,
139+
sharding_strategy=sharding_strategy,
140+
device_id=torch.cuda.current_device(),
141+
limit_all_gathers=True,
142+
sync_module_states=cpu_init_rank0,
143+
param_init_fn=param_init_fn,
144+
)
145+
146+
non_reentrant_wrapper = functools.partial(
147+
checkpoint_wrapper,
148+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
149+
)
150+
151+
check_fn = lambda submodule: isinstance(submodule, MistralDecoderLayer)
152+
153+
apply_activation_checkpointing(
154+
self,
155+
checkpoint_wrapper_fn=non_reentrant_wrapper,
156+
check_fn=check_fn,
157+
)
158+
159+
fsdp = True
160+
self.precision = precision
161+
self.fsdp = fsdp
162+
self.model_name = model_name
163+
self.num_embeddings = num_embeddings
164+
165+
def __call__(self, batch):
166+
local_rank = int(os.environ["LOCAL_RANK"])
167+
168+
for key in batch.keys():
169+
batch[key] = batch[key].to(local_rank)
170+
171+
if self.precision == "fp16":
172+
with torch.cuda.amp.autocast():
173+
loss = super().__call__(**batch).loss
174+
else:
175+
loss = super().__call__(**batch).loss
176+
177+
return loss
178+
179+
def save_model(self, save_path):
180+
'''
181+
Save model's weight to master node
182+
~/.cache/higgsfield/{save_path}
183+
'''
184+
if "/" == save_path[0]:
185+
save_path = save_path[1:]
186+
187+
head, tail = os.path.split(save_path)
188+
189+
path = Path.home() / ".cache/higgsfield" / head
190+
path.mkdir(exist_ok=True, parents=True)
191+
192+
save_distributed_model_rank0(path / tail, self)
193+
194+
def save_huggingface_model(self, save_path):
195+
'''
196+
Save model's weight in huggingface format to master node
197+
~/.cache/higgsfield/{save_path}
198+
'''
199+
if "/" == save_path[0]:
200+
save_path = save_path[1:]
201+
202+
head, tail = os.path.split(save_path)
203+
204+
path = Path.home() / ".cache/higgsfield" / head
205+
path.mkdir(exist_ok=True, parents=True)
206+
cpu_state = fsdp_model_state_dict_rank0(self)
207+
208+
if dist.get_rank() == 0:
209+
model = load_mistral_from_config(self.model_name, num_embeddings=self.num_embeddings)
210+
model.load_state_dict(cpu_state)
211+
model.save_pretrained(path / tail)
212+
213+
def push_to_hub(self, repo_id, token):
214+
cpu_state = fsdp_model_state_dict_rank0(self)
215+
216+
if dist.get_rank() == 0:
217+
model = load_mistral_from_config(self.model_name, num_embeddings=self.num_embeddings)
218+
model.load_state_dict(cpu_state)
219+
model.push_to_hub(repo_id, token=token)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch.distributed as dist
2+
3+
from torch.utils.data import (
4+
DistributedSampler,
5+
DataLoader
6+
)
7+
8+
from transformers import (
9+
AutoTokenizer,
10+
default_data_collator
11+
)
12+
13+
from higgsfield.dataset import TorchCompletionDataset
14+
15+
IGNORE_INDEX = -100
16+
DEFAULT_PAD_TOKEN = "<|pad|>"
17+
DEFAULT_EOS_TOKEN = "<|endoftext|>"
18+
DEFAULT_UNK_TOKEN = "<|unk|>"
19+
20+
def get_tokenizer(model_name, max_length, cache_dir=None):
21+
22+
tokenizer = AutoTokenizer.from_pretrained(
23+
model_name,
24+
model_max_length=max_length,
25+
padding_side="right",
26+
use_fast=False,
27+
pad_token=DEFAULT_PAD_TOKEN,
28+
trust_remote_code=True,
29+
cache_dir=cache_dir,
30+
)
31+
32+
special_tokens_dict = dict()
33+
if tokenizer.pad_token is None:
34+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
35+
if tokenizer.eos_token is None:
36+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
37+
if tokenizer.unk_token is None:
38+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
39+
40+
tokenizer.add_special_tokens(special_tokens_dict)
41+
42+
return tokenizer
43+
44+
class HiggsfieldSampler(DistributedSampler):
45+
def __init__(
46+
self,
47+
dataset,
48+
shuffle=True,
49+
seed=0,
50+
drop_last=False
51+
):
52+
rank=dist.get_rank()
53+
num_replicas=dist.get_world_size()
54+
55+
super(HiggsfieldSampler, self).__init__(
56+
dataset=dataset,
57+
num_replicas=num_replicas,
58+
rank=rank,
59+
shuffle=shuffle,
60+
seed=seed,
61+
drop_last=drop_last,
62+
)
63+
64+
class MistralLoader(DataLoader):
65+
def __init__(
66+
self,
67+
dataset,
68+
tokenizer=None,
69+
max_sequence_length=2048,
70+
batch_size_per_gpu=1,
71+
shuffle=True,
72+
seed=0,
73+
num_workers=0,
74+
pin_memory=False,
75+
drop_last=False,
76+
timeout=0,
77+
worker_init_fn=None,
78+
multiprocessing_context=None,
79+
*,
80+
prefetch_factor=None,
81+
persistent_workers=False,
82+
pin_memory_device=""
83+
):
84+
85+
if not tokenizer:
86+
tokenizer = get_tokenizer("mistralai/Mistral-7B-v0.1", max_sequence_length)
87+
88+
dataset = TorchCompletionDataset(
89+
dataset,
90+
tokenizer,
91+
max_sequence_length,
92+
)
93+
94+
sampler = HiggsfieldSampler(dataset, shuffle=shuffle, seed=seed,)
95+
96+
super(MistralLoader, self).__init__(
97+
dataset,
98+
batch_size=batch_size_per_gpu,
99+
sampler=sampler,
100+
num_workers=num_workers,
101+
pin_memory=pin_memory,
102+
drop_last=drop_last,
103+
timeout=timeout,
104+
worker_init_fn=worker_init_fn,
105+
multiprocessing_context=multiprocessing_context,
106+
prefetch_factor=prefetch_factor,
107+
persistent_workers=persistent_workers,
108+
pin_memory_device=pin_memory_device
109+
)
110+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
from transformers import (
3+
MistralConfig,
4+
MistralForCausalLM,
5+
)
6+
from higgsfield.checkpoint import fsdp_model_state_dict_rank0
7+
8+
def load_mistral_from_config(model_name, num_embeddings=None):
9+
config = MistralConfig.from_pretrained(model_name)
10+
model = MistralForCausalLM(config)
11+
12+
if num_embeddings:
13+
model.resize_token_embeddings(num_embeddings)
14+
15+
return model
16+
17+
def load_mistral_from_checkpoint(model_name, checkpoint_path, num_embeddings=None):
18+
model = load_mistral_from_config(model_name, num_embeddings=num_embeddings)
19+
state_dict = torch.load(checkpoint_path)
20+
model.load_state_dict(state_dict)
21+
return model

0 commit comments

Comments
 (0)