Skip to content

Commit 10dcc3f

Browse files
committed
add fused and chunked linear-loss function
Signed-off-by: dhx <[email protected]>
1 parent 223665c commit 10dcc3f

File tree

3 files changed

+356
-20
lines changed

3 files changed

+356
-20
lines changed

Diff for: training/DeepSpeed-Domino/domino/arguments.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import dataclasses
1414
from dataclasses import dataclass
15-
from typing import Callable
15+
from typing import Callable, Optional
1616
from domino.timer import Timers
1717
from megatron.tokenizer import build_tokenizer
1818

@@ -206,6 +206,8 @@ def parse_args():
206206
help='Report loss and timing interval.')
207207
parser.add_argument('--save-interval', type=int, default=None,
208208
help='Number of iterations between checkpoint saves.')
209+
parser.add_argument('--fused-linear-loss', action='store_true',
210+
help='whether to use LigerFusedLinearCrossEntropyFunction()')
209211

210212
args = parser.parse_args()
211213

@@ -359,6 +361,8 @@ class TransformerConfig():
359361
no_sync_func: Callable = None
360362
# grad_sync_func: Callable = None
361363
# param_sync_func: Callable = None
364+
365+
fused_linear_loss: bool = False
362366

363367
def __post_init__(self):
364368
""" Python dataclass method that is used to modify attributes after initialization.
@@ -400,5 +404,6 @@ def core_transformer_config_from_args(args):
400404
kw_args['init_method'] = args.init_method
401405
kw_args['output_layer_init_method'] = args.init_method
402406
kw_args['params_dtype'] = args.params_dtype
407+
kw_args['fused_linear_loss'] = args.fused_linear_loss
403408

404409
return TransformerConfig(**kw_args)

0 commit comments

Comments
 (0)