Skip to content

Commit e165692

Browse files
wang2yn84The tunix Authors
authored andcommitted
Code update
PiperOrigin-RevId: 848200226
1 parent c948ebe commit e165692

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

tunix/models/gemma/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_default_sharding(is_sampling: bool = False):
7979
ffw_weight_fd=('tp', fsdp),
8080
rms_norm_weight=('tp',),
8181
act_btd=('fsdp', None, None if is_sampling else 'tp'),
82-
act_btf=('fsdp', None, None),
82+
act_btf=('fsdp', None, 'tp'),
8383
act_btnh=('fsdp', None, 'tp', None),
8484
score_weight_d1=(fsdp, None),
8585
)

tunix/models/gemma3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_default_sharding(is_sampling: bool = False):
7272
ffw_weight_fd=('tp', fsdp),
7373
rms_norm_weight=('tp',),
7474
act_btd=('fsdp', None, None if is_sampling else 'tp'),
75-
act_btf=('fsdp', None, None),
75+
act_btf=('fsdp', None, 'tp'),
7676
act_btnh=('fsdp', None, 'tp', None),
7777
)
7878

tunix/models/llama3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_default_sharding(is_sampling: bool = False):
7373
ffw_weight_fd=('tp', fsdp),
7474
rms_norm_weight=('tp',),
7575
act_btd=('fsdp', None, None if is_sampling else 'tp'),
76-
act_btf=('fsdp', None, None),
76+
act_btf=('fsdp', None, 'tp'),
7777
act_btnh=('fsdp', None, 'tp', None),
7878
)
7979

tunix/models/qwen2/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_default_sharding(is_sampling: bool = False):
7575
ffw_weight_fd=('tp', fsdp),
7676
rms_norm_weight=('tp',),
7777
act_btd=('fsdp', None, None if is_sampling else 'tp'),
78-
act_btf=('fsdp', None, None),
78+
act_btf=('fsdp', None, 'tp'),
7979
act_btnh=('fsdp', None, 'tp', None),
8080
exp_weight_cdf=('fsdp', None, 'tp'),
8181
exp_weight_cfd=('fsdp', 'tp', None),

tunix/models/qwen3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_default_sharding(is_sampling: bool = False):
7575
ffw_weight_fd=('tp', fsdp),
7676
rms_norm_weight=('tp',),
7777
act_btd=('fsdp', None, None if is_sampling else 'tp'),
78-
act_btf=('fsdp', None, None),
78+
act_btf=('fsdp', None, 'tp'),
7979
act_btnh=('fsdp', None, 'tp', None),
8080
exp_weight_cdf=('fsdp', None, 'tp'),
8181
exp_weight_cfd=('fsdp', 'tp', None),

0 commit comments

Comments
 (0)