-
Hi, I am trying to boundle two networks into the same large network, and it was my understanding that something like this would work: import flax.linen as nn
import jax
import jax.numpy as jnp
class Net1(nn.Module):
n: int
@nn.compact
def __call__(self, x):
for _ in range(2):
x = nn.relu(nn.Dense(64)(x))
x = nn.Dense(self.n)(x)
return x
class Net2(nn.Module):
n: int
def setup(self):
self.inner = Net1(self.n)
self.outer = Net1(self.n*2)
def __call__(self, x):
x = self.outer(x)
return x
key = jax.random.key(42)
x1_dummy = jnp.ones(shape=(20))
inner_params = Net2(n=10).init(key, x1_dummy, method='inner')
...
apply_inner_fn = lambda x: Net2(n=10).apply(x, method='inner') Where I want to have access to def loss_f(params, x, y):
z = train_state.apply_inner_fn(params, x)
y_hat = train_state.apply_fn(params, z)
return jnp.mean((y-y_hat)**2)
...
# Update params together with the same grads But running the first snippets, I get
I am calling it inside |
Beta Was this translation helpful? Give feedback.
Answered by
chiamp
Nov 13, 2023
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
alonfnt
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Net2(n=10).init(key, x1_dummy, method='inner')
is looking for a method calledinner
, which isn't defined. You could try something like this: