Skip to content

Add an option to cache NNX traversals in PEFT trainer.#928

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_847406317
Dec 22, 2025
Merged

Add an option to cache NNX traversals in PEFT trainer.#928
copybara-service[bot] merged 1 commit intomainfrom
test_847406317

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Dec 21, 2025

Add an option to cache NNX traversals in PEFT trainer.

In peft_trainer.py, wrap train_step and eval_step with nnx.cached_partial to cache NNX graph traversals for performance as documented in https://flax.readthedocs.io/en/latest/guides/performance.html#caching-graph-node-traversals

This was show to significantly reduce the time spent in Python traversing the graph, which can be important for smaller models (e.g. Gemma3 1B) that may not spend that much accelerator time per step.

@copybara-service copybara-service bot changed the title Optimize NNX traversals in PEFT trainer. Add an option to cache NNX traversals in PEFT trainer. Dec 22, 2025
PiperOrigin-RevId: 847827483
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant