1+ import os
2+ import functools
3+ from pathlib import Path
4+
5+ import torch
6+ import torch .distributed as dist
7+
8+ from torch .distributed .fsdp .fully_sharded_data_parallel import (
9+ FullyShardedDataParallel as FSDP ,
10+ CPUOffload ,
11+ )
12+
13+ from torch .distributed .fsdp import (
14+ MixedPrecision ,
15+ ShardingStrategy ,
16+ )
17+ from torch .distributed .fsdp .wrap import (
18+ transformer_auto_wrap_policy ,
19+ )
20+ from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
21+ checkpoint_wrapper ,
22+ CheckpointImpl ,
23+ apply_activation_checkpointing ,
24+ )
25+
26+ from transformers import (
27+ MistralForCausalLM ,
28+ MistralConfig ,
29+ default_data_collator ,
30+ )
31+ from transformers .models .mistral .modeling_mistral import MistralDecoderLayer
32+ from optimum .bettertransformer import BetterTransformer
33+
34+ from higgsfield .checkpoint .fsdp_checkpoint import (
35+ save_distributed_model_rank0 ,
36+ fsdp_model_state_dict_rank0 ,
37+ )
38+
39+ from higgsfield .mistral .mistral_utils import (
40+ load_mistral_from_checkpoint ,
41+ load_mistral_from_config ,
42+ )
43+
44+ class Mistral (FSDP ):
45+ def __init__ (
46+ self ,
47+ model_name ,
48+ checkpoint_path = None ,
49+ zero_stage = 3 ,
50+ fast_attn = False ,
51+ precision = "bf16" ,
52+ cpu_init_rank0 = False ,
53+ cpu_offload = False ,
54+ num_embeddings = None ,
55+ cache_dir = None ,
56+ ):
57+
58+ rank = dist .get_rank ()
59+
60+
61+ model = MistralForCausalLM .from_pretrained (model_name , cache_dir = cache_dir )
62+
63+ if num_embeddings :
64+ model .resize_token_embeddings (num_embeddings )
65+
66+
67+ if fast_attn :
68+ #raise NotImplementedError("Fast attention is not supported yet")
69+ model = BetterTransformer .transform (model )
70+
71+ fpSixteen = MixedPrecision (
72+ param_dtype = torch .float16 ,
73+ reduce_dtype = torch .float16 ,
74+ buffer_dtype = torch .float16 ,
75+ )
76+
77+ bfSixteen_mixed = MixedPrecision (
78+ param_dtype = torch .float32 ,
79+ reduce_dtype = torch .bfloat16 ,
80+ buffer_dtype = torch .bfloat16 ,
81+ )
82+
83+ pure_bf16 = False
84+ if precision == "fp16" :
85+ mixed_precision_policy = fpSixteen
86+
87+ elif precision == "bf16" :
88+ mixed_precision_policy = None
89+ pure_bf16 = True
90+
91+ elif precision == "bf16_mixed" :
92+ mixed_precision_policy = bfSixteen_mixed
93+
94+ else :
95+ mixed_precision_policy = None
96+
97+ if pure_bf16 :
98+ model .to (torch .bfloat16 )
99+
100+ wrapping_policy = functools .partial (
101+ transformer_auto_wrap_policy ,
102+ transformer_layer_cls = {
103+ MistralDecoderLayer ,
104+ }
105+ )
106+
107+ if zero_stage == 0 :
108+ sharding_strategy = ShardingStrategy .NO_SHARD
109+
110+ elif zero_stage == 1 :
111+ raise NotImplementedError ("stage 1 is not supported. Only 0 2 3" )
112+
113+ elif zero_stage == 2 :
114+ sharding_strategy = ShardingStrategy .SHARD_GRAD_OP
115+
116+ elif zero_stage == 3 :
117+ sharding_strategy = ShardingStrategy .FULL_SHARD
118+ else :
119+ raise NotImplementedError ("stage can be only 0 2 3" )
120+
121+ if cpu_init_rank0 and rank != 0 :
122+ param_init_fn = lambda module : module .to_empty (
123+ device = torch .device ('cuda' ),
124+ recurse = False ,
125+ )
126+ else :
127+ param_init_fn = None
128+
129+ if cpu_offload :
130+ cpu_offload = CPUOffload (offload_params = True )
131+ else :
132+ cpu_offload = None
133+
134+ super ().__init__ (
135+ model ,
136+ auto_wrap_policy = wrapping_policy ,
137+ cpu_offload = cpu_offload ,
138+ mixed_precision = mixed_precision_policy ,
139+ sharding_strategy = sharding_strategy ,
140+ device_id = torch .cuda .current_device (),
141+ limit_all_gathers = True ,
142+ sync_module_states = cpu_init_rank0 ,
143+ param_init_fn = param_init_fn ,
144+ )
145+
146+ non_reentrant_wrapper = functools .partial (
147+ checkpoint_wrapper ,
148+ checkpoint_impl = CheckpointImpl .NO_REENTRANT ,
149+ )
150+
151+ check_fn = lambda submodule : isinstance (submodule , MistralDecoderLayer )
152+
153+ apply_activation_checkpointing (
154+ self ,
155+ checkpoint_wrapper_fn = non_reentrant_wrapper ,
156+ check_fn = check_fn ,
157+ )
158+
159+ fsdp = True
160+ self .precision = precision
161+ self .fsdp = fsdp
162+ self .model_name = model_name
163+ self .num_embeddings = num_embeddings
164+
165+ def __call__ (self , batch ):
166+ local_rank = int (os .environ ["LOCAL_RANK" ])
167+
168+ for key in batch .keys ():
169+ batch [key ] = batch [key ].to (local_rank )
170+
171+ if self .precision == "fp16" :
172+ with torch .cuda .amp .autocast ():
173+ loss = super ().__call__ (** batch ).loss
174+ else :
175+ loss = super ().__call__ (** batch ).loss
176+
177+ return loss
178+
179+ def save_model (self , save_path ):
180+ '''
181+ Save model's weight to master node
182+ ~/.cache/higgsfield/{save_path}
183+ '''
184+ if "/" == save_path [0 ]:
185+ save_path = save_path [1 :]
186+
187+ head , tail = os .path .split (save_path )
188+
189+ path = Path .home () / ".cache/higgsfield" / head
190+ path .mkdir (exist_ok = True , parents = True )
191+
192+ save_distributed_model_rank0 (path / tail , self )
193+
194+ def save_huggingface_model (self , save_path ):
195+ '''
196+ Save model's weight in huggingface format to master node
197+ ~/.cache/higgsfield/{save_path}
198+ '''
199+ if "/" == save_path [0 ]:
200+ save_path = save_path [1 :]
201+
202+ head , tail = os .path .split (save_path )
203+
204+ path = Path .home () / ".cache/higgsfield" / head
205+ path .mkdir (exist_ok = True , parents = True )
206+ cpu_state = fsdp_model_state_dict_rank0 (self )
207+
208+ if dist .get_rank () == 0 :
209+ model = load_mistral_from_config (self .model_name , num_embeddings = self .num_embeddings )
210+ model .load_state_dict (cpu_state )
211+ model .save_pretrained (path / tail )
212+
213+ def push_to_hub (self , repo_id , token ):
214+ cpu_state = fsdp_model_state_dict_rank0 (self )
215+
216+ if dist .get_rank () == 0 :
217+ model = load_mistral_from_config (self .model_name , num_embeddings = self .num_embeddings )
218+ model .load_state_dict (cpu_state )
219+ model .push_to_hub (repo_id , token = token )
0 commit comments