@@ -45,12 +45,6 @@ class ScriptArguments:
4545 use_flash_attention : Optional [bool ] = field (
4646 default = False , metadata = {"help" : "Whether to use Habana flash attention for fine-tuning." }
4747 )
48- flash_attention_recompute : Optional [bool ] = field (
49- default = False , metadata = {"help" : "Whether to enable recompute in Habana flash attention for fine-tuning." }
50- )
51- flash_attention_causal_mask : Optional [bool ] = field (
52- default = False , metadata = {"help" : "Whether to enable causal mask in Habana flash attention for fine-tuning." }
53- )
5448
5549 # LoraConfig
5650 lora_alpha : Optional [float ] = field (default = 16 , metadata = {"help" : "the lora alpha parameter" })
@@ -159,13 +153,11 @@ def create_datasets(tokenizer, args, seed=None):
159153 )
160154
161155 base_model .config .use_cache = False
162- if not script_args .use_flash_attention and (
163- script_args .flash_attention_recompute or script_args .flash_attention_recompute
164- ):
156+ if not script_args .use_flash_attention and training_args .flash_attention_recompute :
165157 assert "Need to enable use_flash_attention"
166158 base_model .generation_config .use_flash_attention = script_args .use_flash_attention
167- base_model .generation_config .flash_attention_recompute = script_args .flash_attention_recompute
168- base_model .generation_config .flash_attention_causal_mask = script_args .flash_attention_causal_mask
159+ base_model .generation_config .flash_attention_recompute = training_args .flash_attention_recompute
160+ base_model .generation_config .flash_attention_causal_mask = training_args .flash_attention_causal_mask
169161
170162 if is_zero3_enabled and training_args .use_zero3_leaf_promotion :
171163 apply_zero3_leaf_promotion (base_model )
0 commit comments