@@ -161,28 +161,6 @@ class ModelArguments:
161161 default = False ,
162162 metadata = {"help" : ("Whether to run attention softmax layer in bf16 precision for fine-tuning." )},
163163 )
164- use_flash_attention : bool = field (
165- default = False ,
166- metadata = {"help" : ("Whether to use Habana flash attention for fine-tuning." )},
167- )
168- flash_attention_recompute : bool = field (
169- default = False ,
170- metadata = {
171- "help" : (
172- "Whether to enable recompute in Habana flash attention for fine-tuning."
173- " It is applicable only when use_flash_attention is True."
174- )
175- },
176- )
177- flash_attention_causal_mask : bool = field (
178- default = False ,
179- metadata = {
180- "help" : (
181- "Whether to enable causal mask in Habana flash attention for fine-tuning."
182- " It is applicable only when use_flash_attention is True."
183- )
184- },
185- )
186164
187165 def __post_init__ (self ):
188166 if self .config_overrides is not None and (self .config_name is not None or self .model_name_or_path is not None ):
@@ -569,10 +547,10 @@ def main():
569547 # We need to add these fused kernels config
570548 if model_args .attn_softmax_bf16 :
571549 model .generation_config .attn_softmax_bf16 = True
572- if model_args . use_flash_attention :
550+ if training_args . attn_implementation == "gaudi_fused_sdpa" :
573551 model .generation_config .use_flash_attention = True
574- model .generation_config .flash_attention_recompute = model_args .flash_attention_recompute
575- model .generation_config .flash_attention_causal_mask = model_args .flash_attention_causal_mask
552+ model .generation_config .flash_attention_recompute = training_args .flash_attention_recompute
553+ model .generation_config .flash_attention_causal_mask = training_args .flash_attention_causal_mask
576554
577555 # Preprocessing the datasets.
578556 # First we tokenize all the texts.
0 commit comments