Skip to content

Commit b780d70

Browse files
committed
[skip ci] Merge branch 'main' into transformers_future
2 parents 9cf57be + f76f572 commit b780d70

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

examples/text-generation/run_generation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929

3030
import torch
3131
from transformers import BatchEncoding
32-
from utils import adjust_batch, count_hpu_graphs, finalize_quantization, initialize_model, save_model
32+
from utils import (
33+
SetTrueOrFalseOrNone,
34+
adjust_batch,
35+
count_hpu_graphs,
36+
finalize_quantization,
37+
initialize_model,
38+
save_model,
39+
)
3340

3441
from optimum.habana.utils import get_hpu_memory_stats
3542

@@ -276,7 +283,9 @@ def setup_parser(parser):
276283
)
277284
parser.add_argument(
278285
"--flash_attention_fast_softmax",
279-
action="store_true",
286+
nargs="?",
287+
const=None,
288+
action=SetTrueOrFalseOrNone,
280289
help="Whether to enable Habana Flash Attention in fast softmax mode.",
281290
)
282291
parser.add_argument(
@@ -382,8 +391,13 @@ def setup_parser(parser):
382391
if not args.use_hpu_graphs:
383392
args.limit_hpu_graphs = False
384393

385-
if args.use_flash_attention and not args.flash_attention_fast_softmax:
394+
if args.use_flash_attention and args.flash_attention_fast_softmax is None:
395+
logger.warning(
396+
"`--flash_attention_fast_softmax` was not set; defaulting to True due to `--use_flash_attention` being enabled."
397+
)
386398
args.flash_attention_fast_softmax = True
399+
else:
400+
args.flash_attention_fast_softmax = False
387401

388402
args.quant_config = os.getenv("QUANT_CONFIG", "")
389403
if args.quant_config and args.load_quantized_model_with_autogptq:

examples/text-generation/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
1818
###############################################################################
1919

20+
import argparse
2021
import copy
2122
import glob
2223
import os
@@ -796,3 +797,46 @@ def local_split_rank_state_dict(model, gathered_state_dict):
796797
cur_accelerator.synchronize()
797798

798799
return rank_state_dict
800+
801+
802+
class SetTrueOrFalseOrNone(argparse.Action):
803+
"""
804+
Custom argparse action to handle a flag that can be set to True, False, or None.
805+
806+
This action allows an argument to be:
807+
- Set to True if the flag is present without a value.
808+
- Set to a boolean value (True or False) if explicitly provided.
809+
- Set to None if the flag is not present.
810+
811+
The argument accepts the following values (case-insensitive):
812+
- True values: 'true', '1', 't', 'y', 'yes'
813+
- False values: 'false', '0', 'f', 'n', 'no'
814+
815+
If an invalid value is provided, an argparse.ArgumentTypeError is raised.
816+
"""
817+
818+
def __call__(self, parser, namespace, values, option_string=None):
819+
value_map = {
820+
"true": True,
821+
"1": True,
822+
"t": True,
823+
"y": True,
824+
"yes": True,
825+
"false": False,
826+
"0": False,
827+
"f": False,
828+
"n": False,
829+
"no": False,
830+
}
831+
if values is None:
832+
setattr(namespace, self.dest, True)
833+
elif isinstance(values, bool):
834+
setattr(namespace, self.dest, values)
835+
else:
836+
value_lower = values.lower()
837+
if value_lower in value_map:
838+
setattr(namespace, self.dest, value_map[value_lower])
839+
else:
840+
raise argparse.ArgumentTypeError(
841+
f"Invalid value for {option_string}: {values}. Expected one of: {', '.join(value_map.keys())}."
842+
)

0 commit comments

Comments
 (0)