I am updating initial values of trainables, if I check the values in the dataframe view values change but get_parameters method still returns old values.
(PS: Asking because I use get_parameters to fetch parameters that are sent into optax transform, but my initial values don't seem to be updated as expected.)
>>> branch.nodes['radius']
0 1.0
1 1.0
Name: radius, dtype: float64
>>> branch.get_parameters()
[{'radius': Array([1., 1.], dtype=float32)}]
>>> branch.write_trainables([{"radius":jnp.array([0.25, 1.23], jnp.float32)}])
>>> branch.nodes['radius']
0 0.25
1 1.23
Name: radius, dtype: float64
>>> branch.get_parameters()
[{'radius': Array([1., 1.], dtype=float32)}]