Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
f84c03f
Add: LLaVa lce_loss
jp1924 Jan 8, 2025
4aaf1ef
Add: LLaVa monkey_patched
jp1924 Jan 8, 2025
8a6f8cd
Add: llava config
jp1924 Jan 9, 2025
4477479
Add: load_image_processing_config
jp1924 Jan 9, 2025
f77d647
Add: llava test
jp1924 Jan 9, 2025
53efec3
Add: llava
jp1924 Jan 9, 2025
1fadb0e
Add: llava test
jp1924 Jan 15, 2025
2f14cf6
Add: llava test
jp1924 Jan 15, 2025
a3db21c
Refactor: update llava model forward method and enhance loss handling
jp1924 Jan 16, 2025
79e2102
debugging
jp1924 Jan 16, 2025
80cfc08
Refactor: simplify lce_forward method and improve loss calculation
jp1924 Jan 16, 2025
24ca457
Refactor: clean up apply_liger_kernel_to_llava function by removing c…
jp1924 Jan 16, 2025
8c181a6
rm debugging
jp1924 Jan 16, 2025
24bc56c
Refactor: remove redundant import of llava_lce_forward in apply_liger…
jp1924 Jan 16, 2025
42dbe36
Refactor: update mini_llava model configuration and add test cases fo…
jp1924 Jan 16, 2025
f6ba33f
Fix: loss value error
jp1924 Jan 16, 2025
9223959
Refactor: add new processor and tokenizer configuration files, remove…
jp1924 Jan 16, 2025
61476e8
Refactor: clean up lce_forward function and update apply_liger_kernel…
jp1924 Jan 16, 2025
22cafbf
Add: processor configuration loading function and clean up model conf…
jp1924 Jan 16, 2025
7f61b13
Merge branch 'main' into add_llava
jp1924 Jan 16, 2025
21bdc0a
Fix: typo Qwen2-VL -> LLaVa
jp1924 Jan 20, 2025
37857ef
Refactor: remove unused input list from run_mini_model function
jp1924 Jan 20, 2025
23403c5
Add model initialization for llava in multimodal tests
jp1924 Jan 20, 2025
5fb2853
Add support for deprecated llava forward function and warning for leg…
jp1924 Jan 20, 2025
8289347
Merge branch 'main' into add_llava
lancerts Jan 21, 2025
dbc1cdd
Merge branch 'main' into add_llava
jp1924 Jan 22, 2025
5adf3e2
Clean: unused module
jp1924 Jan 23, 2025
5b72328
Update: validate kwargs & model
jp1924 Jan 23, 2025
9b1f929
Merge branch 'main' into add_llava
jp1924 Jan 23, 2025
35c1788
Fix: incorrect model input
jp1924 Jan 23, 2025
20e87db
Update: enhance documentation for Llava
jp1924 Jan 23, 2025
419467e
Clean: check model
jp1924 Jan 23, 2025
a12c95b
Merge branch 'main' into add_llava
jp1924 Jan 24, 2025
5801eb9
Fix: conflict
jp1924 Jan 24, 2025
a980a88
revert change
jp1924 Jan 24, 2025
bcca064
solve conflict
jp1924 Jan 26, 2025
4fe5307
solve conflict
jp1924 Jan 26, 2025
5dae0ff
fix conflict
jp1924 Jan 26, 2025
dcef62f
Merge branch 'main' into add_llava
jp1924 Jan 26, 2025
eaef787
solve conflict
jp1924 Jan 26, 2025
90587e8
Add: load_processor_config, load_image_processing_config
jp1924 Jan 26, 2025
ec94e9c
Add: revert_liger_kernel_to_llava
jp1924 Jan 26, 2025
34db41f
apply lint
jp1924 Jan 28, 2025
f29dce4
resolve conflict
jp1924 Jan 28, 2025
38ef343
Merge branch 'main' into add_llava
jp1924 Jan 28, 2025
3f360b5
Merge branch 'main' into add_llava
jp1924 Jan 30, 2025
01d8dd4
Merge branch 'main' into add_llava
jp1924 Feb 4, 2025
f02750b
Merge branch 'main' into add_llava
jp1924 Feb 11, 2025
b6a0c48
Merge branch 'main' into add_llava
jp1924 Feb 12, 2025
263aca3
split: 16 & 32
jp1924 Feb 21, 2025
7c9a891
solve conflict 1
jp1924 Feb 21, 2025
85d8db9
Merge branch 'main' into add_llava
jp1924 Feb 21, 2025
23b8b75
Merge branch 'main' into add_llava
jp1924 Feb 25, 2025
70921a5
apply checkstyle
jp1924 Feb 26, 2025
d32c208
Merge branch 'main' into add_llava
jp1924 Feb 27, 2025
969923b
Merge branch 'main' into add_llava
jp1924 Mar 3, 2025
9e44857
Merge branch 'main' into add_llava
jp1924 Mar 6, 2025
a2aabd9
Merge branch 'main' into add_llava
jp1924 Mar 7, 2025
fafeccb
Apply: #596
jp1924 Mar 8, 2025
a67c612
Update: llava
jp1924 Mar 8, 2025
7bd027c
apply lint
jp1924 Mar 8, 2025
fd3fc8e
Add: llava test
jp1924 Mar 8, 2025
c312a3c
rm: legacy
jp1924 Mar 8, 2025
22a34c3
revert indent
jp1924 Mar 8, 2025
eb495f5
fix: logit scale
jp1924 Mar 8, 2025
8e416f7
Add: deprecate_kwarg
jp1924 Mar 8, 2025
a651cb0
Fix: nn missing
jp1924 Mar 8, 2025
4358b82
Fix: correct variable name for num_logits_to_keep in lce_forward_depr…
jp1924 Mar 8, 2025
89cee9e
fix: deprecate version
jp1924 Mar 8, 2025
a452fdd
fix: add liger_kernel_patch_revert_func call in run_mini_model_multim…
jp1924 Mar 10, 2025
98868b7
fix: update deprecate_kwarg for num_logits_to_keep in lce_forward_dep…
jp1924 Mar 10, 2025
7261db0
fix: correct parameter name from num_logits_to_keep to logits_to_keep…
jp1924 Mar 10, 2025
c4cc4c5
fix: reshape shift_hidden_states and shift_labels when attention_mask…
jp1924 Mar 11, 2025
040d9a2
Merge branch 'main' into add_llava
jp1924 Mar 15, 2025
82bac66
fix: add model initialization for llava in mini model tests
jp1924 Mar 15, 2025
1437aa2
Merge branch 'main' into add_llava
jp1924 Mar 16, 2025
d5cb69f
Merge branch 'main' into add_llava
jp1924 Mar 16, 2025
824cb66
fix: tokenizer_base
jp1924 Mar 16, 2025
82e4517
fix: simplify model assignment in mini model tests
jp1924 Mar 16, 2025
8e74b53
fix: reload modeling modules in revert_liger_kernel_to_llava function
jp1924 Mar 16, 2025
46905b0
fix: refactor model creation in mini model tests
jp1924 Mar 16, 2025
3a51845
fix: update LLAVA model imports and refactor model creation in mini m…
jp1924 Mar 24, 2025
af1420b
Merge branch 'main' into add_llava
jp1924 Mar 24, 2025
70e57e0
Merge branch 'main' into add_llava
jp1924 Mar 24, 2025
5348361
fix: update hidden_size assignment in _patch_layer_norm_module for be…
jp1924 Mar 25, 2025
e424a08
fix: model mismatch
jp1924 Mar 25, 2025
0ec78db
Merge branch 'linkedin:main' into add_llava
jp1924 Mar 27, 2025
cd77128
Merge branch 'main' into add_llava
jp1924 Mar 30, 2025
41526d2
fix: refactor model creation and apply liger kernel for llava models
jp1924 Mar 30, 2025
eedc2d1
fix: remove redundant importlib.reload for modeling_llama in revert_l…
jp1924 Mar 30, 2025
6d9d951
Update test/convergence/bf16/test_mini_models.py
jp1924 Mar 31, 2025
c2c5186
fix: update model configuration parameters in test files for consistency
jp1924 Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
Expand Down
383 changes: 383 additions & 0 deletions src/liger_kernel/transformers/model/llava.py

Large diffs are not rendered by default.

85 changes: 84 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
Expand Down Expand Up @@ -57,7 +59,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i

def _patch_layer_norm_module(module, eps=1e-6):
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
module.hidden_size = module.normalized_shape
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)

_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
module.__class__.__name__ = LigerLayerNorm.__name__
Expand Down Expand Up @@ -224,6 +227,85 @@ def apply_liger_kernel_to_llama(
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_llava(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
model: PreTrainedModel = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
Comment thread
tyler-romero marked this conversation as resolved.
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
NOTE: Llava is not available in transformers<4.36.0

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.llava import modeling_llava

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse("4.49.0"):
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
else: # if version < 4.49.0
logger.warning(
"Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
)
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated

if model is not None:
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)

kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.language_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")

if vision_liger_fn:
accept_params = inspect.signature(vision_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.vision_tower
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")


def apply_liger_kernel_to_mllama(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -1071,6 +1153,7 @@ def apply_liger_kernel_to_olmo2(
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"llama": apply_liger_kernel_to_llama,
"llava": apply_liger_kernel_to_llava,
"granite": apply_liger_kernel_to_granite,
"mllama": apply_liger_kernel_to_mllama,
"mllama_text_model": apply_liger_kernel_to_mllama,
Expand Down
93 changes: 93 additions & 0 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
from liger_kernel.transformers import apply_liger_kernel_to_granite
from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import apply_liger_kernel_to_llava
from liger_kernel.transformers import apply_liger_kernel_to_mistral
from liger_kernel.transformers import apply_liger_kernel_to_mixtral
from liger_kernel.transformers import apply_liger_kernel_to_mllama
Expand All @@ -37,6 +38,7 @@
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_granite
from test.utils import revert_liger_kernel_to_llama
from test.utils import revert_liger_kernel_to_llava
from test.utils import revert_liger_kernel_to_mistral
from test.utils import revert_liger_kernel_to_mixtral
from test.utils import revert_liger_kernel_to_mllama
Expand Down Expand Up @@ -84,6 +86,15 @@
except ImportError:
GRANITE_AVAILABLE = False

try:
from transformers import CLIPVisionConfig
from transformers.models.llava.configuration_llava import LlavaConfig
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration

LLAVA_AVAILABLE = True
except ImportError:
LLAVA_AVAILABLE = False

try:
# OLMO2 is only available in transformers>=4.47.0
from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
Expand All @@ -93,6 +104,7 @@
except ImportError:
OLMO2_AVAILABLE = False


from liger_kernel.utils import infer_device

device = infer_device()
Expand Down Expand Up @@ -504,6 +516,65 @@
),
)

if LLAVA_AVAILABLE:
# https://huggingface.co/llava-hf/llava-1.5-7b-hf
MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_llava,
liger_kernel_patch_revert_func=revert_liger_kernel_to_llava,
model_class=LlavaForConditionalGeneration,
mini_model_config=LlavaConfig(
text_config=LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=2048,
num_attention_heads=8,
num_hidden_layers=4,
num_key_value_heads=2,
pretraining_tp=1,
rope_scaling=None,
rope_theta=500000.0,
tie_word_embeddings=False,
use_cache=True,
max_position_embeddings=4096, # llava-1.5-7b-hf
rms_norm_eps=1e-05, # llava-1.5-7b-hf
vocab_size=32064, # llava-1.5-7b-hf
# At rope backward
# Eager produces incontiguous dq and dk
# SDPA produces contiguous dq and incontiguous dk
# Flash_attn produces contiguous dq and dk
attn_implementation="sdpa", # default value, pytorch native attention
),
vision_config=CLIPVisionConfig(
hidden_size=1024,
image_size=336,
intermediate_size=2048, # 4096
model_type="clip_vision_model",
num_attention_heads=4, # 16
num_hidden_layers=4, # 24
patch_size=14,
projection_dim=768,
vocab_size=32000,
),
vocab_size=32064,
ignore_index=-100,
pad_token_id=4,
image_token_index=3,
projector_hidden_act="gelu",
vision_feature_layer=-2,
vision_feature_select_strategy="default",
# At rope backward
# Eager produces incontiguous dq and dk
# SDPA produces contiguous dq and incontiguous dk
# Flash_attn produces contiguous dq and dk
attn_implementation="sdpa", # default value, pytorch native attention
),
)

if OLMO2_AVAILABLE:
MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_olmo2,
Expand Down Expand Up @@ -577,6 +648,9 @@ def run_mini_model(
else:
kwargs["swiglu"] = True

if "llava" in model_name:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition would make other models not being patched.

Copy link
Copy Markdown
Contributor Author

@jp1924 jp1924 Mar 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, I think you should write a conditional patching method in your tests for models that require instance passing.

Hey, did you mean to pass model argument values to other models as well, including llava?

I interpreted it as liger patch not being applied to llava, so we have to subtract the argument value separately.
I'll fix that part and make it possible to pass the argument value to other models.

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 Mar 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, I think you should write a conditional patching method in your tests for models that require instance passing.

Hey, did you mean to pass model argument values to other models as well, including llava?

I interpreted it as liger patch not being applied to llava, so we have to subtract the argument value separately.
I'll fix that part and make it possible to pass the argument value to other models.

My bad, I was thinking to implement different workflows, create model -> patch for llava, patch -> create model for others, which complicates the testing. Since the main goal is "convergence", simply picking create model -> patch workflow that works for all cases should be better and cleaner.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't understand this, I pass kwargs['model'] and all tests fail. @Tcc0403 any idea why this is the case?

test convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/jp/Liger-Kernel
configfile: pyproject.toml
plugins: jaxtyping-0.2.36, anyio-4.6.2.post1

----------------------------- live log collection ------------------------------
INFO     numexpr.utils:utils.py:148 Note: NumExpr detected 48 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
INFO     numexpr.utils:utils.py:161 NumExpr defaulting to 16 threads.
INFO     datasets:config.py:54 PyTorch version 2.5.1+cu121 available.
collected 13 items

test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] FAILED [  7%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] 
-------------------------------- live log call ---------------------------------
WARNING  liger_kernel.transformers.monkey_patch:monkey_patch.py:293 clip_vision_model is not supported by Liger kernel.
FAILED                                                                   [ 15%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] FAILED [ 23%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.005-1e-05-0.005-1e-05] FAILED [ 30%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype4-1e-05-0.1-0.005-1e-05-0.005-1e-05] FAILED [ 38%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype5-1e-05-0.1-0.005-1e-05-0.005-1e-05] FAILED [ 46%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_olmo2-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] FAILED [ 53%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.005-1e-05-0.005-1e-05] FAILED [ 61%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05] FAILED [ 69%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype9-1e-08-0.0001-0.005-1e-05-0.005-1e-05] FAILED [ 76%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype10-1e-08-0.0001-0.005-1e-05-0.005-1e-05] FAILED [ 84%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma2-32-0.0001-dtype11-1e-08-0.0001-0.005-1e-05-0.005-1e-05] FAILED [ 92%]
test/convergence/fp32/test_mini_models.py::test_mini_model[mini_granite3-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05] FAILED [100%]

=================================== FAILURES ===================================
_ test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] _

model_name = 'mini_llama3', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 2e-05, logits_atol = 0.0001, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.6516,  2.1799,  0.8609,  0.7009,  0.6704,  0.7875,  0.7169,  0.9746,
          0.6826,  0.7658,  0.7752,  ...  0.6201,  0.3873,  0.4851,  0.4619,
          0.4543,  0.4247,  0.3782,  0.5069,  0.5950,  0.3770,  0.7420,  0.4775]])
tensor2 = tensor([[10.1816,  2.0437,  0.8494,  0.6971,  0.6678,  0.7906,  0.7074,  0.9635,
          0.6718,  0.7569,  0.7764,  ...  0.6062,  0.3785,  0.4786,  0.4540,
          0.4473,  0.4196,  0.3656,  0.4940,  0.5871,  0.3639,  0.7289,  0.4616]])
rtol = 2e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.651559829711914, tensor2[(0, 0)] = 10.181577682495117
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 2.1799447536468506, tensor2[(0, 1)] = 2.043672561645508
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 0.8609191179275513, tensor2[(0, 2)] = 0.849406898021698
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.7009443044662476, tensor2[(0, 3)] = 0.6970827579498291
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.6704153418540955, tensor2[(0, 4)] = 0.667807936668396
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.651559829711914
Step 1, Loss: 2.1799447536468506
Step 2, Loss: 0.8609191179275513
Step 3, Loss: 0.7009443044662476
Step 4, Loss: 0.6704153418540955
Step 5, Loss: 0.7874557375907898
Step 6, Loss: 0.716873288154602
Step 7, Loss: 0.9746330380439758
Step 8, Loss: 0.6826125383377075
Step 9, Loss: 0.7657949328422546
Step 10, Loss: 0.7752344012260437
Step 11, Loss: 0.47871941328048706
Step 12, Loss: 0.7085208296775818
Step 13, Loss: 0.6854400038719177
Step 14, Loss: 0.6153515577316284
Step 15, Loss: 0.7325771450996399
Step 16, Loss: 0.7106887698173523
Step 17, Loss: 0.7801141738891602
Step 18, Loss: 0.6898989081382751
Step 19, Loss: 0.7559919357299805
Step 20, Loss: 0.6201393008232117
Step 21, Loss: 0.3873378336429596
Step 22, Loss: 0.48507416248321533
Step 23, Loss: 0.4619227349758148
Step 24, Loss: 0.4543335437774658
Step 25, Loss: 0.42470917105674744
Step 26, Loss: 0.3781709372997284
Step 27, Loss: 0.5069137811660767
Step 28, Loss: 0.5949812531471252
Step 29, Loss: 0.3769553303718567
Step 30, Loss: 0.7420326471328735
Step 31, Loss: 0.4775380492210388
Liger kernel patches have been reverted.
Step 0, Loss: 10.181577682495117
Step 1, Loss: 2.043672561645508
Step 2, Loss: 0.849406898021698
Step 3, Loss: 0.6970827579498291
Step 4, Loss: 0.667807936668396
Step 5, Loss: 0.7905953526496887
Step 6, Loss: 0.7073678374290466
Step 7, Loss: 0.963516891002655
Step 8, Loss: 0.6717684864997864
Step 9, Loss: 0.7569032311439514
Step 10, Loss: 0.7764208912849426
Step 11, Loss: 0.47319403290748596
Step 12, Loss: 0.699938952922821
Step 13, Loss: 0.672838032245636
Step 14, Loss: 0.6139233112335205
Step 15, Loss: 0.7216891050338745
Step 16, Loss: 0.7029391527175903
Step 17, Loss: 0.7786717414855957
Step 18, Loss: 0.6813715100288391
Step 19, Loss: 0.7553117871284485
Step 20, Loss: 0.6062243580818176
Step 21, Loss: 0.3784829080104828
Step 22, Loss: 0.47860071063041687
Step 23, Loss: 0.4540420174598694
Step 24, Loss: 0.4473242461681366
Step 25, Loss: 0.41960909962654114
Step 26, Loss: 0.36560794711112976
Step 27, Loss: 0.49400007724761963
Step 28, Loss: 0.5870980024337769
Step 29, Loss: 0.3638893961906433
Step 30, Loss: 0.7288545370101929
Step 31, Loss: 0.4616243541240692
Liger kernel patches have been reverted.
_ test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_llava', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 1e-05, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.8153,  9.0049,  9.0121,  8.3405,  8.1161,  8.4512,  8.1446,  8.3607,
          7.8458,  7.7532,  8.0552,  ...  7.0205,  6.0649,  6.8576,  6.4975,
          6.5468,  6.7241,  6.5390,  7.0226,  7.1243,  6.2323,  7.1499,  6.6104]])
tensor2 = tensor([[10.3328,  8.9970,  8.9445,  8.2137,  8.0093,  8.3785,  8.0023,  8.2842,
          7.7432,  7.7735,  8.0738,  ...  7.0994,  6.2638,  6.9236,  6.7024,
          6.5430,  6.7543,  6.4407,  7.1172,  7.1593,  6.1753,  7.0963,  6.6384]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.81532096862793, tensor2[(0, 0)] = 10.332820892333984
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 9.004902839660645, tensor2[(0, 1)] = 8.99697208404541
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 9.01207447052002, tensor2[(0, 2)] = 8.944518089294434
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 8.34049129486084, tensor2[(0, 3)] = 8.213672637939453
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 8.116089820861816, tensor2[(0, 4)] = 8.009262084960938
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.81532096862793
Step 1, Loss: 9.004902839660645
Step 2, Loss: 9.01207447052002
Step 3, Loss: 8.34049129486084
Step 4, Loss: 8.116089820861816
Step 5, Loss: 8.451170921325684
Step 6, Loss: 8.144582748413086
Step 7, Loss: 8.360725402832031
Step 8, Loss: 7.845808982849121
Step 9, Loss: 7.753244876861572
Step 10, Loss: 8.055215835571289
Step 11, Loss: 6.473562717437744
Step 12, Loss: 7.501335620880127
Step 13, Loss: 7.2815632820129395
Step 14, Loss: 6.985671043395996
Step 15, Loss: 7.557857513427734
Step 16, Loss: 7.431850910186768
Step 17, Loss: 7.591641426086426
Step 18, Loss: 7.8725175857543945
Step 19, Loss: 7.626773357391357
Step 20, Loss: 7.020523548126221
Step 21, Loss: 6.064859867095947
Step 22, Loss: 6.857598781585693
Step 23, Loss: 6.497462272644043
Step 24, Loss: 6.546847343444824
Step 25, Loss: 6.724087238311768
Step 26, Loss: 6.539017677307129
Step 27, Loss: 7.022586822509766
Step 28, Loss: 7.124300479888916
Step 29, Loss: 6.23234748840332
Step 30, Loss: 7.1498613357543945
Step 31, Loss: 6.610395431518555
Liger kernel patches have been reverted.
Step 0, Loss: 10.332820892333984
Step 1, Loss: 8.99697208404541
Step 2, Loss: 8.944518089294434
Step 3, Loss: 8.213672637939453
Step 4, Loss: 8.009262084960938
Step 5, Loss: 8.378498077392578
Step 6, Loss: 8.00233268737793
Step 7, Loss: 8.284164428710938
Step 8, Loss: 7.743182182312012
Step 9, Loss: 7.773451328277588
Step 10, Loss: 8.073831558227539
Step 11, Loss: 6.484562873840332
Step 12, Loss: 7.451560974121094
Step 13, Loss: 7.3586530685424805
Step 14, Loss: 6.855362415313721
Step 15, Loss: 7.5276994705200195
Step 16, Loss: 7.417705535888672
Step 17, Loss: 7.677684307098389
Step 18, Loss: 7.8316192626953125
Step 19, Loss: 7.593496322631836
Step 20, Loss: 7.099413871765137
Step 21, Loss: 6.263830184936523
Step 22, Loss: 6.92359733581543
Step 23, Loss: 6.70235538482666
Step 24, Loss: 6.5430192947387695
Step 25, Loss: 6.754302024841309
Step 26, Loss: 6.440676689147949
Step 27, Loss: 7.117189884185791
Step 28, Loss: 7.159293174743652
Step 29, Loss: 6.175302028656006
Step 30, Loss: 7.096262454986572
Step 31, Loss: 6.63844108581543
Liger kernel patches have been reverted.
------------------------------ Captured log call -------------------------------
WARNING  liger_kernel.transformers.monkey_patch:monkey_patch.py:293 clip_vision_model is not supported by Liger kernel.
_ test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_mllama', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 1e-05, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.9785,  6.8807,  5.2030,  2.7935,  1.4178,  1.0796,  1.0623,  1.3844,
          0.9944,  1.1533,  1.1611,  ...  0.8267,  0.5791,  0.7015,  0.6890,
          0.6710,  0.6036,  0.5282,  0.6502,  0.7508,  0.5036,  0.9282,  0.6071]])
tensor2 = tensor([[11.5642,  7.1794,  5.6338,  3.1170,  1.3802,  1.0824,  0.9557,  1.3528,
          0.9690,  1.1335,  1.1417,  ...  0.8166,  0.5886,  0.6985,  0.6596,
          0.6372,  0.5665,  0.5001,  0.6307,  0.7307,  0.4899,  0.9123,  0.5907]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.97847843170166, tensor2[(0, 0)] = 11.564229965209961
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 6.880693435668945, tensor2[(0, 1)] = 7.179394245147705
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 5.203007221221924, tensor2[(0, 2)] = 5.633819580078125
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 2.79349684715271, tensor2[(0, 3)] = 3.116971254348755
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 1.4178184270858765, tensor2[(0, 4)] = 1.3802435398101807
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.97847843170166
Step 1, Loss: 6.880693435668945
Step 2, Loss: 5.203007221221924
Step 3, Loss: 2.79349684715271
Step 4, Loss: 1.4178184270858765
Step 5, Loss: 1.0796024799346924
Step 6, Loss: 1.062251329421997
Step 7, Loss: 1.3844060897827148
Step 8, Loss: 0.994351863861084
Step 9, Loss: 1.1533421277999878
Step 10, Loss: 1.1611496210098267
Step 11, Loss: 0.7341922521591187
Step 12, Loss: 1.085917353630066
Step 13, Loss: 1.0426464080810547
Step 14, Loss: 0.9200263023376465
Step 15, Loss: 1.0746867656707764
Step 16, Loss: 1.0197221040725708
Step 17, Loss: 1.0846731662750244
Step 18, Loss: 0.9172532558441162
Step 19, Loss: 1.0032974481582642
Step 20, Loss: 0.8267252445220947
Step 21, Loss: 0.5791030526161194
Step 22, Loss: 0.7015061378479004
Step 23, Loss: 0.6890019774436951
Step 24, Loss: 0.6710031628608704
Step 25, Loss: 0.6035834550857544
Step 26, Loss: 0.5281577706336975
Step 27, Loss: 0.6502063274383545
Step 28, Loss: 0.7507665753364563
Step 29, Loss: 0.5035553574562073
Step 30, Loss: 0.9281589388847351
Step 31, Loss: 0.6070712208747864
Liger kernel patches have been reverted.
Step 0, Loss: 11.564229965209961
Step 1, Loss: 7.179394245147705
Step 2, Loss: 5.633819580078125
Step 3, Loss: 3.116971254348755
Step 4, Loss: 1.3802435398101807
Step 5, Loss: 1.082366704940796
Step 6, Loss: 0.9556831121444702
Step 7, Loss: 1.3528398275375366
Step 8, Loss: 0.9690356254577637
Step 9, Loss: 1.1334657669067383
Step 10, Loss: 1.1417131423950195
Step 11, Loss: 0.7202014327049255
Step 12, Loss: 1.0641112327575684
Step 13, Loss: 1.0251208543777466
Step 14, Loss: 0.9098283052444458
Step 15, Loss: 1.056693196296692
Step 16, Loss: 1.0093092918395996
Step 17, Loss: 1.0651583671569824
Step 18, Loss: 0.8884314298629761
Step 19, Loss: 0.9759615659713745
Step 20, Loss: 0.8166440725326538
Step 21, Loss: 0.5886070132255554
Step 22, Loss: 0.6984806060791016
Step 23, Loss: 0.6595674157142639
Step 24, Loss: 0.6372134685516357
Step 25, Loss: 0.5665295124053955
Step 26, Loss: 0.5001171231269836
Step 27, Loss: 0.6307035684585571
Step 28, Loss: 0.7306686043739319
Step 29, Loss: 0.48985299468040466
Step 30, Loss: 0.9122832417488098
Step 31, Loss: 0.5906772613525391
Liger kernel patches have been reverted.
_ test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_qwen2', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 1e-05, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[11.0585,  2.3595,  0.8737,  0.7047,  0.6766,  0.7809,  0.7184,  0.9788,
          0.6817,  0.7750,  0.7798,  ...  0.6357,  0.4045,  0.5104,  0.4811,
          0.4740,  0.4416,  0.3968,  0.5272,  0.6155,  0.4000,  0.7706,  0.5069]])
tensor2 = tensor([[10.2202,  2.1691,  0.8705,  0.7057,  0.6822,  0.7976,  0.7253,  0.9839,
          0.7018,  0.7822,  0.7995,  ...  0.6380,  0.4088,  0.5076,  0.4798,
          0.4764,  0.4476,  0.4039,  0.5276,  0.6252,  0.4035,  0.7882,  0.5055]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 11.058478355407715, tensor2[(0, 0)] = 10.220199584960938
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 2.359509229660034, tensor2[(0, 1)] = 2.169123649597168
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 0.8737365007400513, tensor2[(0, 2)] = 0.8704574704170227
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.7047495245933533, tensor2[(0, 3)] = 0.7057247757911682
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.6766155958175659, tensor2[(0, 4)] = 0.6822481155395508
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 11.058478355407715
Step 1, Loss: 2.359509229660034
Step 2, Loss: 0.8737365007400513
Step 3, Loss: 0.7047495245933533
Step 4, Loss: 0.6766155958175659
Step 5, Loss: 0.7808870077133179
Step 6, Loss: 0.7183868288993835
Step 7, Loss: 0.9787932634353638
Step 8, Loss: 0.6816532611846924
Step 9, Loss: 0.7750170826911926
Step 10, Loss: 0.779798686504364
Step 11, Loss: 0.47361063957214355
Step 12, Loss: 0.7199702262878418
Step 13, Loss: 0.6917579770088196
Step 14, Loss: 0.625240683555603
Step 15, Loss: 0.7451457977294922
Step 16, Loss: 0.7356531620025635
Step 17, Loss: 0.8098469972610474
Step 18, Loss: 0.715949296951294
Step 19, Loss: 0.775590717792511
Step 20, Loss: 0.6357059478759766
Step 21, Loss: 0.40446850657463074
Step 22, Loss: 0.510444700717926
Step 23, Loss: 0.48105382919311523
Step 24, Loss: 0.4739769697189331
Step 25, Loss: 0.4415643513202667
Step 26, Loss: 0.3968295156955719
Step 27, Loss: 0.5271892547607422
Step 28, Loss: 0.6155114769935608
Step 29, Loss: 0.3999830186367035
Step 30, Loss: 0.7705995440483093
Step 31, Loss: 0.5069104433059692
Liger kernel patches have been reverted.
Applied Liger kernels to Qwen2
Step 0, Loss: 10.220199584960938
Step 1, Loss: 2.169123649597168
Step 2, Loss: 0.8704574704170227
Step 3, Loss: 0.7057247757911682
Step 4, Loss: 0.6822481155395508
Step 5, Loss: 0.7976466417312622
Step 6, Loss: 0.7252814769744873
Step 7, Loss: 0.9839460849761963
Step 8, Loss: 0.701795220375061
Step 9, Loss: 0.782172679901123
Step 10, Loss: 0.799495279788971
Step 11, Loss: 0.49434420466423035
Step 12, Loss: 0.7387100458145142
Step 13, Loss: 0.705495297908783
Step 14, Loss: 0.6422168016433716
Step 15, Loss: 0.7519472241401672
Step 16, Loss: 0.7420428991317749
Step 17, Loss: 0.8044173121452332
Step 18, Loss: 0.715694010257721
Step 19, Loss: 0.7845567464828491
Step 20, Loss: 0.6379701495170593
Step 21, Loss: 0.4087505042552948
Step 22, Loss: 0.5076121091842651
Step 23, Loss: 0.4798239767551422
Step 24, Loss: 0.47640153765678406
Step 25, Loss: 0.44758713245391846
Step 26, Loss: 0.40385881066322327
Step 27, Loss: 0.5275834202766418
Step 28, Loss: 0.6252421736717224
Step 29, Loss: 0.40345272421836853
Step 30, Loss: 0.7882486581802368
Step 31, Loss: 0.5054807662963867
Liger kernel patches have been reverted.
----------------------------- Captured stderr call -----------------------------
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
_ test_mini_model[mini_qwen2_vl-32-0.0001-dtype4-1e-05-0.1-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_qwen2_vl', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-05, loss_rtol = 0.1, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[11.5665,  0.7003,  0.8316,  0.6783,  0.6395,  0.7315,  0.6408,  0.9450,
          0.6134,  0.7097,  0.7467,  ...  0.5771,  0.3562,  0.4486,  0.4165,
          0.4210,  0.4156,  0.3573,  0.4765,  0.5671,  0.3587,  0.6944,  0.4625]])
tensor2 = tensor([[10.0765,  0.6772,  0.8197,  0.6674,  0.6292,  0.7310,  0.6350,  0.8791,
          0.5850,  0.6574,  0.6953,  ...  0.5223,  0.3140,  0.4146,  0.3775,
          0.3821,  0.3570,  0.3200,  0.4395,  0.5268,  0.3123,  0.6465,  0.4126]])
rtol = 0.1, atol = 1e-05, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 15
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 11.566545486450195, tensor2[(0, 0)] = 10.076494216918945
E           Mismatch at index (0, 11): tensor1[(0, 11)] = 0.4500718116760254, tensor2[(0, 11)] = 0.37524980306625366
E           Mismatch at index (0, 12): tensor1[(0, 12)] = 0.6924243569374084, tensor2[(0, 12)] = 0.6108517050743103
E           Mismatch at index (0, 13): tensor1[(0, 13)] = 0.6685009002685547, tensor2[(0, 13)] = 0.581950843334198
E           Mismatch at index (0, 14): tensor1[(0, 14)] = 0.594361424446106, tensor2[(0, 14)] = 0.5191696882247925
E           ... and 10 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 11.566545486450195
Step 1, Loss: 0.7002792954444885
Step 2, Loss: 0.8316103219985962
Step 3, Loss: 0.6782925128936768
Step 4, Loss: 0.6394776105880737
Step 5, Loss: 0.731537401676178
Step 6, Loss: 0.6408157348632812
Step 7, Loss: 0.9449859857559204
Step 8, Loss: 0.6134490966796875
Step 9, Loss: 0.7096973061561584
Step 10, Loss: 0.7467284798622131
Step 11, Loss: 0.4500718116760254
Step 12, Loss: 0.6924243569374084
Step 13, Loss: 0.6685009002685547
Step 14, Loss: 0.594361424446106
Step 15, Loss: 0.7167119383811951
Step 16, Loss: 0.696448028087616
Step 17, Loss: 0.7723467946052551
Step 18, Loss: 0.6912583112716675
Step 19, Loss: 0.7467405796051025
Step 20, Loss: 0.5770792961120605
Step 21, Loss: 0.3562013804912567
Step 22, Loss: 0.44857344031333923
Step 23, Loss: 0.41650092601776123
Step 24, Loss: 0.4209594130516052
Step 25, Loss: 0.4156118929386139
Step 26, Loss: 0.3572855591773987
Step 27, Loss: 0.4764554798603058
Step 28, Loss: 0.5670983195304871
Step 29, Loss: 0.3587135970592499
Step 30, Loss: 0.694358766078949
Step 31, Loss: 0.4624866247177124
Liger kernel patches have been reverted.
Step 0, Loss: 10.076494216918945
Step 1, Loss: 0.6771941184997559
Step 2, Loss: 0.8197367191314697
Step 3, Loss: 0.6673524975776672
Step 4, Loss: 0.6292317509651184
Step 5, Loss: 0.731016993522644
Step 6, Loss: 0.6350122690200806
Step 7, Loss: 0.8791137337684631
Step 8, Loss: 0.585008978843689
Step 9, Loss: 0.6574373841285706
Step 10, Loss: 0.6952733397483826
Step 11, Loss: 0.37524980306625366
Step 12, Loss: 0.6108517050743103
Step 13, Loss: 0.581950843334198
Step 14, Loss: 0.5191696882247925
Step 15, Loss: 0.6507109999656677
Step 16, Loss: 0.6424310207366943
Step 17, Loss: 0.7138785719871521
Step 18, Loss: 0.6281083226203918
Step 19, Loss: 0.6843845248222351
Step 20, Loss: 0.5223273038864136
Step 21, Loss: 0.3139645457267761
Step 22, Loss: 0.41460877656936646
Step 23, Loss: 0.3775334358215332
Step 24, Loss: 0.38212329149246216
Step 25, Loss: 0.35704103112220764
Step 26, Loss: 0.32000523805618286
Step 27, Loss: 0.43948787450790405
Step 28, Loss: 0.5267584323883057
Step 29, Loss: 0.31233325600624084
Step 30, Loss: 0.6465412378311157
Step 31, Loss: 0.4126203954219818
Liger kernel patches have been reverted.
_ test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype5-1e-05-0.1-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_qwen2_5_vl', num_steps = 32, lr = 0.0001
dtype = torch.float32, loss_atol = 1e-05, loss_rtol = 0.1, logits_atol = 0.005
logits_rtol = 1e-05, param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[12.0822,  0.7677,  0.8430,  0.6976,  0.6749,  0.7734,  0.6874,  0.9115,
          0.5968,  0.6990,  0.6870,  ...  0.5519,  0.3374,  0.4415,  0.4146,
          0.4155,  0.4046,  0.3577,  0.4754,  0.5705,  0.3547,  0.6901,  0.4538]])
tensor2 = tensor([[10.5113,  0.6981,  0.8267,  0.6682,  0.6340,  0.7203,  0.6464,  0.9055,
          0.5970,  0.6794,  0.7126,  ...  0.5456,  0.3410,  0.4319,  0.3988,
          0.3952,  0.3709,  0.3340,  0.4576,  0.5344,  0.3261,  0.6593,  0.4224]])
rtol = 0.1, atol = 1e-05, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 1
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 12.082216262817383, tensor2[(0, 0)] = 10.511345863342285

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 12.082216262817383
Step 1, Loss: 0.7676835060119629
Step 2, Loss: 0.8430027961730957
Step 3, Loss: 0.6976154446601868
Step 4, Loss: 0.6748846769332886
Step 5, Loss: 0.7733780741691589
Step 6, Loss: 0.6873887777328491
Step 7, Loss: 0.9114535450935364
Step 8, Loss: 0.5968311429023743
Step 9, Loss: 0.6989853978157043
Step 10, Loss: 0.687013566493988
Step 11, Loss: 0.3859785199165344
Step 12, Loss: 0.6277555823326111
Step 13, Loss: 0.60830157995224
Step 14, Loss: 0.5365318655967712
Step 15, Loss: 0.6612260341644287
Step 16, Loss: 0.6523535847663879
Step 17, Loss: 0.722366452217102
Step 18, Loss: 0.6390480399131775
Step 19, Loss: 0.7095366716384888
Step 20, Loss: 0.5519123673439026
Step 21, Loss: 0.337378591299057
Step 22, Loss: 0.44145601987838745
Step 23, Loss: 0.4145508408546448
Step 24, Loss: 0.41546767950057983
Step 25, Loss: 0.4046224057674408
Step 26, Loss: 0.35766521096229553
Step 27, Loss: 0.47539013624191284
Step 28, Loss: 0.570549488067627
Step 29, Loss: 0.3547478914260864
Step 30, Loss: 0.6901333332061768
Step 31, Loss: 0.4537835717201233
Liger kernel patches have been reverted.
Step 0, Loss: 10.511345863342285
Step 1, Loss: 0.698088526725769
Step 2, Loss: 0.8267481923103333
Step 3, Loss: 0.6681709289550781
Step 4, Loss: 0.6340337991714478
Step 5, Loss: 0.7202634215354919
Step 6, Loss: 0.6464317440986633
Step 7, Loss: 0.9055057168006897
Step 8, Loss: 0.5970208048820496
Step 9, Loss: 0.6794081926345825
Step 10, Loss: 0.7126320004463196
Step 11, Loss: 0.39510583877563477
Step 12, Loss: 0.6385011672973633
Step 13, Loss: 0.6056986451148987
Step 14, Loss: 0.5322800278663635
Step 15, Loss: 0.6470276713371277
Step 16, Loss: 0.6550820469856262
Step 17, Loss: 0.7208991050720215
Step 18, Loss: 0.6374115943908691
Step 19, Loss: 0.6962684988975525
Step 20, Loss: 0.5456441640853882
Step 21, Loss: 0.3409982919692993
Step 22, Loss: 0.43185386061668396
Step 23, Loss: 0.3988208770751953
Step 24, Loss: 0.39517685770988464
Step 25, Loss: 0.3709378242492676
Step 26, Loss: 0.33398640155792236
Step 27, Loss: 0.4576243758201599
Step 28, Loss: 0.5343964695930481
Step 29, Loss: 0.3261362314224243
Step 30, Loss: 0.6592898964881897
Step 31, Loss: 0.4223594069480896
Liger kernel patches have been reverted.
_ test_mini_model[mini_olmo2-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_olmo2', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 1e-05, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.6500,  7.0135,  3.9183,  1.5300,  0.8660,  0.9542,  0.8878,  1.2609,
          0.9040,  1.0310,  1.0320,  ...  0.7837,  0.5266,  0.6080,  0.5741,
          0.5660,  0.5096,  0.4608,  0.5860,  0.6817,  0.4531,  0.8551,  0.5535]])
tensor2 = tensor([[10.7046,  6.8448,  3.9706,  1.4798,  0.8829,  0.9579,  0.8743,  1.2318,
          0.8670,  0.9912,  0.9969,  ...  0.7333,  0.4728,  0.5714,  0.5362,
          0.5369,  0.4941,  0.4407,  0.5692,  0.6735,  0.4373,  0.8328,  0.5408]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.649967193603516, tensor2[(0, 0)] = 10.704586029052734
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 7.013522148132324, tensor2[(0, 1)] = 6.844752788543701
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 3.9183390140533447, tensor2[(0, 2)] = 3.9706192016601562
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 1.5299837589263916, tensor2[(0, 3)] = 1.4797770977020264
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.8659607768058777, tensor2[(0, 4)] = 0.8829061985015869
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.649967193603516
Step 1, Loss: 7.013522148132324
Step 2, Loss: 3.9183390140533447
Step 3, Loss: 1.5299837589263916
Step 4, Loss: 0.8659607768058777
Step 5, Loss: 0.954227864742279
Step 6, Loss: 0.8877675533294678
Step 7, Loss: 1.2608544826507568
Step 8, Loss: 0.90403151512146
Step 9, Loss: 1.0310240983963013
Step 10, Loss: 1.0319746732711792
Step 11, Loss: 0.6489917635917664
Step 12, Loss: 0.9354605078697205
Step 13, Loss: 0.8912016749382019
Step 14, Loss: 0.7892565131187439
Step 15, Loss: 0.9090968370437622
Step 16, Loss: 0.8636245727539062
Step 17, Loss: 0.9345141053199768
Step 18, Loss: 0.8144885897636414
Step 19, Loss: 0.9413245320320129
Step 20, Loss: 0.7836901545524597
Step 21, Loss: 0.5266157984733582
Step 22, Loss: 0.6079881191253662
Step 23, Loss: 0.574065625667572
Step 24, Loss: 0.5660095810890198
Step 25, Loss: 0.5096372365951538
Step 26, Loss: 0.4608082175254822
Step 27, Loss: 0.5859861373901367
Step 28, Loss: 0.6817017793655396
Step 29, Loss: 0.45309633016586304
Step 30, Loss: 0.8551167845726013
Step 31, Loss: 0.5534916520118713
Liger kernel patches have been reverted.
Step 0, Loss: 10.704586029052734
Step 1, Loss: 6.844752788543701
Step 2, Loss: 3.9706192016601562
Step 3, Loss: 1.4797770977020264
Step 4, Loss: 0.8829061985015869
Step 5, Loss: 0.9579141139984131
Step 6, Loss: 0.8742548227310181
Step 7, Loss: 1.2317748069763184
Step 8, Loss: 0.8670055270195007
Step 9, Loss: 0.991202175617218
Step 10, Loss: 0.9968867897987366
Step 11, Loss: 0.6167395710945129
Step 12, Loss: 0.8750386834144592
Step 13, Loss: 0.8204431533813477
Step 14, Loss: 0.731562614440918
Step 15, Loss: 0.8526293039321899
Step 16, Loss: 0.8303168416023254
Step 17, Loss: 0.9314197301864624
Step 18, Loss: 0.8446087837219238
Step 19, Loss: 0.8938522934913635
Step 20, Loss: 0.7333099246025085
Step 21, Loss: 0.4727902412414551
Step 22, Loss: 0.5713554620742798
Step 23, Loss: 0.5362161993980408
Step 24, Loss: 0.5369151830673218
Step 25, Loss: 0.49410873651504517
Step 26, Loss: 0.4407224953174591
Step 27, Loss: 0.569237232208252
Step 28, Loss: 0.673481285572052
Step 29, Loss: 0.43728938698768616
Step 30, Loss: 0.8327814936637878
Step 31, Loss: 0.5408373475074768
Liger kernel patches have been reverted.
_ test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_phi3', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 1e-05, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[11.0310,  2.3919,  0.8640,  0.7026,  0.6790,  0.7819,  0.7124,  0.9715,
          0.6776,  0.7696,  0.7849,  ...  0.6209,  0.3927,  0.4986,  0.4713,
          0.4617,  0.4346,  0.3885,  0.5137,  0.6146,  0.3939,  0.7605,  0.4955]])
tensor2 = tensor([[9.4221, 1.5946, 0.8646, 0.7050, 0.6782, 0.7911, 0.7188, 0.9777, 0.6778,
         0.7700, 0.7810, 0.4701, 0.71...071, 0.7805, 0.6222, 0.3999, 0.4990, 0.4746, 0.4633, 0.4363, 0.3946,
         0.5259, 0.6116, 0.3954, 0.7658, 0.4964]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 11.031005859375, tensor2[(0, 0)] = 9.422085762023926
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 2.3918755054473877, tensor2[(0, 1)] = 1.5946464538574219
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 0.8639824390411377, tensor2[(0, 2)] = 0.8645725250244141
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.7025548815727234, tensor2[(0, 3)] = 0.7049572467803955
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.6789949536323547, tensor2[(0, 4)] = 0.6782235503196716
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 11.031005859375
Step 1, Loss: 2.3918755054473877
Step 2, Loss: 0.8639824390411377
Step 3, Loss: 0.7025548815727234
Step 4, Loss: 0.6789949536323547
Step 5, Loss: 0.7818920612335205
Step 6, Loss: 0.7123703956604004
Step 7, Loss: 0.9714922308921814
Step 8, Loss: 0.6776041388511658
Step 9, Loss: 0.7695889472961426
Step 10, Loss: 0.7848676443099976
Step 11, Loss: 0.46163010597229004
Step 12, Loss: 0.7140446901321411
Step 13, Loss: 0.6872910857200623
Step 14, Loss: 0.619159996509552
Step 15, Loss: 0.7348251342773438
Step 16, Loss: 0.7239388823509216
Step 17, Loss: 0.8000224828720093
Step 18, Loss: 0.7025389075279236
Step 19, Loss: 0.7704784870147705
Step 20, Loss: 0.6208988428115845
Step 21, Loss: 0.3926902413368225
Step 22, Loss: 0.49856045842170715
Step 23, Loss: 0.4713331460952759
Step 24, Loss: 0.4616822898387909
Step 25, Loss: 0.43457287549972534
Step 26, Loss: 0.38849711418151855
Step 27, Loss: 0.5137475728988647
Step 28, Loss: 0.6145758032798767
Step 29, Loss: 0.393913209438324
Step 30, Loss: 0.7605491876602173
Step 31, Loss: 0.4954882860183716
Liger kernel patches have been reverted.
Step 0, Loss: 9.422085762023926
Step 1, Loss: 1.5946464538574219
Step 2, Loss: 0.8645725250244141
Step 3, Loss: 0.7049572467803955
Step 4, Loss: 0.6782235503196716
Step 5, Loss: 0.7911245226860046
Step 6, Loss: 0.7187986373901367
Step 7, Loss: 0.9776793122291565
Step 8, Loss: 0.6778076887130737
Step 9, Loss: 0.7699770927429199
Step 10, Loss: 0.7809611558914185
Step 11, Loss: 0.47009801864624023
Step 12, Loss: 0.7190852761268616
Step 13, Loss: 0.6920911073684692
Step 14, Loss: 0.6234636306762695
Step 15, Loss: 0.7407030463218689
Step 16, Loss: 0.7293444871902466
Step 17, Loss: 0.8047713041305542
Step 18, Loss: 0.7070678472518921
Step 19, Loss: 0.7805389761924744
Step 20, Loss: 0.6221668720245361
Step 21, Loss: 0.3999490439891815
Step 22, Loss: 0.4989567995071411
Step 23, Loss: 0.47456565499305725
Step 24, Loss: 0.46330246329307556
Step 25, Loss: 0.43631160259246826
Step 26, Loss: 0.3946307897567749
Step 27, Loss: 0.5258839130401611
Step 28, Loss: 0.6115926504135132
Step 29, Loss: 0.3953520357608795
Step 30, Loss: 0.7657529711723328
Step 31, Loss: 0.4963764250278473
Liger kernel patches have been reverted.
_ test_mini_model[mini_mistral-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_mistral', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 1e-05, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.6504,  2.1981,  0.8611,  0.7005,  0.6689,  0.7863,  0.7165,  0.9742,
          0.6818,  0.7652,  0.7750,  ...  0.6202,  0.3877,  0.4857,  0.4621,
          0.4545,  0.4249,  0.3780,  0.5064,  0.5941,  0.3757,  0.7396,  0.4748]])
tensor2 = tensor([[10.1866,  2.0517,  0.8494,  0.6975,  0.6682,  0.7902,  0.7079,  0.9638,
          0.6723,  0.7580,  0.7772,  ...  0.6049,  0.3775,  0.4772,  0.4518,
          0.4450,  0.4175,  0.3636,  0.4920,  0.5843,  0.3614,  0.7265,  0.4594]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.650413513183594, tensor2[(0, 0)] = 10.186630249023438
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 2.1980717182159424, tensor2[(0, 1)] = 2.051693916320801
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 0.8610823750495911, tensor2[(0, 2)] = 0.8493916988372803
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.7004667520523071, tensor2[(0, 3)] = 0.6974519491195679
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.6688546538352966, tensor2[(0, 4)] = 0.6681956052780151
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.650413513183594
Step 1, Loss: 2.1980717182159424
Step 2, Loss: 0.8610823750495911
Step 3, Loss: 0.7004667520523071
Step 4, Loss: 0.6688546538352966
Step 5, Loss: 0.7862818837165833
Step 6, Loss: 0.7164952158927917
Step 7, Loss: 0.9742444157600403
Step 8, Loss: 0.6818380951881409
Step 9, Loss: 0.7652358412742615
Step 10, Loss: 0.7749951481819153
Step 11, Loss: 0.4780674874782562
Step 12, Loss: 0.7077803611755371
Step 13, Loss: 0.684683620929718
Step 14, Loss: 0.6145927309989929
Step 15, Loss: 0.7329809069633484
Step 16, Loss: 0.7110365033149719
Step 17, Loss: 0.7807507514953613
Step 18, Loss: 0.6901764869689941
Step 19, Loss: 0.7565171122550964
Step 20, Loss: 0.6201580762863159
Step 21, Loss: 0.38774964213371277
Step 22, Loss: 0.4856514036655426
Step 23, Loss: 0.4620998203754425
Step 24, Loss: 0.45451074838638306
Step 25, Loss: 0.4248639941215515
Step 26, Loss: 0.3780428171157837
Step 27, Loss: 0.5063985586166382
Step 28, Loss: 0.5941102504730225
Step 29, Loss: 0.3757373094558716
Step 30, Loss: 0.7395505905151367
Step 31, Loss: 0.47481733560562134
Liger kernel patches have been reverted.
Step 0, Loss: 10.186630249023438
Step 1, Loss: 2.051693916320801
Step 2, Loss: 0.8493916988372803
Step 3, Loss: 0.6974519491195679
Step 4, Loss: 0.6681956052780151
Step 5, Loss: 0.7902041673660278
Step 6, Loss: 0.7079319953918457
Step 7, Loss: 0.9638433456420898
Step 8, Loss: 0.6723192930221558
Step 9, Loss: 0.7579817175865173
Step 10, Loss: 0.7771773934364319
Step 11, Loss: 0.4739665389060974
Step 12, Loss: 0.6998313069343567
Step 13, Loss: 0.6728542447090149
Step 14, Loss: 0.6151115894317627
Step 15, Loss: 0.7236663699150085
Step 16, Loss: 0.7049237489700317
Step 17, Loss: 0.780116856098175
Step 18, Loss: 0.6817130446434021
Step 19, Loss: 0.7547957301139832
Step 20, Loss: 0.6049016118049622
Step 21, Loss: 0.3775196671485901
Step 22, Loss: 0.47715944051742554
Step 23, Loss: 0.4517800211906433
Step 24, Loss: 0.4450123608112335
Step 25, Loss: 0.41752997040748596
Step 26, Loss: 0.3636443614959717
Step 27, Loss: 0.4920283257961273
Step 28, Loss: 0.5842927694320679
Step 29, Loss: 0.36143186688423157
Step 30, Loss: 0.7265202403068542
Step 31, Loss: 0.45942458510398865
Liger kernel patches have been reverted.
_ test_mini_model[mini_gemma1-32-0.0001-dtype9-1e-08-0.0001-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_gemma1', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 0.0001, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[0.8182, 0.8556, 1.0365, 0.8061, 0.7326, 0.8181, 0.7028, 0.9490, 0.6424,
         0.7324, 0.7475, 0.4348, 0.67...584, 0.7241, 0.5794, 0.3548, 0.4489, 0.4197, 0.4129, 0.3807, 0.3338,
         0.4561, 0.5375, 0.3328, 0.6660, 0.4200]])
tensor2 = tensor([[0.8417, 0.8618, 1.0506, 0.8105, 0.7480, 0.8391, 0.7081, 0.9673, 0.6531,
         0.7435, 0.7460, 0.4431, 0.67...653, 0.7283, 0.5878, 0.3678, 0.4586, 0.4360, 0.4321, 0.4014, 0.3513,
         0.4752, 0.5597, 0.3492, 0.6921, 0.4344]])
rtol = 0.0001, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 0.8181616067886353, tensor2[(0, 0)] = 0.8416699171066284
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 0.8555576801300049, tensor2[(0, 1)] = 0.8618330955505371
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 1.0364792346954346, tensor2[(0, 2)] = 1.0505924224853516
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.806117594242096, tensor2[(0, 3)] = 0.8105048537254333
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.7325620055198669, tensor2[(0, 4)] = 0.7479591965675354
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 0.8181616067886353
Step 1, Loss: 0.8555576801300049
Step 2, Loss: 1.0364792346954346
Step 3, Loss: 0.806117594242096
Step 4, Loss: 0.7325620055198669
Step 5, Loss: 0.8181335926055908
Step 6, Loss: 0.7028118968009949
Step 7, Loss: 0.9490224719047546
Step 8, Loss: 0.6424455046653748
Step 9, Loss: 0.7323514223098755
Step 10, Loss: 0.7475058436393738
Step 11, Loss: 0.4348335266113281
Step 12, Loss: 0.6739875078201294
Step 13, Loss: 0.6441472172737122
Step 14, Loss: 0.5794464945793152
Step 15, Loss: 0.6939228177070618
Step 16, Loss: 0.6797480583190918
Step 17, Loss: 0.7528703212738037
Step 18, Loss: 0.6584448218345642
Step 19, Loss: 0.7241498827934265
Step 20, Loss: 0.5793939828872681
Step 21, Loss: 0.3547557592391968
Step 22, Loss: 0.4489109218120575
Step 23, Loss: 0.4197341203689575
Step 24, Loss: 0.41293245553970337
Step 25, Loss: 0.3807133138179779
Step 26, Loss: 0.33384984731674194
Step 27, Loss: 0.4561339020729065
Step 28, Loss: 0.5375469326972961
Step 29, Loss: 0.33277252316474915
Step 30, Loss: 0.6659684181213379
Step 31, Loss: 0.42003539204597473
Liger kernel patches have been reverted.
Step 0, Loss: 0.8416699171066284
Step 1, Loss: 0.8618330955505371
Step 2, Loss: 1.0505924224853516
Step 3, Loss: 0.8105048537254333
Step 4, Loss: 0.7479591965675354
Step 5, Loss: 0.8391144871711731
Step 6, Loss: 0.7081260085105896
Step 7, Loss: 0.9673135876655579
Step 8, Loss: 0.6531041264533997
Step 9, Loss: 0.743484616279602
Step 10, Loss: 0.7459954619407654
Step 11, Loss: 0.4431079924106598
Step 12, Loss: 0.6759545803070068
Step 13, Loss: 0.6579065918922424
Step 14, Loss: 0.5922181606292725
Step 15, Loss: 0.6976196765899658
Step 16, Loss: 0.6879425644874573
Step 17, Loss: 0.7612311840057373
Step 18, Loss: 0.6653209924697876
Step 19, Loss: 0.7283201813697815
Step 20, Loss: 0.5877732038497925
Step 21, Loss: 0.3677934408187866
Step 22, Loss: 0.4585503041744232
Step 23, Loss: 0.43603450059890747
Step 24, Loss: 0.43212780356407166
Step 25, Loss: 0.40142807364463806
Step 26, Loss: 0.3513464033603668
Step 27, Loss: 0.47521498799324036
Step 28, Loss: 0.5596947073936462
Step 29, Loss: 0.34919875860214233
Step 30, Loss: 0.6920917630195618
Step 31, Loss: 0.43439042568206787
Liger kernel patches have been reverted.
_ test_mini_model[mini_gemma1.1-32-0.0001-dtype10-1e-08-0.0001-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_gemma1.1', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 0.0001, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[0.8181, 0.8556, 1.0365, 0.8061, 0.7326, 0.8181, 0.7028, 0.9490, 0.6424,
         0.7323, 0.7475, 0.4348, 0.67...584, 0.7242, 0.5794, 0.3548, 0.4489, 0.4197, 0.4129, 0.3807, 0.3339,
         0.4561, 0.5375, 0.3328, 0.6660, 0.4200]])
tensor2 = tensor([[0.8417, 0.8618, 1.0506, 0.8105, 0.7480, 0.8391, 0.7081, 0.9673, 0.6531,
         0.7435, 0.7460, 0.4431, 0.67...653, 0.7283, 0.5878, 0.3678, 0.4586, 0.4360, 0.4321, 0.4014, 0.3513,
         0.4752, 0.5597, 0.3492, 0.6921, 0.4344]])
rtol = 0.0001, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 0.818123459815979, tensor2[(0, 0)] = 0.8416699171066284
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 0.8555735945701599, tensor2[(0, 1)] = 0.8618330955505371
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 1.0364962816238403, tensor2[(0, 2)] = 1.0505924224853516
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.8061254024505615, tensor2[(0, 3)] = 0.8105048537254333
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.7325626015663147, tensor2[(0, 4)] = 0.7479591965675354
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 0.818123459815979
Step 1, Loss: 0.8555735945701599
Step 2, Loss: 1.0364962816238403
Step 3, Loss: 0.8061254024505615
Step 4, Loss: 0.7325626015663147
Step 5, Loss: 0.8181264996528625
Step 6, Loss: 0.7028031945228577
Step 7, Loss: 0.9490163326263428
Step 8, Loss: 0.6424455642700195
Step 9, Loss: 0.7323499917984009
Step 10, Loss: 0.7475047707557678
Step 11, Loss: 0.4348352253437042
Step 12, Loss: 0.6739868521690369
Step 13, Loss: 0.644147515296936
Step 14, Loss: 0.5794475674629211
Step 15, Loss: 0.6939223408699036
Step 16, Loss: 0.6797480583190918
Step 17, Loss: 0.7528692483901978
Step 18, Loss: 0.6584441661834717
Step 19, Loss: 0.7241503596305847
Step 20, Loss: 0.5793960094451904
Step 21, Loss: 0.3547539710998535
Step 22, Loss: 0.4489099979400635
Step 23, Loss: 0.4197331666946411
Step 24, Loss: 0.412931352853775
Step 25, Loss: 0.3807120621204376
Step 26, Loss: 0.3338501453399658
Step 27, Loss: 0.4561348557472229
Step 28, Loss: 0.5375467538833618
Step 29, Loss: 0.3327734172344208
Step 30, Loss: 0.6659685373306274
Step 31, Loss: 0.420035183429718
Liger kernel patches have been reverted.
Step 0, Loss: 0.8416699171066284
Step 1, Loss: 0.8618330955505371
Step 2, Loss: 1.0505924224853516
Step 3, Loss: 0.8105048537254333
Step 4, Loss: 0.7479591965675354
Step 5, Loss: 0.8391144871711731
Step 6, Loss: 0.7081260085105896
Step 7, Loss: 0.9673135280609131
Step 8, Loss: 0.6531041264533997
Step 9, Loss: 0.743484616279602
Step 10, Loss: 0.7459954619407654
Step 11, Loss: 0.4431079924106598
Step 12, Loss: 0.6759545803070068
Step 13, Loss: 0.6579065918922424
Step 14, Loss: 0.5922181606292725
Step 15, Loss: 0.6976196765899658
Step 16, Loss: 0.6879425644874573
Step 17, Loss: 0.7612311840057373
Step 18, Loss: 0.6653209924697876
Step 19, Loss: 0.7283201813697815
Step 20, Loss: 0.5877732038497925
Step 21, Loss: 0.36779338121414185
Step 22, Loss: 0.45855027437210083
Step 23, Loss: 0.43603450059890747
Step 24, Loss: 0.43212783336639404
Step 25, Loss: 0.4014280438423157
Step 26, Loss: 0.3513464331626892
Step 27, Loss: 0.47521501779556274
Step 28, Loss: 0.5596947073936462
Step 29, Loss: 0.34919875860214233
Step 30, Loss: 0.6920917630195618
Step 31, Loss: 0.43439042568206787
Liger kernel patches have been reverted.
_ test_mini_model[mini_gemma2-32-0.0001-dtype11-1e-08-0.0001-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_gemma2', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 0.0001, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[6.3975, 0.9506, 0.8501, 0.6923, 0.6645, 0.7852, 0.7055, 0.9699, 0.6682,
         0.7611, 0.7799, 0.4486, 0.70...043, 0.7717, 0.6187, 0.3963, 0.5011, 0.4721, 0.4620, 0.4374, 0.3967,
         0.5223, 0.6152, 0.3913, 0.7618, 0.4976]])
tensor2 = tensor([[6.6696, 0.8469, 0.8440, 0.6889, 0.6647, 0.7792, 0.7059, 0.9622, 0.6579,
         0.7480, 0.7625, 0.4430, 0.68...930, 0.7569, 0.6045, 0.3895, 0.4902, 0.4685, 0.4604, 0.4364, 0.3885,
         0.5161, 0.6136, 0.3911, 0.7599, 0.4894]])
rtol = 0.0001, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 6.397529602050781, tensor2[(0, 0)] = 6.66957950592041
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 0.950577437877655, tensor2[(0, 1)] = 0.8468623161315918
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 0.8500818014144897, tensor2[(0, 2)] = 0.8440279364585876
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.6923311948776245, tensor2[(0, 3)] = 0.6889015436172485
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.6644583344459534, tensor2[(0, 4)] = 0.6646580100059509
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 6.397529602050781
Step 1, Loss: 0.950577437877655
Step 2, Loss: 0.8500818014144897
Step 3, Loss: 0.6923311948776245
Step 4, Loss: 0.6644583344459534
Step 5, Loss: 0.7852233052253723
Step 6, Loss: 0.7055337429046631
Step 7, Loss: 0.9699369668960571
Step 8, Loss: 0.6682460904121399
Step 9, Loss: 0.7611407041549683
Step 10, Loss: 0.7798893451690674
Step 11, Loss: 0.44855356216430664
Step 12, Loss: 0.7075291275978088
Step 13, Loss: 0.6803080439567566
Step 14, Loss: 0.6021474599838257
Step 15, Loss: 0.7360024452209473
Step 16, Loss: 0.7214772701263428
Step 17, Loss: 0.790974497795105
Step 18, Loss: 0.7043339014053345
Step 19, Loss: 0.7717177271842957
Step 20, Loss: 0.618697464466095
Step 21, Loss: 0.39627838134765625
Step 22, Loss: 0.5010679960250854
Step 23, Loss: 0.4721023142337799
Step 24, Loss: 0.4620136320590973
Step 25, Loss: 0.4374282956123352
Step 26, Loss: 0.39665651321411133
Step 27, Loss: 0.5223129391670227
Step 28, Loss: 0.6151712536811829
Step 29, Loss: 0.39134547114372253
Step 30, Loss: 0.7617537379264832
Step 31, Loss: 0.49758002161979675
Liger kernel patches have been reverted.
Step 0, Loss: 6.66957950592041
Step 1, Loss: 0.8468623161315918
Step 2, Loss: 0.8440279364585876
Step 3, Loss: 0.6889015436172485
Step 4, Loss: 0.6646580100059509
Step 5, Loss: 0.7791532874107361
Step 6, Loss: 0.7059386372566223
Step 7, Loss: 0.9621715545654297
Step 8, Loss: 0.6578609347343445
Step 9, Loss: 0.7480449080467224
Step 10, Loss: 0.7625100016593933
Step 11, Loss: 0.4429822564125061
Step 12, Loss: 0.6867876648902893
Step 13, Loss: 0.6716685891151428
Step 14, Loss: 0.6054787635803223
Step 15, Loss: 0.7207801342010498
Step 16, Loss: 0.7118698954582214
Step 17, Loss: 0.7849560976028442
Step 18, Loss: 0.6930171847343445
Step 19, Loss: 0.7568984031677246
Step 20, Loss: 0.6044853925704956
Step 21, Loss: 0.3895394802093506
Step 22, Loss: 0.4902222454547882
Step 23, Loss: 0.46853187680244446
Step 24, Loss: 0.4604288637638092
Step 25, Loss: 0.43637147545814514
Step 26, Loss: 0.38852736353874207
Step 27, Loss: 0.5160799622535706
Step 28, Loss: 0.6136299967765808
Step 29, Loss: 0.3910857141017914
Step 30, Loss: 0.759941041469574
Step 31, Loss: 0.4894213378429413
Liger kernel patches have been reverted.
_ test_mini_model[mini_granite3-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05] _

model_name = 'mini_granite3', num_steps = 32, lr = 0.0001, dtype = torch.float32
loss_atol = 1e-08, loss_rtol = 0.0001, logits_atol = 0.005, logits_rtol = 1e-05
param_atol = 0.005, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llava",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not LLAVA_AVAILABLE,
                    reason="LLaVa not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not MLLAMA_AVAILABLE,
                    reason="Mllama not available in this version of transformers",
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(  # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(  # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.float32,
                1e-5,  # 1e-8,
                1e-1,  # 1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_5_VL_AVAILABLE,
                    reason="Qwen2.5-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_olmo2",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not OLMO2_AVAILABLE,
                    reason="OLMO2 not available in this version of transformers",
                ),
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            # TODO: mixtral is flaky so disable the test for now
            # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_granite3",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-4,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not GRANITE_AVAILABLE,
                    reason="Granite not available in this version of transformers",
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
    
        expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
        actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
    
        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/fp32/test_mini_models.py:810: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor([[10.5235,  3.0442,  0.9047,  0.7144,  0.6839,  0.8117,  0.7413,  1.0123,
          0.7141,  0.8117,  0.8263,  ...  0.6656,  0.4234,  0.5247,  0.4991,
          0.4967,  0.4543,  0.4059,  0.5354,  0.6261,  0.4036,  0.7827,  0.5113]])
tensor2 = tensor([[10.4341,  4.8195,  1.1653,  0.7205,  0.6933,  0.8360,  0.7812,  1.0745,
          0.7657,  0.8629,  0.8533,  ...  0.6831,  0.4444,  0.5392,  0.5154,
          0.5016,  0.4674,  0.4107,  0.5392,  0.6289,  0.4070,  0.7960,  0.5125]])
rtol = 0.0001, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(torch.isposinf(tensor1), torch.isposinf(tensor2))
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(torch.isneginf(tensor1), torch.isneginf(tensor2))
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}")
            if num_mismatched > max_print:
                mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.523491859436035, tensor2[(0, 0)] = 10.434144020080566
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 3.0442163944244385, tensor2[(0, 1)] = 4.819544792175293
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 0.9047379493713379, tensor2[(0, 2)] = 1.1653170585632324
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 0.7143992185592651, tensor2[(0, 3)] = 0.7205212712287903
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 0.6838905811309814, tensor2[(0, 4)] = 0.6932666897773743
E           ... and 27 more mismatched elements.

test/utils.py:119: AssertionError
----------------------------- Captured stdout call -----------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.523491859436035
Step 1, Loss: 3.0442163944244385
Step 2, Loss: 0.9047379493713379
Step 3, Loss: 0.7143992185592651
Step 4, Loss: 0.6838905811309814
Step 5, Loss: 0.8117480278015137
Step 6, Loss: 0.7413424253463745
Step 7, Loss: 1.0122911930084229
Step 8, Loss: 0.7141204476356506
Step 9, Loss: 0.8116552829742432
Step 10, Loss: 0.8263360857963562
Step 11, Loss: 0.5299246907234192
Step 12, Loss: 0.7658659219741821
Step 13, Loss: 0.745334804058075
Step 14, Loss: 0.671735942363739
Step 15, Loss: 0.7948212623596191
Step 16, Loss: 0.7692177891731262
Step 17, Loss: 0.8354676365852356
Step 18, Loss: 0.7253051996231079
Step 19, Loss: 0.8105993866920471
Step 20, Loss: 0.665621280670166
Step 21, Loss: 0.42336153984069824
Step 22, Loss: 0.5246681571006775
Step 23, Loss: 0.49905407428741455
Step 24, Loss: 0.4966582655906677
Step 25, Loss: 0.4542563855648041
Step 26, Loss: 0.40592220425605774
Step 27, Loss: 0.5353959798812866
Step 28, Loss: 0.6261412501335144
Step 29, Loss: 0.40357351303100586
Step 30, Loss: 0.7826501131057739
Step 31, Loss: 0.5113450884819031
Liger kernel patches have been reverted.
Step 0, Loss: 10.434144020080566
Step 1, Loss: 4.819544792175293
Step 2, Loss: 1.1653170585632324
Step 3, Loss: 0.7205212712287903
Step 4, Loss: 0.6932666897773743
Step 5, Loss: 0.8359834551811218
Step 6, Loss: 0.7811504006385803
Step 7, Loss: 1.0745203495025635
Step 8, Loss: 0.7656863927841187
Step 9, Loss: 0.8629040122032166
Step 10, Loss: 0.8532899618148804
Step 11, Loss: 0.5375107526779175
Step 12, Loss: 0.7750334143638611
Step 13, Loss: 0.7481169700622559
Step 14, Loss: 0.6873371601104736
Step 15, Loss: 0.8025739789009094
Step 16, Loss: 0.7945616841316223
Step 17, Loss: 0.8575659394264221
Step 18, Loss: 0.7552111744880676
Step 19, Loss: 0.84697026014328
Step 20, Loss: 0.6830578446388245
Step 21, Loss: 0.44442734122276306
Step 22, Loss: 0.5392183065414429
Step 23, Loss: 0.5154062509536743
Step 24, Loss: 0.5015864968299866
Step 25, Loss: 0.4673548936843872
Step 26, Loss: 0.410695880651474
Step 27, Loss: 0.5391891002655029
Step 28, Loss: 0.628862202167511
Step 29, Loss: 0.407024621963501
Step 30, Loss: 0.795998215675354
Step 31, Loss: 0.5125444531440735
Liger kernel patches have been reverted.
=========================== short test summary info ============================
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype4-1e-05-0.1-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype5-1e-05-0.1-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_olmo2-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype9-1e-08-0.0001-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype10-1e-08-0.0001-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_gemma2-32-0.0001-dtype11-1e-08-0.0001-0.005-1e-05-0.005-1e-05]
FAILED test/convergence/fp32/test_mini_models.py::test_mini_model[mini_granite3-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05]
================== 13 failed, 1 warning in 140.12s (0:02:20) ===================
make: *** [Makefile:23: test-convergence] Error 1

apply_liger_kernel_to_llama(**kwargs)

# fused_linear_cross_entropy is not supported in mini_granite3
kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False
kwargs["cross_entropy"] = False
Expand Down Expand Up @@ -623,6 +697,25 @@ def run_mini_model(
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
pytest.param(
"mini_llava",
32,
1e-4,
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not LLAVA_AVAILABLE,
reason="LLaVa not available in this version of transformers",
),
],
),
pytest.param(
"mini_granite3",
32,
Expand Down
Loading
Loading