Skip to content

Unexpected change in parameter registration from 'Orphan' to 'Generic' causes training failure in kfac_jax 0.0.7 #344

@DanChai22

Description

@DanChai22

When upgrading from kfac_jax==0.0.6 (with jax==0.4.35) to kfac_jax==0.0.7 (with jax==0.4.36), some parameters that were previously registered as Orphan are now being registered as Generic. This change leads to a training crash. In version 0.0.6, training proceeded normally even when parameters were marked as Orphan.

Expected behavior

Either:

  • The behavior in 0.0.6 should be preserved (i.e., training with Orphan parameters should continue to work), or
  • If the new behavior in 0.0.7 is intentional, there should be clear documentation on:
    • The distinction between 'Orphan' and 'Generic'
    • How these affect the optimizer's behavior

Additional context / Questions

  1. What is the intended difference between 'Orphan' and 'Generic' parameter types in kfac_jax?
    The documentation and source code comments do not clearly explain this.

  2. What internal logic determines whether a parameter is assigned 'Orphan' vs. 'Generic'?
    Was this logic modified between 0.0.6 and 0.0.7?

Environment

  • kfac_jax: 0.0.6 → 0.0.7
  • jax: 0.4.35 → 0.4.36
  • Hardware: A100 GPU

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions