We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9ee9830 commit ed1b33fCopy full SHA for ed1b33f
paxml/contrib/gpu/scripts_gpu/te_helper.py
@@ -1,13 +1,17 @@
1
import os
2
from contextlib import contextmanager
3
4
+from praxis import base_layer
5
+
6
try:
7
import transformer_engine.jax as te
8
from transformer_engine.common import recipe
9
_IS_TRANSFORMER_ENGINE_INSTALLED = True
10
+ DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME]
11
12
except ModuleNotFoundError as e:
13
_IS_TRANSFORMER_ENGINE_INSTALLED = False
14
+ DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST
15
16
17
class TransformerEngineHelperBase:
0 commit comments