Recompute convolution without parameter update #2951
Unanswered
jakubMitura14
asked this question in
Q&A
Replies: 1 comment
-
Hey @jakubMitura14, take a look at nn.map_variables, you can probably use a similar pattern as the example but |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello, I have unet like architecture, so when x is input, pseudocode
Now, the problem is that I am memory constrained, so instead of keeping all in memory, I would do remat and sequential, so line 3 will become
However, for this case, I need to recompute
conv1(x) but avoid any parameter learning that was already done in line 1
I know that one can freeze parameters using optax, but as far as I get it, it would freeze all conv1 cases, so both first as well as second invocation, and I want to block only the second one.
Is there a way to selectively block parameter learning?
Beta Was this translation helpful? Give feedback.
All reactions