File tree 3 files changed +417
-11
lines changed
internlm/model/model_ops/ops
3 files changed +417
-11
lines changed Original file line number Diff line number Diff line change 18
18
CrossEntropyApexVocabParallel ,
19
19
CrossEntropyLossApex ,
20
20
CrossEntropyPython ,
21
+ CrossEntropyLossFlash ,
21
22
)
22
23
from internlm .utils .logger import get_logger
23
24
@@ -86,17 +87,8 @@ def new_cross_entropy(
86
87
87
88
assert gpc .get_group (ParallelMode .TENSOR ) is not None , "The process group should not be None."
88
89
89
- try :
90
- from flash_attn .losses .cross_entropy import (
91
- CrossEntropyLoss as FlashCrossEntropyLoss ,
92
- )
93
-
94
- flash_cross_entropy_impl = True
95
- except (ModuleNotFoundError , ImportError ):
96
- flash_cross_entropy_impl = False
97
-
98
90
assert (
99
- gpc .config .model .get ("use_flash_attn" , False ) and flash_cross_entropy_impl
91
+ gpc .config .model .get ("use_flash_attn" , False )
100
92
), "Only flash cross entropy support parallel_output"
101
93
102
94
assert (
@@ -108,7 +100,7 @@ def new_cross_entropy(
108
100
which may result loss divergency in long sequence."
109
101
)
110
102
111
- return FlashCrossEntropyLoss (
103
+ return CrossEntropyLossFlash (
112
104
ignore_index = ignore_index ,
113
105
reduction = reduction ,
114
106
label_smoothing = label_smoothing ,
Original file line number Diff line number Diff line change 2
2
from .py_naive_loss import CrossEntropyPython
3
3
from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel
4
4
from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss
5
+ from .flash_loss import CrossEntropyLossFlash
5
6
6
7
__all__ = [
7
8
"CrossEntropyLossApex" ,
8
9
"CrossEntropyPython" ,
9
10
"CrossEntropyApexVocabParallel" ,
10
11
"VocabSequenceParallelCrossEntropyLoss" ,
12
+ "CrossEntropyLossFlash" ,
11
13
]
You can’t perform that action at this time.
0 commit comments