Skip to content

[Bug] Adam optimizer produces near-zero effective updates when training LoRA weights in bfloat16 with ModelParallel on TPU #2629

@14790897

Description

@14790897

Describe the bug

When fine-tuning a large model (Qwen2.5-7B) with LoRA using ModelParallel on TPU v5e-8,
the Adam/AdamW optimizer produces no effective weight updates: the loss is identical
at step 1 of every epoch, indicating the LoRA weights are never modified.
Root cause may be: LoRA variables inherit the layer's bfloat16 dtype, and Adam's second-moment estimate v_t underflows in bfloat16, causing near-zero effective step sizes.

To Reproduce

keras version: 3.13.2
Use adam: https://www.kaggle.com/code/liuweiq/keras3-jax-qwen2-5-tpu?scriptVersionId=302386192
It's worth noting that Switching to SGD avoids this issue.
Use sgd: https://www.kaggle.com/code/liuweiq/keras3-jax-qwen2-5-tpu/notebook?scriptVersionId=302428315

Expected behavior

loss should change when using adam.
Additional context

epoch14(loss: 0.5014):
Image
epoch15(loss: 0.5014):
Image

Would you like to help us fix it?
no

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions