When upgrading from kfac_jax==0.0.6 (with jax==0.4.35) to kfac_jax==0.0.7 (with jax==0.4.36), some parameters that were previously registered as Orphan are now being registered as Generic. This change leads to a training crash. In version 0.0.6, training proceeded normally even when parameters were marked as Orphan.
Expected behavior
Either:
- The behavior in 0.0.6 should be preserved (i.e., training with
Orphan parameters should continue to work), or
- If the new behavior in 0.0.7 is intentional, there should be clear documentation on:
- The distinction between
'Orphan' and 'Generic'
- How these affect the optimizer's behavior
Additional context / Questions
-
What is the intended difference between 'Orphan' and 'Generic' parameter types in kfac_jax?
The documentation and source code comments do not clearly explain this.
-
What internal logic determines whether a parameter is assigned 'Orphan' vs. 'Generic'?
Was this logic modified between 0.0.6 and 0.0.7?
Environment
kfac_jax: 0.0.6 → 0.0.7
jax: 0.4.35 → 0.4.36
- Hardware: A100 GPU