import copy
import jax
from EasyDel import AutoEasyDelModelForCausalLM, AutoEasyDelConfig, get_modules_by_type
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import MixtralForCausalLM
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel
pretrained_model_name_or_path = "/LLMs/Mixtral-8x7B-v0.1/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
dtype=jax.numpy.bfloat16,
param_dtype=jax.numpy.bfloat16,
precision=jax.lax.Precision("fastest"),
device=jax.devices('cpu')[0]
)
seq_len = 128
config = MixtralConfig(
hidden_size=256,
num_attention_heads=8,
num_hidden_layers=1,
num_key_value_heads=4,
intermediate_size=512,
num_local_experts=8,
max_position_embeddings=seq_len
)
torch_model = MixtralForCausalLM(
config=copy.deepcopy(config)
)
params = {"params":
huggingface_to_easydel(
torch_model.state_dict(),
embedding_layer_names=["embed_tokens"],
device=jax.devices("cpu")[0]
)
}
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer("Can you tell me who is the current president of the united states?", max_length=4096, padding='max_length', return_tensors='jax')
input_ids, attention_mask = tokens.input_ids, tokens.attention_mask
predict = model.generate(
input_ids,
attention_mask=attention_mask,
params=params)
Throws this error when generate is called

TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1, 4096) for operand shape (1, 20).
To Reproduce