Skip to content

Commit e424a08

Browse files
committed
fix: model mismatch
1 parent 5348361 commit e424a08

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
MLLAMA_AVAILABLE = False
6969

7070
try:
71-
7271
from transformers import CLIPImageProcessor
7372
from transformers import CLIPVisionConfig
7473
from transformers import LlamaConfig
@@ -617,6 +616,8 @@ def run_mini_model_multimodal(
617616

618617
set_seed(42)
619618

619+
model = create_model(model_name).to(dtype).to(device)
620+
620621
revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]}
621622
if "mllama" in model_name:
622623
revert_kwargs["model_type"] = "conditional_generation"
@@ -636,13 +637,12 @@ def run_mini_model_multimodal(
636637
else:
637638
kwargs["swiglu"] = True
638639

639-
kwargs["model"] = create_model(model_name)
640+
kwargs["model"] = model
640641

641642
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
642643
else:
643644
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
644645

645-
model = create_model(model_name).to(dtype).to(device)
646646
model.gradient_checkpointing_enable()
647647

648648
train_dataset = create_multimodal_dataset(model_name)

0 commit comments

Comments
 (0)