How to manage variables in a Module that uses different submodules based on some input? #1151
-
|
Original question by Anton Raichunk: Suppose I have a module Is my understanding correct that treating such a module is going to be cumbersome? is there something that I'm missing that is supposed to make my life easier? Here's a simple example using Gin: from flax import linen
import flax.nn
import jax
import jax.numpy as jnp
import gin
import functools
gin.external_configurable(functools.partial)
gin.clear_config()
gin.enter_interactive_mode()
@gin.configurable
class MLP(linen.Module):
a: int = 0
b: int = 0
@gin.configurable
def myfunction(network):
class Yo(linen.Module):
@linen.compact
def __call__(self):
m = network
print(m.a, m.b)
return Yo().init
lines = ['myfunction.network=@MyMLP/MLP()',
'MyMLP/MLP.a = 2',
'MyMLP/MLP.b = 2']
gin.parse_config('\n'.join(lines))
t = myfunction()
t(jax.random.PRNGKey(0))
gin.clear_config()
t(jax.random.PRNGKey(0)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Answer from @jheek: The new variables will also contain the untouched variables as long as you provided them as input to apply. I don't think this should be cumbersome but perhaps I'm missing something. Typically you would make explicit what types of state (collections) you are expecting because your training pipeline usually has some assumptions about it. Like eval doesn't update batch stats but it might updated cached decoding. Adding something to |
Beta Was this translation helpful? Give feedback.
Answer from @jheek:
The new variables will also contain the untouched variables as long as you provided them as input to apply. I don't think this should be cumbersome but perhaps I'm missing something.
Typically you would make explicit what types of state (collections) you are expecting because your training pipeline usually has some assumptions about it. Like eval doesn't update batch stats but it might updated cached decoding.
Adding something to
mutablethat isn't there doesn't have side effects. So you could see it interpret the argument as "in this context I support the following kinds of state: ..."