Skip to content

Commit 823497e

Browse files
committed
migrate flash xentropy loss
1 parent b770c5b commit 823497e

File tree

3 files changed

+417
-11
lines changed

3 files changed

+417
-11
lines changed

internlm/model/model_ops/ops/cross_entropy.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
CrossEntropyApexVocabParallel,
1919
CrossEntropyLossApex,
2020
CrossEntropyPython,
21+
CrossEntropyLossFlash,
2122
)
2223
from internlm.utils.logger import get_logger
2324

@@ -86,17 +87,8 @@ def new_cross_entropy(
8687

8788
assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None."
8889

89-
try:
90-
from flash_attn.losses.cross_entropy import (
91-
CrossEntropyLoss as FlashCrossEntropyLoss,
92-
)
93-
94-
flash_cross_entropy_impl = True
95-
except (ModuleNotFoundError, ImportError):
96-
flash_cross_entropy_impl = False
97-
9890
assert (
99-
gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
91+
gpc.config.model.get("use_flash_attn", False)
10092
), "Only flash cross entropy support parallel_output"
10193

10294
assert (
@@ -108,7 +100,7 @@ def new_cross_entropy(
108100
which may result loss divergency in long sequence."
109101
)
110102

111-
return FlashCrossEntropyLoss(
103+
return CrossEntropyLossFlash(
112104
ignore_index=ignore_index,
113105
reduction=reduction,
114106
label_smoothing=label_smoothing,

internlm/model/model_ops/ops/cross_entropy_ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from .py_naive_loss import CrossEntropyPython
33
from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel
44
from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss
5+
from .flash_loss import CrossEntropyLossFlash
56

67
__all__ = [
78
"CrossEntropyLossApex",
89
"CrossEntropyPython",
910
"CrossEntropyApexVocabParallel",
1011
"VocabSequenceParallelCrossEntropyLoss",
12+
"CrossEntropyLossFlash",
1113
]

0 commit comments

Comments
 (0)