Skip to content

Commit 5acd007

Browse files
pbielakPiotr Bielak
andauthored
Remove flash attention flags from run_clm.py (#2314)
Co-authored-by: Piotr Bielak <pbielak@habana.ai>
1 parent f5ed4bf commit 5acd007

File tree

1 file changed

+3
-25
lines changed

1 file changed

+3
-25
lines changed

examples/language-modeling/run_clm.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -161,28 +161,6 @@ class ModelArguments:
161161
default=False,
162162
metadata={"help": ("Whether to run attention softmax layer in bf16 precision for fine-tuning.")},
163163
)
164-
use_flash_attention: bool = field(
165-
default=False,
166-
metadata={"help": ("Whether to use Habana flash attention for fine-tuning.")},
167-
)
168-
flash_attention_recompute: bool = field(
169-
default=False,
170-
metadata={
171-
"help": (
172-
"Whether to enable recompute in Habana flash attention for fine-tuning."
173-
" It is applicable only when use_flash_attention is True."
174-
)
175-
},
176-
)
177-
flash_attention_causal_mask: bool = field(
178-
default=False,
179-
metadata={
180-
"help": (
181-
"Whether to enable causal mask in Habana flash attention for fine-tuning."
182-
" It is applicable only when use_flash_attention is True."
183-
)
184-
},
185-
)
186164

187165
def __post_init__(self):
188166
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
@@ -569,10 +547,10 @@ def main():
569547
# We need to add these fused kernels config
570548
if model_args.attn_softmax_bf16:
571549
model.generation_config.attn_softmax_bf16 = True
572-
if model_args.use_flash_attention:
550+
if training_args.attn_implementation == "gaudi_fused_sdpa":
573551
model.generation_config.use_flash_attention = True
574-
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
575-
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
552+
model.generation_config.flash_attention_recompute = training_args.flash_attention_recompute
553+
model.generation_config.flash_attention_causal_mask = training_args.flash_attention_causal_mask
576554

577555
# Preprocessing the datasets.
578556
# First we tokenize all the texts.

0 commit comments

Comments
 (0)