Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion haiku/_src/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from haiku._src import module
import jax
import jax.numpy as jnp
import haiku as hk


# If you are forking replace this block with `import haiku as hk`.
Expand Down Expand Up @@ -149,7 +150,8 @@ class EMAParamsTree(hk.Module):

>>> network_fn = lambda x: hk.Linear(10)(x)
>>> x = jnp.ones([1, 1])
>>> params = hk.transform(network_fn).init(jax.random.PRNGKey(428), x)
>>> rng = jax.random.PRNGKey(428)
>>> params = hk.transform(network_fn).init(rng, x)

You might use the EMAParamsTree like follows:

Expand Down
31 changes: 31 additions & 0 deletions haiku/_src/moving_averages_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,37 @@ def f(x):
# Floating point error creeps up to 1e-7 (the default).
np.testing.assert_allclose(inp_value, value, rtol=1e-6)

def test_ema_preserves_dtype(self):
def f(x):
return moving_averages.ExponentialMovingAverage(0.5)(x)

inp_value = jnp.array(1.0, dtype=jnp.float32)

init_fn, apply_fn = multi_transform.without_apply_rng(
transform.transform_with_state(f))
_, params_state = init_fn(None, inp_value)

value = inp_value
for _ in range(5):
value, params_state = apply_fn(None, params_state, value)
self.assertEqual(value.dtype, jnp.float32)

def test_ema_stability_on_constant_input(self):
def f(x):
return moving_averages.ExponentialMovingAverage(0.5)(x)

inp_value = jnp.array(10.0, dtype=jnp.float32)

init_fn, apply_fn = multi_transform.without_apply_rng(
transform.transform_with_state(f))
_, params_state = init_fn(None, inp_value)

value = inp_value
for _ in range(50):
value, params_state = apply_fn(None, params_state, inp_value)
self.assertTrue(jnp.isfinite(value))
self.assertLess(abs(value), 1000.0)

@parameterized.parameters(True, False)
@test_utils.transform_and_run
def test_initialize(self, legacy_initialize):
Expand Down