Skip to content

Commit b87d831

Browse files
authored
Fix hang when FP16 enabled and fvs!=bvs (#19)
* fix some miscs first * fix hang when fvs!=bvs and fp16 * fix minor
1 parent 99fc1c2 commit b87d831

File tree

5 files changed

+29
-11
lines changed

5 files changed

+29
-11
lines changed

megatron/arguments.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,6 @@ def validate_args(args, defaults={}):
436436
if args.use_distributed_optimizer:
437437
raise RuntimeError(
438438
"SpiralPipe currently does not support distributed optimizer")
439-
if args.spiral_cross_mapping:
440-
if args.spiral_forward_virtual_size != args.spiral_backward_virtual_size:
441-
raise RuntimeError(
442-
"SpiralPipe with cross mapping requires forward and backward virtual size to be the same")
443439

444440
# GQA
445441
if args.num_key_value_heads is None:

megatron/core/parallel_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def destroy_model_parallel():
10031003

10041004
def apply_spiral_cross_mapping(ranks):
10051005
""" Converts sorted pp rank list into list whose index is cm_rank and element is pp_rank.
1006-
Maps into pattern which elimiates inter-node fetch (when fvs==bvs) and minimizes intra-node activation comm.
1006+
Maps into pattern which elimiates inter-node fetch (when fvs==bvs) and minimizes inter-node activation comm.
10071007
10081008
TODO (SpiralPipe) currently assumes nprocs_per_node=4 and stride=2
10091009
"""

megatron/model/module.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,12 @@ def float_conversion(val):
171171

172172
class Float16Module(MegatronModule):
173173

174-
def __init__(self, module, args):
174+
def __init__(self, module, args, spiral_disable_cast=False):
175+
"""Wrap the module with float16/bfloat16 cast.
176+
177+
Arguments:
178+
spiral_disable_cast: bool, whether to disable float32 <-> float16/bfloat16 cast. SpiralPhaseList contains mutliple Float16Modules, while the `fp16/bf16 cast` is only needed at the first Float16Module in the first stage (i.e., fid/bid=0, fbp/bbp=0) and the `fp32 cast` is only needed at the last Float16Module in the last stage (i.e., fid/bid=last, fbp/bbp=last*).
179+
"""
175180
super(Float16Module, self).__init__()
176181

177182
if args.fp16:
@@ -186,17 +191,19 @@ def float16_convertor(val):
186191
raise Exception('should not be here')
187192

188193
self.float16_convertor = float16_convertor
194+
if args.spiral:
195+
self.spiral_disable_cast = spiral_disable_cast
189196

190197

191198
def set_input_tensor(self, input_tensor):
192199
return self.module.set_input_tensor(input_tensor)
193200

194201

195202
def forward(self, *inputs, **kwargs):
196-
if mpu.is_pipeline_first_stage():
203+
if mpu.is_pipeline_first_stage() and not self.spiral_disable_cast:
197204
inputs = fp32_to_float16(inputs, self.float16_convertor)
198205
outputs = self.module(*inputs, **kwargs)
199-
if mpu.is_pipeline_last_stage():
206+
if mpu.is_pipeline_last_stage() and not self.spiral_disable_cast:
200207
outputs = float16_to_fp32(outputs)
201208
return outputs
202209

megatron/spiral/init_context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,6 @@ def _fetch_data(param, non_blocking=False):
457457
).view(param.spiral_shape)
458458
else:
459459
param.data = torch.empty(param.spiral_shape, device=self.local_device, dtype=param.dtype)
460-
assert not get_args().spiral_cross_mapping, "Spiral cross mapping should eliminate remote fetch"
461460
get_thunder_group().FetchRemoteParam(
462461
param.spiral_id,
463462
non_blocking,
@@ -471,7 +470,6 @@ def _fetch_data(param, non_blocking=False):
471470
param.data.copy_(param.spiral_tensor, non_blocking=non_blocking)
472471
else:
473472
param.data = torch.empty(param.spiral_shape, device=self.local_device, dtype=param.dtype)
474-
assert not get_args().spiral_cross_mapping, "Spiral cross mapping should eliminate remote fetch"
475473
get_thunder_group().FetchRemoteParam(
476474
param.spiral_id,
477475
non_blocking,

megatron/training.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,25 @@ def _model_provider_func_wrapper(
382382

383383
# wrap model with Float16Module
384384
if args.fp16 or args.bf16:
385+
# only the first or last stage can require cast, where a stage is a spiral phase list
386+
# all modules in a spiral phase list shares fid and bid
387+
# only the first or last module in the same spiral phase list requires cast
388+
enable_cast = (
389+
(fvr == 0 and fbp == 0 and mpu.is_pipeline_first_stage())
390+
or (
391+
fvr == mpu.get_spiral_forward_virtual_size() - 1
392+
and fbp == sbs.get_spiral_forward_stage_build_phase_size() - 1
393+
and mpu.is_pipeline_last_stage()
394+
)
395+
or (bvr == 0 and bbp == 0 and mpu.is_pipeline_first_stage())
396+
or (
397+
bvr == mpu.get_spiral_backward_virtual_size() - 1
398+
and bbp == sbs.get_spiral_backward_stage_build_phase_size() - 1
399+
and mpu.is_pipeline_last_stage()
400+
)
401+
)
385402
with SpiralWrapperInitContext(enabled=True):
386-
this_model = Float16Module(this_model, args)
403+
this_model = Float16Module(this_model, args, spiral_disable_cast=not enable_cast)
387404

388405
# reset states of the callee
389406
if mpu.is_spiral_forward_stage():

0 commit comments

Comments
 (0)