Skip to content

Text Generation with Mixtral fails #91

@clintg6

Description

@clintg6

Throws this error when generate is called
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1, 4096) for operand shape (1, 20).
image

To Reproduce

import copy
import jax
from EasyDel import AutoEasyDelModelForCausalLM, AutoEasyDelConfig, get_modules_by_type
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import MixtralForCausalLM
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel

pretrained_model_name_or_path = "/LLMs/Mixtral-8x7B-v0.1/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        dtype=jax.numpy.bfloat16,
        param_dtype=jax.numpy.bfloat16,
        precision=jax.lax.Precision("fastest"),
        device=jax.devices('cpu')[0]
)

seq_len = 128
config = MixtralConfig(
        hidden_size=256,
        num_attention_heads=8,
        num_hidden_layers=1,
        num_key_value_heads=4,
        intermediate_size=512,
        num_local_experts=8,
        max_position_embeddings=seq_len
)

torch_model = MixtralForCausalLM(
        config=copy.deepcopy(config)
)

params = {"params":
        huggingface_to_easydel(
            torch_model.state_dict(),
            embedding_layer_names=["embed_tokens"],
            device=jax.devices("cpu")[0]
        )
}
    
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer("Can you tell me who is the current president of the united states?", max_length=4096, padding='max_length', return_tensors='jax')
input_ids, attention_mask = tokens.input_ids, tokens.attention_mask 
predict = model.generate(
            input_ids,
            attention_mask=attention_mask,
            params=params)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions