Skip to content

NNX Conv Layer Input Tuple Error #4295

Open
@riverliway

Description

@riverliway

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04.3 LTS
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.9.0, 0.4.33, 0.4.33
  • Python version: 3.10
  • GPU/TPU model and memory: Colab T4 GPU
  • CUDA version (if applicable): N/A

Problem you have encountered:

Using a tuple as the input and/or output in the NNX conv layer results in an error that the // operator cannot be used on a tuple.

What you expected to happen:

That creating a new nnx.Conv layer with a tuple as the input/output does not produce an error. Using a conv layer Conv(in_features=(x, x), out_features=(y, y)) on an input shape of (a, b, x, x) will produce an output with shape (a, b, y, y), as alluded to in the documentation.

Image

Logs, error messages, etc:

Image

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

from flax import nnx
import jax.numpy as jnp

rngs = nnx.Rngs(0)
x = jnp.ones((2, 3, 8, 8))
conv1 = nnx.Conv(in_features=(8, 8), out_features=(4, 4), kernel_size=(3,3), rngs=rngs)

https://colab.research.google.com/drive/1jIIowlJaQ-SyS59nfy-Wn8dYcgz4mUwH?usp=sharing

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions