Skip to content

Commit 7e37dac

Browse files
authored
Remove redundant arguments from trl/sft ScriptArguments. (#2332)
Signed-off-by: Artur Kloniecki <arturx.kloniecki@intel.com>
1 parent 00a103d commit 7e37dac

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

examples/trl/sft.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,6 @@ class ScriptArguments:
4545
use_flash_attention: Optional[bool] = field(
4646
default=False, metadata={"help": "Whether to use Habana flash attention for fine-tuning."}
4747
)
48-
flash_attention_recompute: Optional[bool] = field(
49-
default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."}
50-
)
51-
flash_attention_causal_mask: Optional[bool] = field(
52-
default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."}
53-
)
5448

5549
# LoraConfig
5650
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
@@ -159,13 +153,11 @@ def create_datasets(tokenizer, args, seed=None):
159153
)
160154

161155
base_model.config.use_cache = False
162-
if not script_args.use_flash_attention and (
163-
script_args.flash_attention_recompute or script_args.flash_attention_recompute
164-
):
156+
if not script_args.use_flash_attention and training_args.flash_attention_recompute:
165157
assert "Need to enable use_flash_attention"
166158
base_model.generation_config.use_flash_attention = script_args.use_flash_attention
167-
base_model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute
168-
base_model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask
159+
base_model.generation_config.flash_attention_recompute = training_args.flash_attention_recompute
160+
base_model.generation_config.flash_attention_causal_mask = training_args.flash_attention_causal_mask
169161

170162
if is_zero3_enabled and training_args.use_zero3_leaf_promotion:
171163
apply_zero3_leaf_promotion(base_model)

0 commit comments

Comments
 (0)