Description
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04.3 LTS
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: 0.9.0, 0.4.33, 0.4.33 - Python version: 3.10
- GPU/TPU model and memory: Colab T4 GPU
- CUDA version (if applicable): N/A
Problem you have encountered:
Using a tuple as the input and/or output in the NNX conv layer results in an error that the //
operator cannot be used on a tuple.
What you expected to happen:
That creating a new nnx.Conv
layer with a tuple as the input/output does not produce an error. Using a conv layer Conv(in_features=(x, x), out_features=(y, y))
on an input shape of (a, b, x, x)
will produce an output with shape (a, b, y, y)
, as alluded to in the documentation.
Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
from flax import nnx
import jax.numpy as jnp
rngs = nnx.Rngs(0)
x = jnp.ones((2, 3, 8, 8))
conv1 = nnx.Conv(in_features=(8, 8), out_features=(4, 4), kernel_size=(3,3), rngs=rngs)
https://colab.research.google.com/drive/1jIIowlJaQ-SyS59nfy-Wn8dYcgz4mUwH?usp=sharing