|
17 | 17 | # Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company |
18 | 18 | ############################################################################### |
19 | 19 |
|
| 20 | +import argparse |
20 | 21 | import copy |
21 | 22 | import glob |
22 | 23 | import os |
@@ -796,3 +797,46 @@ def local_split_rank_state_dict(model, gathered_state_dict): |
796 | 797 | cur_accelerator.synchronize() |
797 | 798 |
|
798 | 799 | return rank_state_dict |
| 800 | + |
| 801 | + |
| 802 | +class SetTrueOrFalseOrNone(argparse.Action): |
| 803 | + """ |
| 804 | + Custom argparse action to handle a flag that can be set to True, False, or None. |
| 805 | +
|
| 806 | + This action allows an argument to be: |
| 807 | + - Set to True if the flag is present without a value. |
| 808 | + - Set to a boolean value (True or False) if explicitly provided. |
| 809 | + - Set to None if the flag is not present. |
| 810 | +
|
| 811 | + The argument accepts the following values (case-insensitive): |
| 812 | + - True values: 'true', '1', 't', 'y', 'yes' |
| 813 | + - False values: 'false', '0', 'f', 'n', 'no' |
| 814 | +
|
| 815 | + If an invalid value is provided, an argparse.ArgumentTypeError is raised. |
| 816 | + """ |
| 817 | + |
| 818 | + def __call__(self, parser, namespace, values, option_string=None): |
| 819 | + value_map = { |
| 820 | + "true": True, |
| 821 | + "1": True, |
| 822 | + "t": True, |
| 823 | + "y": True, |
| 824 | + "yes": True, |
| 825 | + "false": False, |
| 826 | + "0": False, |
| 827 | + "f": False, |
| 828 | + "n": False, |
| 829 | + "no": False, |
| 830 | + } |
| 831 | + if values is None: |
| 832 | + setattr(namespace, self.dest, True) |
| 833 | + elif isinstance(values, bool): |
| 834 | + setattr(namespace, self.dest, values) |
| 835 | + else: |
| 836 | + value_lower = values.lower() |
| 837 | + if value_lower in value_map: |
| 838 | + setattr(namespace, self.dest, value_map[value_lower]) |
| 839 | + else: |
| 840 | + raise argparse.ArgumentTypeError( |
| 841 | + f"Invalid value for {option_string}: {values}. Expected one of: {', '.join(value_map.keys())}." |
| 842 | + ) |
0 commit comments