diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 910fb3af2..f637d0b06 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -104,6 +104,7 @@ from .nn.normalization import LayerNorm as LayerNorm from .nn.normalization import RMSNorm as RMSNorm from .nn.normalization import GroupNorm as GroupNorm +from .nn.normalization import WeightNorm as WeightNorm from .nn.stochastic import Dropout as Dropout from .rnglib import Rngs as Rngs from .rnglib import RngStream as RngStream diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 72c6450cf..c1a3dc2b6 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -181,6 +181,24 @@ def _normalize( return jnp.asarray(y, dtype) +def _l2_normalize(x, axis=None, eps=1e-12): + """Normalizes along dimension `axis` using an L2 norm. + + This specialized function exists for numerical stability reasons. + + Args: + x: An input ndarray. + axis: Dimension along which to normalize, e.g. `1` to separately normalize + vectors in a batch. Passing `None` views `t` as a flattened vector when + calculating the norm (equivalent to Frobenius norm). + eps: Epsilon to avoid dividing by zero. + + Returns: + An array of the same shape as 'x' L2-normalized along 'axis'. + """ + return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) + + class BatchNorm(Module): """BatchNorm Module. @@ -835,4 +853,138 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): (self.feature_axis,), self.dtype, self.epsilon, - ) \ No newline at end of file + ) + + +class WeightNorm(nnx.Module): + """L2 weight normalization (https://arxiv.org/abs/1602.07868). + + Weight normalization normalizes the weight params so that the l2-norm of + the matrix is equal to 1. This is implemented as a layer wrapper where + each wrapped layer will have its params l2-normalized before computing + its ``__call__`` output. + + Example usage:: + + >>> import jax + >>> import numpy as np + >>> from flax import nnx + + >>> class Foo(nnx.Module): + ... def __init__(self, rngs: nnx.Rngs): + ... self.normed_linear = nnx.WeightNorm( + ... nnx.Linear(8, 4, rngs=rngs), + ... variable_filter=nnx.PathContains('kernel'), + ... rngs=rngs, + ... ) + ... + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.normed_linear(x) + + >>> rng = jax.random.PRNGKey(42) + >>> model = Foo(rngs=nnx.Rngs(rng)) + + >>> x = jax.random.normal(rng, (5, 8)) + >>> y = model(x) + >>> y.shape + (5, 4) + + >>> w = model.normed_linear.layer_instance.kernel.value + >>> col_norms = np.linalg.norm(np.array(w), axis=0) + >>> np.testing.assert_allclose(col_norms, np.ones(4)) + + Args: + layer_instance: The layer instance to wrap. + feature_axes: The axes to normalize. + use_scale: Whether to use a scale parameter. + scale_init: The initializer for the scale parameter, by default ones. + epsilon: The epsilon value for the normalization, by default 1e-12. + dtype: The dtype of the result, by default infer from input and params. + param_dtype: The dtype of the parameters, by default float32. + variable_filter: The variable filter, by default ``nnx.PathContains('kernel')``. + rngs: The rng key. + """ + def __init__( + self, + layer_instance: nnx.Module, + *, + feature_axes: Axes | None = -1, + use_scale: bool = True, + scale_init: Initializer = initializers.ones, + epsilon: float = 1e-12, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + variable_filter: nnx.filterlib.Filter = nnx.PathContains('kernel'), + rngs: rnglib.Rngs, + ): + self.layer_instance = layer_instance + self.feature_axes = feature_axes + self.use_scale = use_scale + self.scale_init = scale_init + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.variable_filter = variable_filter + self.rngs = rngs + + def __call__(self, x: Array, *args, **kwargs) -> Array: + """Compute the l2-norm of the weights in ``self.layer_instance`` + and normalize the weights using this value before computing the + ``__call__`` output. + + Args: + *args: positional arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + **kwargs: keyword arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + + Returns: + Output of the layer using l2-normalized weights. + """ + state = nnx.state(self.layer_instance) + + def apply_weightnorm(path, var_state): + if not self.variable_filter(path, var_state): + return var_state + + param_val = jnp.asarray(var_state.value) + if self.feature_axes is None: + feature_axes = () + reduction_axes = tuple(range(param_val.ndim)) + else: + feature_axes = _canonicalize_axes(param_val.ndim, self.feature_axes) + reduction_axes = tuple(i for i in range(param_val.ndim) if i not in feature_axes) + + value_bar = _l2_normalize(param_val, axis=reduction_axes, eps=self.epsilon) + + if self.use_scale: + scale_shape = tuple(param_val.shape[ax] for ax in feature_axes) + scale_path = path + ("scale",) + try: + scale_state = state[scale_path] + scale_value = scale_state.value + except KeyError: + key = self.rngs.params() + scale_value = self.scale_init(key, scale_shape, self.param_dtype) + state[scale_path] = nnx.Param(scale_value) + + if len(feature_axes) < param_val.ndim: + broadcast_shape = [1] * param_val.ndim + for ax in feature_axes: + broadcast_shape[ax] = param_val.shape[ax] + scale_value = scale_value.reshape(broadcast_shape) + value_bar = value_bar * scale_value + + cast_args = [param_val] + if self.use_scale: + cast_args.append(scale_value) + + final_dtype = dtypes.canonicalize_dtype(*cast_args, dtype=self.dtype) + new_val = jnp.asarray(value_bar, final_dtype) + + return nnx.Param(new_val) + + state = nnx.map_state(apply_weightnorm, state) + nnx.update(self.layer_instance, state) + + return self.layer_instance(x, *args, **kwargs) # type: ignore diff --git a/tests/nnx/nn/normalization_test.py b/tests/nnx/nn/normalization_test.py index d6a399196..2601dc1fc 100644 --- a/tests/nnx/nn/normalization_test.py +++ b/tests/nnx/nn/normalization_test.py @@ -323,6 +323,69 @@ def __call__(self, x, *, mask=None): assert isinstance(linen_out, jax.Array) np.testing.assert_array_equal(linen_out, nnx_out) + @parameterized.product( + dtype=[jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + scale_init=[ + nnx.initializers.ones, + nnx.initializers.constant(10.0), + nnx.initializers.constant(0.5), + ], + ) + def test_nnx_linen_weightnorm_equivalence( + self, + dtype: tp.Optional[Dtype], + param_dtype: Dtype, + scale_init: nnx.Initializer, + ): + class NNXModel(nnx.Module): + def __init__(self, dtype, param_dtype, rngs): + self.dense = nnx.Linear( + 8, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs + ) + self.normed = nnx.WeightNorm( + self.dense, + use_scale=True, + scale_init=scale_init, + feature_axes=-1, + dtype=dtype, + param_dtype=param_dtype, + rngs=rngs, + ) + + def __call__(self, x, *, mask=None): + return self.normed(x) + + class LinenModel(linen.Module): + dtype: tp.Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + + def setup(self): + self.dense = linen.Dense( + 4, dtype=self.dtype, param_dtype=self.param_dtype + ) + self.weight_norm = linen.WeightNorm( + self.dense, variable_filter={'kernel'}, scale_init=scale_init + ) + + def __call__(self, x, *, mask=None): + return self.weight_norm(x) + + rngs = nnx.Rngs(42) + + x = jax.random.normal(jax.random.key(0), (10, 8)) + + linen_model = LinenModel(dtype=dtype, param_dtype=param_dtype) + variables = linen_model.init(jax.random.key(1), x) + + nnx_model = NNXModel(dtype=dtype, param_dtype=param_dtype, rngs=rngs) + nnx_model.dense.kernel.value = variables['params']['dense']['kernel'] + nnx_model.dense.bias.value = variables['params']['dense']['bias'] + + linen_out = linen_model.apply(variables, x) + + nnx_out = nnx_model(x) + np.testing.assert_array_equal(linen_out, nnx_out) if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file