Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable batch size and LR scheduler #7104

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ff214de
copy past of all files from PR 7020
bm-synth Mar 3, 2025
53a3b1a
copy/paste of all files from PR 7020 and 7102 to avoid having rebase and
bm-synth Mar 3, 2025
28b8418
Merge branch 'variable_batch_size_and_lr_2' of github.com:bm-synth/De…
bm-synth Mar 3, 2025
810d89b
merge conflicts
bm-synth Mar 4, 2025
6e6f04c
corrected arguments in batch_by_seqlens
bm-synth Mar 4, 2025
588e982
corrected alias for deepspeed team
bm-synth Mar 4, 2025
8b44bfd
Update deepspeed/runtime/data_pipeline/data_sampling/variable_batch_s…
bm-synth Mar 4, 2025
ae9c667
Merge branch 'variable_batch_size_and_lr_2' of github.com:bm-synth/De…
bm-synth Mar 4, 2025
42015d7
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 4, 2025
241eab8
clang format with the correct version
bm-synth Mar 4, 2025
78d80cf
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 4, 2025
37b62e5
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 4, 2025
91a1d28
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 4, 2025
1aee19e
Merge branch 'master' into variable_batch_size_and_lr_2
bm-synth Mar 5, 2025
9fbe3d1
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 8, 2025
1331caf
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 11, 2025
3c1ae09
Merge branch 'master' into variable_batch_size_and_lr_2
loadams Mar 11, 2025
50edd6f
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 12, 2025
089d323
Merge branch 'master' into variable_batch_size_and_lr_2
loadams Mar 14, 2025
ff0ea73
Merge branch 'master' into variable_batch_size_and_lr_2
tjruwase Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions csrc/deepspeed4science/evoformer_attn/attention_back.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ constexpr auto kBlockSizeJ = 64;
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
torch::Tensor& go,
torch::Tensor& q,
Expand All @@ -42,10 +40,8 @@ typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_back_
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
torch::Tensor& go,
torch::Tensor& q,
Expand Down
12 changes: 4 additions & 8 deletions csrc/deepspeed4science/evoformer_attn/attention_cu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_template(
torch::Tensor& q,
torch::Tensor& k,
Expand All @@ -31,10 +29,8 @@ typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_impl_template(
torch::Tensor& q,
torch::Tensor& k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class PredicatedTileAccessIteratorResidualLast<Shape_,

/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))){};
Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {};

/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Expand Down Expand Up @@ -695,7 +695,7 @@ class PredicatedTileAccessIteratorResidualLast<Shape_,

/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))){};
Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {};

/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Expand Down Expand Up @@ -1211,7 +1211,7 @@ class PredicatedTileAccessIteratorResidualLast<Shape_,
/// Construct the Params object given an AffineRankN<2> tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){};
: params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {};
};

private:
Expand Down Expand Up @@ -1413,7 +1413,7 @@ class PredicatedTileAccessIteratorResidualLast<Shape_,
/// Construct the Params object given an AffineRankN<2> tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){};
: params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {};
};

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ struct BroadcastB : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
template <typename Shape,
typename scalar_t,
int kThreads,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
struct AttentionBiasEpilogue {
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
Expand Down
2 changes: 1 addition & 1 deletion csrc/includes/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ reduce_block_into_lanes(T* x,
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
// __SYNCWARP();

#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
Expand Down
2 changes: 1 addition & 1 deletion csrc/xpu/includes/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ reduce_block_into_lanes(T* x,
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
// __SYNCWARP();

#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ struct
{
int32_t n_tokens;
int32_t n_sequences;
}
typedef RaggedBatchDescriptor;
} typedef RaggedBatchDescriptor;

struct
#ifdef __CUDA_CC__
Expand All @@ -26,8 +25,7 @@ struct
int32_t n_tokens;
int32_t seen_tokens;
int32_t UNUSED; // Explicit padding to match the Python code pattern.
}
typedef InflightSeqDescriptor;
} typedef InflightSeqDescriptor;

struct
#ifdef __CUDA_CC__
Expand All @@ -37,8 +35,7 @@ struct
int32_t** block_lists;
int32_t block_size;
int32_t n_blocks;
}
typedef KVCacheDescriptor;
} typedef KVCacheDescriptor;

struct {
const RaggedBatchDescriptor* batch_metadata; // Offset 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void launch_top_k_gating(int32_t* expert_counts,
}

#define INSTANTIATE_top_k_KERNEL(T) \
template void launch_top_k_gating<T>(int32_t * expert_counts, \
template void launch_top_k_gating<T>(int32_t* expert_counts, \
float* scores, \
int32_t* assignments, \
int32_t* offsets, \
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def main(args=None):
result = subprocess.check_output(hostname_cmd)
except subprocess.CalledProcessError as err:
logger.error(
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
"Unable to detect suitable master address via 'hostname -I', please manually specify one via --master_addr"
)
raise err
args.master_addr = result.decode('utf-8').split()[0]
Expand Down
1 change: 0 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,6 @@ def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):

def _initialize_params(self, param_dict):
self.train_batch_size = get_train_batch_size(param_dict)
#print(f"beginning get_train_batch_size = {get_train_batch_size}")
self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict)
self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
self.steps_per_print = get_steps_per_print(param_dict)
Expand Down
25 changes: 23 additions & 2 deletions deepspeed/runtime/data_pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def get_data_efficiency_config(param_dict):
sub_param_dict = param_dict[DATA_EFFICIENCY]
output[DATA_SAMPLING] = get_data_sampling(sub_param_dict)
output[DATA_ROUTING] = get_data_routing(sub_param_dict)

return output


Expand All @@ -43,11 +42,13 @@ def get_data_sampling(param_dict):
output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict)
output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict)
output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict)
output[DATA_SAMPLING_PIN_MEMORY] = bool(
output.get(param_dict[DATA_SAMPLING][DATA_SAMPLING_PIN_MEMORY], DATA_SAMPLING_PIN_MEMORY_DEFAULT))
if DATA_SAMPLING not in param_dict.keys():
param_dict[DATA_SAMPLING] = {}
sub_param_dict = param_dict[DATA_SAMPLING]
output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict)

output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict)
return output


Expand Down Expand Up @@ -87,6 +88,26 @@ def get_curriculum_learning(param_dict):
return output


def get_dynamic_batching(param_dict):
output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {}))
output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT))
output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str(
output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT))
output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int(
output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT))
output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \
if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \
else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT
output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str(
output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT))
if output[DYNAMIC_BATCHING_ENABLED]:
assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys(
), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified"
output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS])
output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False))
return output


def get_curriculum_learning_enabled(param_dict):
if CURRICULUM_LEARNING in param_dict.keys():
return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/data_pipeline/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
DATA_SAMPLING_NUM_EPOCHS_DEFAULT = 1000
DATA_SAMPLING_NUM_WORKERS = "num_workers"
DATA_SAMPLING_NUM_WORKERS_DEFAULT = 0
DATA_SAMPLING_PIN_MEMORY = "pin_memory"
DATA_SAMPLING_PIN_MEMORY_DEFAULT = False

#########################################
# Data efficiency - Data Sampling - Curriculum Learning
Expand Down Expand Up @@ -62,6 +64,24 @@
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION = "data_cluster_current_position"
CURRICULUM_LEARNING_NP_RNG_STATE = "np_rng_state"

#########################################
# Data efficiency - Dynamic batching and LR scaling
#########################################
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_BATCHING_ENABLED = "enabled"
DYNAMIC_BATCHING_ENABLED_DEFAULT = False
DYNAMIC_BATCHING_METRICS_PATH = "metrics_path"
DYNAMIC_BATCHING_LR_SCALING_METHOD = "lr_scaling_method" # "linear" / "sqrt" / "none"
DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT = "linear"
DYNAMIC_BATCHING_MIN_BATCH_SIZE = "min_batch_size"
DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT = 1
DYNAMIC_BATCHING_MAX_BATCH_SIZE = "max_batch_size"
DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT = None
DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER = "sequence_picking_order" # "random" / "seqlen" / "dataloader"
DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT = "dataloader" # "random" / "seqlen" / "dataloader"
DYNAMIC_BATCHING_MAX_TOKENS = "max_tokens"
DYNAMIC_BATCHING_VERBOSE = "verbose"

#########################################
# Curriculum Learning legacy implementation
#########################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,13 @@ def test_compare_both_data_analyzers(dataset):
for path in output_paths:
with open(os.path.join(da.save_path, path), 'rb') as f1, \
open(os.path.join(dda.save_path, path), 'rb') as f2:
if f1.read() != f2.read():
# if files have suffix .bin, they should be identical
if path.endswith(".bin"):
assert f1.read() == f2.read(), f"files {path} are not identical."
elif f1.read() != f2.read():
print(f"files {path} are not identical.")
dist.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
Expand Down
Loading
Loading