File tree Expand file tree Collapse file tree 9 files changed +9
-8
lines changed
testdata/axlearn.experiments.text.gpt.c4_trainer Expand file tree Collapse file tree 9 files changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -234,7 +234,7 @@ mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif
234
234
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
235
235
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
236
236
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
237
- mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
237
+ mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.* linear1_[01]'
238
238
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
239
239
mesh_rules[7][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
240
240
mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModuleConfigModifier'
Original file line number Diff line number Diff line change @@ -810,6 +810,7 @@ def get_trainer_kwargs(
810
810
names_which_can_be_saved = "|" .join (
811
811
[
812
812
RematRegexSavePatterns .QKV_PROJ .value ,
813
+ RematRegexSavePatterns .O_PROJ .value ,
813
814
RematRegexSavePatterns .LINEAR1_X .value ,
814
815
]
815
816
),
You can’t perform that action at this time.
0 commit comments