Skip to content

a question about how to increase batch size. #125

@IvoryTower800

Description

@IvoryTower800

Describe the bug
Hi, I tried to finetune gemma-2b model with sharding_array=(1, 1, 1, -1) on Kaggle tpu vm v3-8.

there are two parameters about batch size in TrainArguments: total_batch_size, gradient_accumulation_steps.

If I set 1 for both of the two. it worked well. and the total tpu memory used was 14.4GB (reported by track_memory=True).

However, when I set the total_batch_size=1 and gradient_accumulation_steps=4. it says tpu memory exhausted.

I'm doubt about the batch size when training on TPU with easydel. if I finetune the model using transformers on GPU. I can set as

many as gradient_accumulation_steps I want, it won't increase the gpu vram usage. but on tpu, it can't.

Do I misunderstand the gradient_accumulation_steps in easydel? Could you please tell me how I can increase my actual batch size when finetuning?
for example, if I want my final batch size equal to 32.

To Reproduce

train_arguments = TrainArguments( model_class=type(model), model_name="gemma_2b_it", num_train_epochs=1, configs_to_initialize_model_class=configs_to_initialize_model_class, custom_rule=model.config.get_partition_rules(True), learning_rate=2e-5, learning_rate_end=2e-7, max_sequence_length=max_length, optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported scheduler=EasyDelSchedulers.LINEAR, # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported weight_decay=0.01, total_batch_size=1, max_training_steps=None, # None to let trainer Decide do_train=True, do_eval=False, # it's optional but supported backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu max_length=max_length, # Note that you have to change this in the model config too gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE, sharding_array=(1, 1, 1, -1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1) # everything training will be in sequence and model parallel automatic and share data between devices use_pjit_attention_force=False, remove_ckpt_after_load=True, init_input_shape=(1, max_length), gradient_accumulation_steps=4, loss_re_mat="", dtype=jnp.bfloat16, track_memory=True, use_wandb=True, # This disable WANB usage )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions