|
12 | 12 |
|
13 | 13 | import dataclasses
|
14 | 14 | from dataclasses import dataclass
|
15 |
| -from typing import Callable |
| 15 | +from typing import Callable, Optional |
16 | 16 | from domino.timer import Timers
|
17 | 17 | from megatron.tokenizer import build_tokenizer
|
18 | 18 |
|
@@ -206,6 +206,8 @@ def parse_args():
|
206 | 206 | help='Report loss and timing interval.')
|
207 | 207 | parser.add_argument('--save-interval', type=int, default=None,
|
208 | 208 | help='Number of iterations between checkpoint saves.')
|
| 209 | + parser.add_argument('--fused-linear-loss', action='store_true', |
| 210 | + help='whether to use LigerFusedLinearCrossEntropyFunction()') |
209 | 211 |
|
210 | 212 | args = parser.parse_args()
|
211 | 213 |
|
@@ -359,6 +361,8 @@ class TransformerConfig():
|
359 | 361 | no_sync_func: Callable = None
|
360 | 362 | # grad_sync_func: Callable = None
|
361 | 363 | # param_sync_func: Callable = None
|
| 364 | + |
| 365 | + fused_linear_loss: bool = False |
362 | 366 |
|
363 | 367 | def __post_init__(self):
|
364 | 368 | """ Python dataclass method that is used to modify attributes after initialization.
|
@@ -400,5 +404,6 @@ def core_transformer_config_from_args(args):
|
400 | 404 | kw_args['init_method'] = args.init_method
|
401 | 405 | kw_args['output_layer_init_method'] = args.init_method
|
402 | 406 | kw_args['params_dtype'] = args.params_dtype
|
| 407 | + kw_args['fused_linear_loss'] = args.fused_linear_loss |
403 | 408 |
|
404 | 409 | return TransformerConfig(**kw_args)
|
0 commit comments