Manually updating parameters - documentation traverse util #1073
-
I'm working on implementing reservoir computing / echo state networks in Jax & Flax. (We're seeing an almost a 2x speedup compared to pytorch code, so I hope to release it soon!) After training I want to manually set the weights in the output layer. Currently I use
which works, but I think the preferred method is to use a traverse util? Is this indeed the case, and if so, is there a little bit more documentation on how to apply these? The example in the documentation is sparse and I'm having issues making it work... Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I think you proposed solution works fine and I don't think it is necessary to use
|
Beta Was this translation helpful? Give feedback.
-
Hi, how is the progress with your implementation? I have opened a new thread, where I mention this discussion. |
Beta Was this translation helpful? Give feedback.
I think you proposed solution works fine and I don't think it is necessary to use
traverse_util
if you simply want to update a subtree of your variables. Theunfreeze
/freeze
pattern is quite common, and it is for instance also used in our Model Surgery HOWTO.traverse_util
can be usee for more complex operations. For instance, if you would like to replace all occurrences ofoutput_layer
in your params with a different subtree. For more details on how this can be used, it may be insightful to take a look at some of the tests in tests/traverse_util_test.py.