Skip to content

Commit 5966cc5

Browse files
committed
Save O_PROJ on Fuji 70B-v2 for neuron
1 parent 2146b74 commit 5966cc5

File tree

9 files changed

+9
-8
lines changed

9 files changed

+9
-8
lines changed

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237-
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237+
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*linear1_[01]'
238238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'

axlearn/experiments/text/gpt/fuji.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ def get_trainer_kwargs(
810810
names_which_can_be_saved="|".join(
811811
[
812812
RematRegexSavePatterns.QKV_PROJ.value,
813+
RematRegexSavePatterns.O_PROJ.value,
813814
RematRegexSavePatterns.LINEAR1_X.value,
814815
]
815816
),

0 commit comments

Comments
 (0)