Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable batch size and LR scheduler #7104

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ff214de
copy past of all files from PR 7020
bm-synth Mar 3, 2025
53a3b1a
copy/paste of all files from PR 7020 and 7102 to avoid having rebase and
bm-synth Mar 3, 2025
28b8418
Merge branch 'variable_batch_size_and_lr_2' of github.com:bm-synth/De…
bm-synth Mar 3, 2025
810d89b
merge conflicts
bm-synth Mar 4, 2025
6e6f04c
corrected arguments in batch_by_seqlens
bm-synth Mar 4, 2025
588e982
corrected alias for deepspeed team
bm-synth Mar 4, 2025
8b44bfd
Update deepspeed/runtime/data_pipeline/data_sampling/variable_batch_s…
bm-synth Mar 4, 2025
ae9c667
Merge branch 'variable_batch_size_and_lr_2' of github.com:bm-synth/De…
bm-synth Mar 4, 2025
42015d7
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 4, 2025
241eab8
clang format with the correct version
bm-synth Mar 4, 2025
78d80cf
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 4, 2025
37b62e5
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 4, 2025
91a1d28
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 4, 2025
1aee19e
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 5, 2025
9fbe3d1
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 8, 2025
1331caf
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 11, 2025
3c1ae09
Merge branch 'master' into variable_batch_size_and_lr_2
loadams Mar 11, 2025
50edd6f
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 12, 2025
089d323
Merge branch 'master' into variable_batch_size_and_lr_2
loadams Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def main(args=None):
result = subprocess.check_output(hostname_cmd)
except subprocess.CalledProcessError as err:
logger.error(
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
"Unable to detect suitable master address via 'hostname -I', please manually specify one via --master_addr"
)
raise err
args.master_addr = result.decode('utf-8').split()[0]
Expand Down
1 change: 0 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,6 @@ def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):

def _initialize_params(self, param_dict):
self.train_batch_size = get_train_batch_size(param_dict)
#print(f"beginning get_train_batch_size = {get_train_batch_size}")
self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict)
self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
self.steps_per_print = get_steps_per_print(param_dict)
Expand Down
25 changes: 23 additions & 2 deletions deepspeed/runtime/data_pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def get_data_efficiency_config(param_dict):
sub_param_dict = param_dict[DATA_EFFICIENCY]
output[DATA_SAMPLING] = get_data_sampling(sub_param_dict)
output[DATA_ROUTING] = get_data_routing(sub_param_dict)

return output


Expand All @@ -43,11 +42,13 @@ def get_data_sampling(param_dict):
output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict)
output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict)
output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict)
output[DATA_SAMPLING_PIN_MEMORY] = bool(
output.get(param_dict[DATA_SAMPLING][DATA_SAMPLING_PIN_MEMORY], DATA_SAMPLING_PIN_MEMORY_DEFAULT))
if DATA_SAMPLING not in param_dict.keys():
param_dict[DATA_SAMPLING] = {}
sub_param_dict = param_dict[DATA_SAMPLING]
output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict)

output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict)
return output


Expand Down Expand Up @@ -87,6 +88,26 @@ def get_curriculum_learning(param_dict):
return output


def get_dynamic_batching(param_dict):
output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {}))
output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT))
output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str(
output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT))
output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int(
output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT))
output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \
if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \
else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT
output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str(
output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT))
if output[DYNAMIC_BATCHING_ENABLED]:
assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys(
), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified"
output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS])
output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False))
return output


def get_curriculum_learning_enabled(param_dict):
if CURRICULUM_LEARNING in param_dict.keys():
return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/data_pipeline/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
DATA_SAMPLING_NUM_EPOCHS_DEFAULT = 1000
DATA_SAMPLING_NUM_WORKERS = "num_workers"
DATA_SAMPLING_NUM_WORKERS_DEFAULT = 0
DATA_SAMPLING_PIN_MEMORY = "pin_memory"
DATA_SAMPLING_PIN_MEMORY_DEFAULT = False

#########################################
# Data efficiency - Data Sampling - Curriculum Learning
Expand Down Expand Up @@ -62,6 +64,24 @@
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION = "data_cluster_current_position"
CURRICULUM_LEARNING_NP_RNG_STATE = "np_rng_state"

#########################################
# Data efficiency - Dynamic batching and LR scaling
#########################################
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_BATCHING_ENABLED = "enabled"
DYNAMIC_BATCHING_ENABLED_DEFAULT = False
DYNAMIC_BATCHING_METRICS_PATH = "metrics_path"
DYNAMIC_BATCHING_LR_SCALING_METHOD = "lr_scaling_method" # "linear" / "sqrt" / "none"
DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT = "linear"
DYNAMIC_BATCHING_MIN_BATCH_SIZE = "min_batch_size"
DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT = 1
DYNAMIC_BATCHING_MAX_BATCH_SIZE = "max_batch_size"
DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT = None
DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER = "sequence_picking_order" # "random" / "seqlen" / "dataloader"
DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT = "dataloader" # "random" / "seqlen" / "dataloader"
DYNAMIC_BATCHING_MAX_TOKENS = "max_tokens"
DYNAMIC_BATCHING_VERBOSE = "verbose"

#########################################
# Curriculum Learning legacy implementation
#########################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,13 @@ def test_compare_both_data_analyzers(dataset):
for path in output_paths:
with open(os.path.join(da.save_path, path), 'rb') as f1, \
open(os.path.join(dda.save_path, path), 'rb') as f2:
if f1.read() != f2.read():
# if files have suffix .bin, they should be identical
if path.endswith(".bin"):
assert f1.read() == f2.read(), f"files {path} are not identical."
elif f1.read() != f2.read():
print(f"files {path} are not identical.")
dist.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
Expand Down
Loading
Loading