Skip to content

Commit 4b8a061

Browse files
committed
create moe_layer_freq layer fallback
1 parent 2f3ab49 commit 4b8a061

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

primus/core/projection/training_config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ def megatron_derive_default_args(args):
105105
if args.num_experts is None:
106106
args.moe_pattern = [0] * args.num_layers
107107
else:
108-
if isinstance(args.moe_layer_freq, int):
108+
# Check if moe_layer_freq is defined before accessing it
109+
if not hasattr(args, "moe_layer_freq"):
110+
# If not defined, default to all layers being MoE (equivalent to moe_layer_freq=1)
111+
args.moe_pattern = [1] * args.num_layers
112+
elif isinstance(args.moe_layer_freq, int):
109113
args.moe_pattern = [1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)]
110114
elif isinstance(args.moe_layer_freq, list):
111115
args.moe_pattern = args.moe_layer_freq
@@ -117,6 +121,9 @@ def megatron_derive_default_args(args):
117121
assert (
118122
len(args.moe_pattern) == args.num_layers
119123
), f"Invalid moe_layer_freq length: {len(args.moe_pattern)} (expected {args.num_layers})"
124+
else:
125+
# Fallback for unexpected types
126+
args.moe_pattern = [1] * args.num_layers
120127

121128
# naming conversion
122129
args.sequence_length = args.seq_length

0 commit comments

Comments
 (0)