Describe the bug
When fine-tuning a large model (Qwen2.5-7B) with LoRA using ModelParallel on TPU v5e-8,
the Adam/AdamW optimizer produces no effective weight updates: the loss is identical
at step 1 of every epoch, indicating the LoRA weights are never modified.
Root cause may be: LoRA variables inherit the layer's bfloat16 dtype, and Adam's second-moment estimate v_t underflows in bfloat16, causing near-zero effective step sizes.
To Reproduce
keras version: 3.13.2
Use adam: https://www.kaggle.com/code/liuweiq/keras3-jax-qwen2-5-tpu?scriptVersionId=302386192
It's worth noting that Switching to SGD avoids this issue.
Use sgd: https://www.kaggle.com/code/liuweiq/keras3-jax-qwen2-5-tpu/notebook?scriptVersionId=302428315
Expected behavior
loss should change when using adam.
Additional context
epoch14(loss: 0.5014):

epoch15(loss: 0.5014):

Would you like to help us fix it?
no