-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Hello, I'm trying to do Sobolev training, i.e. fitting function value + derivatives simultaneously, with KFAC. As far as I can understand, computing values + gradients requires two passes in JAX, and this seems to be causing issues with KFAC. Here is some sample code:
import math
import jax
import jax.numpy as jnp
import kfac_jax
def mlp(params, x):
for (weights, biases) in params[:-1]:
x = jnp.dot(x, weights) + biases
x = jax.nn.tanh(x)
weights_last, biases_last = params[-1]
x = jnp.dot(x, weights_last) + biases_last
return x.squeeze()
def init_mlp(rng, dim, num_hidden_layers, hidden_size):
params = []
layer_sizes = [dim] + [hidden_size]*num_hidden_layers + [1]
for size_in, size_out in zip(layer_sizes[:-1], layer_sizes[1:]):
rng, rng_weight = jax.random.split(rng)
weight_lim = math.sqrt(6. / (size_in + size_out))
weights = jax.random.uniform(rng_weight, (size_in, size_out), minval=-weight_lim, maxval=weight_lim)
biases = jnp.zeros((size_out,))
params.append((weights, biases))
return params
def func(x):
return jnp.exp(-jnp.sum(x))
sobolev_weight = 0.1
batch_size = 256
dim = 4
rng = jax.random.key(0)
rng, rng_sample, rng_dummy, rng_dummy_init = jax.random.split(rng, 4)
params = init_mlp(rng, 4, 4, 32)
def loss_fn(params, batch):
x = batch
# Prediction
preds = jax.vmap(mlp, in_axes=(None, 0))(params, x)
# Residuals
funcs = jax.vmap(func)(x)
if sobolev_weight is None:
kfac_jax.register_squared_error_loss(prediction=preds, targets=funcs)
return jnp.mean((preds - funcs)**2)
# Sobolev loss
else:
# Gradients of the prediction
grad_preds = jax.vmap(
jax.grad(lambda x, params: mlp(params, x).squeeze()),
in_axes=(0 ,None)
)(x, params)
# Gradients of the residuals
grad_funcs = jax.vmap(jax.grad(func))(x)
cat_preds = jnp.concatenate([preds[:, None], grad_preds], axis=1)
cat_funcs = jnp.concatenate([funcs[:, None], grad_funcs], axis=1)
weights = jnp.concatenate([jnp.ones((1,)), sobolev_weight * jnp.ones((dim,))])
kfac_jax.register_squared_error_loss(
prediction = cat_preds,
targets = cat_funcs,
weight = weights
)
return jnp.mean(
jnp.sum(weights * (cat_preds - cat_funcs)**2, axis=-1)
)
optimizer = kfac_jax.Optimizer(
value_and_grad_func = jax.value_and_grad(loss_fn),
l2_reg = 0.0,
value_func_has_aux = False,
value_func_has_state = False,
value_func_has_rng = False,
use_adaptive_learning_rate = True,
use_adaptive_momentum = True,
use_adaptive_damping = True,
initial_damping = 1.0,
multi_device = False,
)
# initialize K-FAC state on a dummy batch
dummy_x = jax.random.uniform(rng_dummy, (batch_size, dim))
optimizer_state = optimizer.init(params, rng_dummy_init, dummy_x)
for _ in range(50):
rng, rng_sample, rng_opt = jax.random.split(rng, 3)
x = jax.random.uniform(rng_sample, (batch_size, dim))
# Do an update
params, optimizer_state, stats = optimizer.step(
params = params,
state = optimizer_state,
rng = rng_opt,
batch = x
)
loss_val = stats["loss"]
print(loss_val)
This gives an error like this:
ValueError: Parameter Var(id=124146837841728):float32[4,32] has been registered to multiple tags: ['Auto[dense_tag_0]', 'Auto[dense_tag_5]'].
Is there anything I can do to make this work properly?
Metadata
Metadata
Assignees
Labels
No labels