Skip to content

Commit ed1b33f

Browse files
mingxu1067ashors1
authored andcommitted
Fix missing DEFAULT_INIT_MUTABLE_LIST
1 parent 9ee9830 commit ed1b33f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

paxml/contrib/gpu/scripts_gpu/te_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import os
22
from contextlib import contextmanager
33

4+
from praxis import base_layer
5+
46
try:
57
import transformer_engine.jax as te
68
from transformer_engine.common import recipe
79
_IS_TRANSFORMER_ENGINE_INSTALLED = True
10+
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME]
811

912
except ModuleNotFoundError as e:
1013
_IS_TRANSFORMER_ENGINE_INSTALLED = False
14+
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST
1115

1216

1317
class TransformerEngineHelperBase:

0 commit comments

Comments
 (0)