Skip to content

Commit 68d223c

Browse files
Quentin-AnthonyQuentin TastyRiceDashiell Standergithub-actionsdashstander
authored
Curriculum Learning Support (#695)
* Remove deprecated deepspeed.utils.distributed call * Initial curriculum learning support * Add is_train flag for curriculum learning * Update NeoXArgs docs automatically * add comment arg Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Add slurm stuff * Update NeoXArgs docs automatically * Allow json * Update NeoXArgs docs automatically * Apply curriculum learning seq_len to pipeline parallel data loading Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Actually updating the curriculum seq_len Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Actually updating the curriculum seq_len Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Actually updating the curriculum seq_len Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Actually updating the curriculum seq_len Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Iteration + 1 Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Clean up comments and debug print statements Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Debug print again Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * more print statements Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Remove debug print statements Signed-off-by: Dashiell Stander <[email protected]> * Update NeoXArgs docs automatically * Pre-commit * Update NeoXArgs docs automatically * Update NeoXArgs docs automatically * Update NeoXArgs docs automatically --------- Signed-off-by: Dashiell Stander <[email protected]> Co-authored-by: Quentin TastyRice <[email protected]> Co-authored-by: Dashiell Stander <[email protected]> Co-authored-by: github-actions <[email protected]> Co-authored-by: Dashiell Stander <[email protected]>
1 parent 2b84f9a commit 68d223c

File tree

6 files changed

+111
-7
lines changed

6 files changed

+111
-7
lines changed

configs/neox_arguments.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Logging Arguments
111111

112112
- **git_hash**: str
113113

114-
Default = d49acf3
114+
Default = cbed1b5
115115

116116
current git hash of repository
117117

@@ -1676,6 +1676,22 @@ Args for deepspeed config
16761676

16771677

16781678

1679+
- **curriculum_learning**: dict
1680+
1681+
Default = None
1682+
1683+
1684+
1685+
1686+
1687+
- **curriculum_seqlen**: int
1688+
1689+
Default = 0
1690+
1691+
Internal var for tracking the current seqlen
1692+
1693+
1694+
16791695
- **steps_per_print**: int
16801696

16811697
Default = 10

megatron/logging.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,16 @@ def add_to_logging(name):
297297
1, neox_args.log_interval - total_loss_dict[skipped_iters_key]
298298
)
299299

300+
# log curriculum learning
301+
if neox_args.curriculum_learning:
302+
tb_wandb_log(
303+
"curriculum_seqlen",
304+
neox_args.curriculum_seqlen,
305+
iteration,
306+
use_wandb=neox_args.use_wandb,
307+
tensorboard_writer=neox_args.tensorboard_writer,
308+
)
309+
300310
# log tflop / gpu
301311
flops_per_s_per_gpu = get_flops(
302312
neox_args=neox_args, model=model, iter_time_s=iteration_time

megatron/model/utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
2222
from megatron.model.fused_softmax import SoftmaxFusionTypes
2323
from types import GeneratorType
24+
import torch.distributed as dist
2425

2526

2627
def get_params_for_weight_decay_optimization(module, neox_args):
@@ -120,7 +121,33 @@ def train_mode(self):
120121
"""
121122
_set_use_cache(self.sequential, False)
122123

123-
def forward(self, forward_input):
124+
def forward(
125+
self, forward_input, curriculum_seqlen=None, labels=None, neox_args=None
126+
):
127+
128+
if (
129+
curriculum_seqlen is not None
130+
and isinstance(forward_input, tuple)
131+
and len(forward_input) == 3
132+
):
133+
neox_args.update_value("curriculum_seqlen", curriculum_seqlen)
134+
tokens = forward_input[0]
135+
input_ids = forward_input[1]
136+
attention_mask = forward_input[2]
137+
if curriculum_seqlen < input_ids.size()[1]:
138+
# seqlen-based curriculum learning
139+
# input_ids, position_ids, labels have size [batch size, seqlen]
140+
input_ids = input_ids[:, :curriculum_seqlen].contiguous()
141+
tokens = tokens[:, :curriculum_seqlen].contiguous()
142+
# position_ids = position_ids[:, :curriculum_seqlen].contiguous()
143+
if labels is not None:
144+
labels = labels[:, :curriculum_seqlen].contiguous()
145+
# attention_mask has size [1, 1, seqlen, seqlen]
146+
attention_mask = attention_mask[
147+
:, :, :curriculum_seqlen, :curriculum_seqlen
148+
].contiguous()
149+
forward_input = (tokens, input_ids, attention_mask)
150+
124151
def exec_range_func(start, end):
125152
"""Helper function to be used with checkpoint()
126153
Adapted from torch.utils.checkpoint:checkpoint_sequential()

megatron/neox_arguments/arguments.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,10 @@ def consume_deepy_args(cls):
334334
conf_files = [os.path.join(args_parsed.conf_dir, f) for f in conf_files]
335335

336336
# enables us to pass in `small` instead of `small.yml`
337-
conf_files = [(cf if cf.endswith(".yml") else cf + ".yml") for cf in conf_files]
337+
conf_files = [
338+
(cf if cf.endswith(".yml") or cf.endswith(".json") else cf + ".yml")
339+
for cf in conf_files
340+
]
338341

339342
# determine overwrite values
340343
overwrite_values = dict()

megatron/neox_arguments/deepspeed_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate):
102102
zero_optimization: dict = None
103103
""""""
104104

105+
curriculum_learning: dict = None
106+
""""""
107+
108+
curriculum_seqlen: int = 0
109+
"""
110+
Internal var for tracking the current seqlen
111+
"""
112+
105113
steps_per_print: int = 10
106114
"""
107115
Print train loss every N steps.

megatron/training.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torch
2929
import deepspeed
30+
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
3031
import numpy as np
3132

3233
from megatron.utils import (
@@ -301,7 +302,7 @@ def get_batch(neox_args, data_iterator):
301302
)
302303

303304

304-
def get_batch_pipe(data, neox_args):
305+
def get_batch_pipe(data, neox_args, curr_scheduler=None):
305306
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
306307
# Items and their type.
307308
keys = ["text"]
@@ -310,12 +311,31 @@ def get_batch_pipe(data, neox_args):
310311
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
311312
neox_args, neox_args.tokenizer, keys, data, datatype
312313
)
314+
if curr_scheduler is not None:
315+
# iteration + 1 to align with how/when DeepSpeed updates the buffers
316+
curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1)
317+
if curriculum_seqlen < tokens.size()[1]:
318+
# seqlen-based curriculum learning
319+
# input_ids, position_ids, labels have size [batch size, seqlen]
320+
# input_ids = input_ids[:, :curriculum_seqlen].contiguous()
321+
tokens = tokens[:, :curriculum_seqlen].contiguous()
322+
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
323+
if labels is not None:
324+
labels = labels[:, :curriculum_seqlen].contiguous()
325+
if loss_mask is not None:
326+
loss_mask = loss_mask[:, :curriculum_seqlen].contiguous()
327+
# attention_mask has size [1, 1, seqlen, seqlen]
328+
attention_mask = attention_mask[
329+
:, :, :curriculum_seqlen, :curriculum_seqlen
330+
].contiguous()
313331

314332
# unpack data
315333
return (tokens, position_ids, attention_mask), (labels, loss_mask)
316334

317335

318-
def forward_step(data_iterator, model, neox_args, timers, return_logits=False):
336+
def forward_step(
337+
data_iterator, model, neox_args, timers, return_logits=False, is_train=False
338+
):
319339
"""Forward step."""
320340
if neox_args.is_pipe_parallel:
321341
return model.eval_batch(data_iterator, return_logits=return_logits)
@@ -326,10 +346,18 @@ def forward_step(data_iterator, model, neox_args, timers, return_logits=False):
326346
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
327347
neox_args=neox_args, data_iterator=data_iterator
328348
)
349+
329350
if timers is not None:
330351
timers("batch generator").stop()
331352

332-
outputs = model((tokens, position_ids, attention_mask))
353+
outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args)
354+
if (
355+
is_train
356+
and neox_args.curriculum_learning
357+
and neox_args.curriculum_seqlen < neox_args.seq_length
358+
):
359+
loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous()
360+
labels = labels[:, : neox_args.curriculum_seqlen].contiguous()
333361
loss = cross_entropy(
334362
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
335363
)
@@ -589,7 +617,17 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
589617

590618
if neox_args.is_pipe_parallel:
591619
model.set_has_attention_mask(True)
592-
model.set_batch_fn(partial(get_batch_pipe, neox_args=neox_args))
620+
if neox_args.curriculum_learning:
621+
curr_scheduler = CurriculumScheduler(neox_args.curriculum_learning)
622+
if iteration is not None and iteration > 0:
623+
curr_scheduler.update_difficulty(iteration)
624+
else:
625+
curr_scheduler = None
626+
model.set_batch_fn(
627+
partial(
628+
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
629+
)
630+
)
593631
else:
594632
raise ValueError("Must be using deepspeed to run neox")
595633

@@ -647,6 +685,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
647685
timers=timers,
648686
data_iterator=data_iterator,
649687
model=model,
688+
is_train=True,
650689
)
651690
timers("forward").stop()
652691
losses.append(loss)
@@ -736,6 +775,7 @@ def train(
736775
lr_scheduler=lr_scheduler,
737776
)
738777
iteration += 1
778+
neox_args.iteration = iteration
739779

740780
overflow_monitor.check(skipped_iter) # check for repeated overflow
741781
if neox_args.log_gradient_noise_scale: # log noise scale if applicable

0 commit comments

Comments
 (0)