2727
2828import torch
2929import deepspeed
30+ from deepspeed .runtime .data_pipeline .curriculum_scheduler import CurriculumScheduler
3031import numpy as np
3132
3233from megatron .utils import (
@@ -301,7 +302,7 @@ def get_batch(neox_args, data_iterator):
301302 )
302303
303304
304- def get_batch_pipe (data , neox_args ):
305+ def get_batch_pipe (data , neox_args , curr_scheduler = None ):
305306 """A modification of get_batch() to work with the latest batch instead of an iterator."""
306307 # Items and their type.
307308 keys = ["text" ]
@@ -310,12 +311,31 @@ def get_batch_pipe(data, neox_args):
310311 tokens , labels , loss_mask , attention_mask , position_ids = _get_batch (
311312 neox_args , neox_args .tokenizer , keys , data , datatype
312313 )
314+ if curr_scheduler is not None :
315+ # iteration + 1 to align with how/when DeepSpeed updates the buffers
316+ curriculum_seqlen = curr_scheduler .update_difficulty (neox_args .iteration + 1 )
317+ if curriculum_seqlen < tokens .size ()[1 ]:
318+ # seqlen-based curriculum learning
319+ # input_ids, position_ids, labels have size [batch size, seqlen]
320+ # input_ids = input_ids[:, :curriculum_seqlen].contiguous()
321+ tokens = tokens [:, :curriculum_seqlen ].contiguous ()
322+ position_ids = position_ids [:, :curriculum_seqlen ].contiguous ()
323+ if labels is not None :
324+ labels = labels [:, :curriculum_seqlen ].contiguous ()
325+ if loss_mask is not None :
326+ loss_mask = loss_mask [:, :curriculum_seqlen ].contiguous ()
327+ # attention_mask has size [1, 1, seqlen, seqlen]
328+ attention_mask = attention_mask [
329+ :, :, :curriculum_seqlen , :curriculum_seqlen
330+ ].contiguous ()
313331
314332 # unpack data
315333 return (tokens , position_ids , attention_mask ), (labels , loss_mask )
316334
317335
318- def forward_step (data_iterator , model , neox_args , timers , return_logits = False ):
336+ def forward_step (
337+ data_iterator , model , neox_args , timers , return_logits = False , is_train = False
338+ ):
319339 """Forward step."""
320340 if neox_args .is_pipe_parallel :
321341 return model .eval_batch (data_iterator , return_logits = return_logits )
@@ -326,10 +346,18 @@ def forward_step(data_iterator, model, neox_args, timers, return_logits=False):
326346 tokens , labels , loss_mask , attention_mask , position_ids = get_batch (
327347 neox_args = neox_args , data_iterator = data_iterator
328348 )
349+
329350 if timers is not None :
330351 timers ("batch generator" ).stop ()
331352
332- outputs = model ((tokens , position_ids , attention_mask ))
353+ outputs = model ((tokens , position_ids , attention_mask ), neox_args = neox_args )
354+ if (
355+ is_train
356+ and neox_args .curriculum_learning
357+ and neox_args .curriculum_seqlen < neox_args .seq_length
358+ ):
359+ loss_mask = loss_mask [:, : neox_args .curriculum_seqlen ].contiguous ()
360+ labels = labels [:, : neox_args .curriculum_seqlen ].contiguous ()
333361 loss = cross_entropy (
334362 outputs , (labels , loss_mask ), _fp16 = neox_args .fp16_lm_cross_entropy
335363 )
@@ -589,7 +617,17 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
589617
590618 if neox_args .is_pipe_parallel :
591619 model .set_has_attention_mask (True )
592- model .set_batch_fn (partial (get_batch_pipe , neox_args = neox_args ))
620+ if neox_args .curriculum_learning :
621+ curr_scheduler = CurriculumScheduler (neox_args .curriculum_learning )
622+ if iteration is not None and iteration > 0 :
623+ curr_scheduler .update_difficulty (iteration )
624+ else :
625+ curr_scheduler = None
626+ model .set_batch_fn (
627+ partial (
628+ get_batch_pipe , neox_args = neox_args , curr_scheduler = curr_scheduler
629+ )
630+ )
593631 else :
594632 raise ValueError ("Must be using deepspeed to run neox" )
595633
@@ -647,6 +685,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
647685 timers = timers ,
648686 data_iterator = data_iterator ,
649687 model = model ,
688+ is_train = True ,
650689 )
651690 timers ("forward" ).stop ()
652691 losses .append (loss )
@@ -736,6 +775,7 @@ def train(
736775 lr_scheduler = lr_scheduler ,
737776 )
738777 iteration += 1
778+ neox_args .iteration = iteration
739779
740780 overflow_monitor .check (skipped_iter ) # check for repeated overflow
741781 if neox_args .log_gradient_noise_scale : # log noise scale if applicable
0 commit comments