Skip to content

TypeError: unhashable type: 'Literal' #310

@DanChai22

Description

@DanChai22

Issue Report: TypeError with Constant Multiplication in network When Using kfac_jax

Summary

When using kfac_jax with a network, introducing a scaling factor (e.g., geo_scale) for lattice parameters in the computational graph causes a TypeError related to the use of Literal. This occurs whether geo_scale is passed as a parameter or defined as a constant multiplier in the computation.


Steps to Reproduce

Here’s an example illustrating the issue(get_jacobian is part of the network):

  1. Working Example: No Scaling
    The following works without errors:

    def get_jacobian(params):
        p_cell = params['cell'].ravel()  # No scaling applied
        return jnp.diag(p_cell)
  2. Failing Example 1: Direct Multiplication
    Adding a constant multiplier to params['cell'] causes a TypeError:

    def get_jacobian(params):
        p_cell = params['cell'].ravel() * 1e-3  # Multiplying with a constant
        return jnp.diag(p_cell)

    Error Raised:

    TypeError: unhashable type: 'Literal'
    
  3. Failing Example 2: Adding geo_scale Parameter
    Introducing a geo_scale parameter also causes the same TypeError:

    def get_jacobian(params, geo_scale=1e-3):
        p_cell = params['cell'].ravel() * geo_scale
        return jnp.diag(p_cell)

    Error Raised:

    TypeError: unhashable type: 'Literal'
    

Questions

  1. Is there a recommended approach for handling constants or scaling factors like geo_scale in kfac_jax workflows to avoid such issues?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions