-
Notifications
You must be signed in to change notification settings - Fork 645
Description
Hello,
When following the template from the quantization sampling example in the doc (Link), but for a multimodal usecase (the example in the doc is text only), I encounter this error:
Could not find parameter named "scale" in scope "/jit()/jit(vision_encoder)/siglip_encoder/Transformer/encoderblock_0/MlpBlock_0/Dense_0"
Minimal code to reproduce the error:
from gemma import gm
from gemma import peft
from gemma.gm.text import _sampler
import jax
model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(), dtype=jnp.int4)
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
params = peft.quantize(params, method='INT4', checkpoint_kernel_key='w')
sampler = _sampler.Sampler(
model=model,
params=params,
cache_length=512,
max_out_length=200,
)
rng=jax.random.PRNGKey(0)
out = sampler.sample(
'Describe this image: <start_of_image>',
images=image,
max_new_tokens=200,
rng=rng,
return_state=True,
)
Exploring the code I found that the gm.nn.IntWrapper also applies to the Vision_encoder part of the model, whose parameters are in fact not quantized by peft.quantize.
I manage to make it work by making sure that the interceptor used by gm.nn.IntWrapper does not change the module of the vision encoder (keeping the vision encoder full precision, while the text model is quantized).
But I'm looking forward to a maybe cleaner fix in the library.