Skip to content

Shape-dependent initialization gives concretization error at apply. #1077

@GJBoth

Description

@GJBoth

I'm working on implementing Hadamard Transforms. I use the following function to build a hadamard matrix:

def hadamard(key, shape, dtype=jnp.float32):
    lg2 = jnp.log2(shape[0])
    H = jnp.ones((1, ), dtype=dtype)
    for i in jnp.arange(lg2):
        H = jnp.vstack([jnp.hstack([H, H]), jnp.hstack([H, -H])])
    H = 2**(-lg2 / 2) * H
    return H

I cannot lower this to a fori_loop or while, because the shape changes every iteration. I then put it in a simple dense layer:

class HadamardTransform(nn.Module):
    n_hadamard: int
        
    @nn.compact
    def __call__(self, X):
        kernel = self.param("kernel", hadamard, (self.n_hadamard, ))
        z = jnp.dot(X, kernel)
        return z

I can initialize this model and my parameters look fine (i.e. its a pytree with a device array):

# fake test data
key = random.PRNGKey(42)
X = random.normal(key, (1, 4096))

# Instantiating model
model = HadamardTransform(4096)
params = model.init(key, X)

However, when I want to do the forward pass with

model.apply(params, X)

I get a concretizationtype error on the for loop of the hadamard function: 'ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.'

I'm not sure why this is happening - I can init the model, so I would expect forward to work. Is jitting the apply function also jitting the hadamard function?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions