Skip to content

Issue with Quantization of the multimodal model (sampling) with gm.nn.IntWrapper #516

@Reytuag

Description

@Reytuag

Hello,

When following the template from the quantization sampling example in the doc (Link), but for a multimodal usecase (the example in the doc is text only), I encounter this error:

Could not find parameter named "scale" in scope "/jit()/jit(vision_encoder)/siglip_encoder/Transformer/encoderblock_0/MlpBlock_0/Dense_0"

Minimal code to reproduce the error:

from gemma import gm
from gemma import peft
from gemma.gm.text import _sampler
import jax

model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(), dtype=jnp.int4)
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
params = peft.quantize(params, method='INT4', checkpoint_kernel_key='w')

sampler = _sampler.Sampler(
    model=model,
    params=params,
    cache_length=512,
    max_out_length=200,
)

rng=jax.random.PRNGKey(0)
out = sampler.sample(
    'Describe this image: <start_of_image>',
    images=image,
    max_new_tokens=200,
    rng=rng,
    return_state=True,
)

Exploring the code I found that the gm.nn.IntWrapper also applies to the Vision_encoder part of the model, whose parameters are in fact not quantized by peft.quantize.

I manage to make it work by making sure that the interceptor used by gm.nn.IntWrapper does not change the module of the vision encoder (keeping the vision encoder full precision, while the text model is quantized).
But I'm looking forward to a maybe cleaner fix in the library.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions