Skip to content

Commit 802915a

Browse files
committed
Enable CompatibilityTest
1 parent 99beb35 commit 802915a

File tree

2 files changed

+149
-145
lines changed

2 files changed

+149
-145
lines changed

paddleformers/transformers/configuration_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,6 @@ def __init__(self, **kwargs):
622622
# parameter for model dtype
623623
if "torch_dtype" in kwargs:
624624
self.dtype = kwargs.pop("torch_dtype")
625-
# else:
626-
# import paddle
627-
628-
# self.dtype = kwargs.pop("dtype", paddle.get_default_dtype())
629625

630626
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
631627
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)

tests/transformers/qwen2_5_vl/test_modeling.py

Lines changed: 149 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import tempfile
1919
import unittest
2020

21+
import numpy as np
2122
import paddle
23+
from parameterized import parameterized
2224

2325
from paddleformers.transformers import (
2426
AutoProcessor,
@@ -28,6 +30,7 @@
2830
process_vision_info,
2931
)
3032
from paddleformers.transformers.video_utils import load_video
33+
from tests.testing_utils import require_package
3134
from tests.transformers.test_configuration_common import ConfigTester
3235
from tests.transformers.test_generation_utils import GenerationTesterMixin
3336
from tests.transformers.test_modeling_common import (
@@ -779,144 +782,149 @@ def test_model_tiny_logits_with_video(self):
779782
self.assertTrue(paddle.allclose(output[0, 150, 10000:10030], EXPECTED_SLICE, atol=1e-3, rtol=1e-3))
780783

781784

782-
# class Qwen2_5_VLCompatibilityTest(unittest.TestCase):
783-
# @classmethod
784-
# @require_package("transformers", "torch")
785-
# def setUpClass(cls) -> None:
786-
# from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration
787-
788-
# # when python application is done, `TemporaryDirectory` will be free
789-
# cls.torch_model_path = tempfile.TemporaryDirectory().name
790-
# tiny_vision_config = {
791-
# "depth": 4,
792-
# "intermediate_size": 95,
793-
# "hidden_size": 64,
794-
# "out_hidden_size": 128,
795-
# "fullatt_block_indexes": [1, 3],
796-
# }
797-
# tiny_rope_scaling = {"type": "mrope", "mrope_section": [1]}
798-
# config = Qwen2_5_VLConfig(
799-
# hidden_size=64,
800-
# intermediate_size=344,
801-
# num_hidden_layers=2,
802-
# vision_config=tiny_vision_config,
803-
# rope_scaling=tiny_rope_scaling,
804-
# )
805-
806-
# input_ids = np.random.randint(100, 200, [1, 20]).astype("int64")
807-
# visual_token_ids = [config.vision_start_token_id] + [config.image_token_id] * 2 + [config.vision_end_token_id]
808-
# input_ids[:, 10 : 10 + len(visual_token_ids)] = visual_token_ids
809-
810-
# attention_mask = np.ones([1, 20], dtype="int64")
811-
# pixel_values = np.random.randn(2 * 2, 1176).astype("float32")
812-
# image_grid_thw = np.array([[1, 2, 2]], dtype="int64")
813-
# cls.inputs = {
814-
# "input_ids": input_ids,
815-
# "pixel_values": pixel_values,
816-
# "image_grid_thw": image_grid_thw,
817-
# "attention_mask": attention_mask,
818-
# }
819-
820-
# model = Qwen2_5_VLForConditionalGeneration(config)
821-
# model.save_pretrained(cls.torch_model_path)
822-
823-
# @require_package("transformers", "torch")
824-
# def test_Qwen2_5_VL_converter(self):
825-
826-
# # 1. forward the paddle model
827-
# from paddleformers.transformers import Qwen2_5_VLModel
828-
829-
# paddle_inputs = {k: paddle.to_tensor(v) for k, v in self.inputs.items()}
830-
# paddle_model = Qwen2_5_VLModel.from_pretrained(
831-
# self.torch_model_path, convert_from_hf=True, dtype="float32"
832-
# ).eval()
833-
# paddle_logit = paddle_model(**paddle_inputs)[0]
834-
835-
# # 2. forward the torch model
836-
# import torch
837-
# from transformers import Qwen2_5_VLModel
838-
839-
# torch_inputs = {k: torch.tensor(v) for k, v in self.inputs.items()}
840-
# torch_model = Qwen2_5_VLModel.from_pretrained(self.torch_model_path, torch_dtype=torch.float32).eval()
841-
# torch_logit = torch_model(**torch_inputs)[0]
842-
843-
# # 3. compare the result between paddle and torch
844-
# self.assertTrue(
845-
# np.allclose(
846-
# paddle_logit.detach().cpu().reshape([-1])[:9].astype("float32").numpy(),
847-
# torch_logit.detach().cpu().reshape([-1])[:9].float().numpy(),
848-
# atol=1e-2,
849-
# rtol=1e-2,
850-
# )
851-
# )
852-
853-
# @require_package("transformers", "torch")
854-
# def test_Qwen2_5_VL_converter_from_local_dir(self):
855-
# with tempfile.TemporaryDirectory() as tempdir:
856-
857-
# # 1. forward the torch model
858-
# import torch
859-
# from transformers import Qwen2_5_VLModel
860-
861-
# torch_inputs = {k: torch.tensor(v) for k, v in self.inputs.items()}
862-
# torch_model = Qwen2_5_VLModel.from_pretrained(self.torch_model_path, torch_dtype=torch.float32)
863-
# torch_model.eval()
864-
# torch_model.save_pretrained(tempdir)
865-
# torch_logit = torch_model(**torch_inputs)[0]
866-
867-
# # 2. forward the paddle model
868-
# from paddleformers.transformers import Qwen2_5_VLModel
869-
870-
# paddle_inputs = {k: paddle.to_tensor(v) for k, v in self.inputs.items()}
871-
# paddle_model = Qwen2_5_VLModel.from_pretrained(tempdir, convert_from_hf=True, dtype="float32")
872-
# paddle_model.eval()
873-
# paddle_logit = paddle_model(**paddle_inputs)[0]
874-
875-
# # 3. compare the result between paddle and torch
876-
# self.assertTrue(
877-
# np.allclose(
878-
# paddle_logit.detach().cpu().reshape([-1])[:9].astype("float32").numpy(),
879-
# torch_logit.detach().cpu().reshape([-1])[:9].float().numpy(),
880-
# atol=1e-2,
881-
# rtol=1e-2,
882-
# )
883-
# )
884-
885-
# @parameterized.expand([("Qwen2_5_VLModel",), ("Qwen2_5_VLForConditionalGeneration",)])
886-
# @require_package("transformers", "torch")
887-
# def test_Qwen2_5_VL_classes_from_local_dir(self, class_name, pytorch_class_name: str | None = None):
888-
# pytorch_class_name = pytorch_class_name or class_name
889-
# with tempfile.TemporaryDirectory() as tempdir:
890-
891-
# # 1. forward the torch model
892-
# import torch
893-
# import transformers
894-
895-
# torch_inputs = {k: torch.tensor(v) for k, v in self.inputs.items()}
896-
# torch_model_class = getattr(transformers, pytorch_class_name)
897-
# torch_model = torch_model_class.from_pretrained(self.torch_model_path, torch_dtype=torch.float32).eval()
898-
899-
# torch_model.save_pretrained(tempdir)
900-
# torch_logit = torch_model(**torch_inputs)[0]
901-
902-
# # 2. forward the paddle model
903-
# from paddleformers import transformers
904-
905-
# paddle_inputs = {k: paddle.to_tensor(v) for k, v in self.inputs.items()}
906-
# paddle_model_class = getattr(transformers, class_name)
907-
# paddle_model = paddle_model_class.from_pretrained(tempdir, convert_from_hf=True, dtype="float32").eval()
908-
909-
# if class_name == "Qwen2_5_VLModel":
910-
# paddle_logit = paddle_model(**paddle_inputs)[0]
911-
# else:
912-
# paddle_logit = paddle_model(**paddle_inputs)["logits"]
913-
914-
# # 3. compare the result between paddle and torch
915-
# self.assertTrue(
916-
# np.allclose(
917-
# paddle_logit.detach().cpu().reshape([-1])[:9].astype("float32").numpy(),
918-
# torch_logit.detach().cpu().reshape([-1])[:9].float().numpy(),
919-
# atol=1e-2,
920-
# rtol=1e-2,
921-
# )
922-
# )
785+
class Qwen2_5_VLCompatibilityTest(unittest.TestCase):
786+
@classmethod
787+
@require_package("transformers", "torch")
788+
def setUpClass(cls) -> None:
789+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration
790+
791+
# when python application is done, `TemporaryDirectory` will be free
792+
cls.torch_model_path = tempfile.TemporaryDirectory().name
793+
tiny_vision_config = {
794+
"depth": 4,
795+
"intermediate_size": 95,
796+
"hidden_size": 64,
797+
"out_hidden_size": 128,
798+
"fullatt_block_indexes": [1, 3],
799+
}
800+
tiny_rope_scaling = {"type": "mrope", "mrope_section": [1]}
801+
config = Qwen2_5_VLConfig(
802+
hidden_size=64,
803+
intermediate_size=344,
804+
num_hidden_layers=2,
805+
vision_config=tiny_vision_config,
806+
rope_scaling=tiny_rope_scaling,
807+
vision_start_token_id=151652,
808+
vision_end_token_id=151653,
809+
image_token_id=151655,
810+
)
811+
812+
input_ids = np.random.randint(0, 200, [1, 20]).astype("int64")
813+
visual_token_ids = (
814+
[config.vision_start_token_id] + [config.image_token_id] * 2 + [config.vision_start_token_id]
815+
)
816+
input_ids[:, 10 : 10 + len(visual_token_ids)] = visual_token_ids
817+
818+
attention_mask = np.ones([1, 20], dtype="int64")
819+
pixel_values = np.random.randn(2 * 2, 1176).astype("float32")
820+
image_grid_thw = np.array([[1, 2, 2]], dtype="int64")
821+
cls.inputs = {
822+
"input_ids": input_ids,
823+
"pixel_values": pixel_values,
824+
"image_grid_thw": image_grid_thw,
825+
"attention_mask": attention_mask,
826+
}
827+
828+
model = Qwen2_5_VLForConditionalGeneration(config)
829+
model.save_pretrained(cls.torch_model_path)
830+
831+
@require_package("transformers", "torch")
832+
def test_Qwen2_5_VL_converter(self):
833+
834+
# 1. forward the paddle model
835+
from paddleformers.transformers import Qwen2_5_VLModel
836+
837+
paddle_inputs = {k: paddle.to_tensor(v) for k, v in self.inputs.items()}
838+
paddle_model = Qwen2_5_VLModel.from_pretrained(
839+
self.torch_model_path, convert_from_hf=True, dtype="float32"
840+
).eval()
841+
paddle_logit = paddle_model(**paddle_inputs)[0]
842+
843+
# 2. forward the torch model
844+
import torch
845+
from transformers import Qwen2_5_VLModel
846+
847+
torch_inputs = {k: torch.tensor(v) for k, v in self.inputs.items()}
848+
torch_model = Qwen2_5_VLModel.from_pretrained(self.torch_model_path, torch_dtype=torch.float32).eval()
849+
torch_logit = torch_model(**torch_inputs)[0]
850+
851+
# 3. compare the result between paddle and torch
852+
self.assertTrue(
853+
np.allclose(
854+
paddle_logit.detach().cpu().reshape([-1])[:9].astype("float32").numpy(),
855+
torch_logit.detach().cpu().reshape([-1])[:9].float().numpy(),
856+
atol=1e-2,
857+
rtol=1e-2,
858+
)
859+
)
860+
861+
@require_package("transformers", "torch")
862+
def test_Qwen2_5_VL_converter_from_local_dir(self):
863+
with tempfile.TemporaryDirectory() as tempdir:
864+
865+
# 1. forward the torch model
866+
import torch
867+
from transformers import Qwen2_5_VLModel
868+
869+
torch_inputs = {k: torch.tensor(v) for k, v in self.inputs.items()}
870+
torch_model = Qwen2_5_VLModel.from_pretrained(self.torch_model_path, torch_dtype=torch.float32)
871+
torch_model.eval()
872+
torch_model.save_pretrained(tempdir)
873+
torch_logit = torch_model(**torch_inputs)[0]
874+
875+
# 2. forward the paddle model
876+
from paddleformers.transformers import Qwen2_5_VLModel
877+
878+
paddle_inputs = {k: paddle.to_tensor(v) for k, v in self.inputs.items()}
879+
paddle_model = Qwen2_5_VLModel.from_pretrained(tempdir, convert_from_hf=True, dtype="float32")
880+
paddle_model.eval()
881+
paddle_logit = paddle_model(**paddle_inputs)[0]
882+
883+
# 3. compare the result between paddle and torch
884+
self.assertTrue(
885+
np.allclose(
886+
paddle_logit.detach().cpu().reshape([-1])[:9].astype("float32").numpy(),
887+
torch_logit.detach().cpu().reshape([-1])[:9].float().numpy(),
888+
atol=1e-2,
889+
rtol=1e-2,
890+
)
891+
)
892+
893+
@parameterized.expand([("Qwen2_5_VLModel",), ("Qwen2_5_VLForConditionalGeneration",)])
894+
@require_package("transformers", "torch")
895+
def test_Qwen2_5_VL_classes_from_local_dir(self, class_name, pytorch_class_name: str | None = None):
896+
pytorch_class_name = pytorch_class_name or class_name
897+
with tempfile.TemporaryDirectory() as tempdir:
898+
899+
# 1. forward the torch model
900+
import torch
901+
import transformers
902+
903+
torch_inputs = {k: torch.tensor(v) for k, v in self.inputs.items()}
904+
torch_model_class = getattr(transformers, pytorch_class_name)
905+
torch_model = torch_model_class.from_pretrained(self.torch_model_path, torch_dtype=torch.float32).eval()
906+
907+
torch_model.save_pretrained(tempdir)
908+
torch_logit = torch_model(**torch_inputs)[0]
909+
910+
# 2. forward the paddle model
911+
from paddleformers import transformers
912+
913+
paddle_inputs = {k: paddle.to_tensor(v) for k, v in self.inputs.items()}
914+
paddle_model_class = getattr(transformers, class_name)
915+
paddle_model = paddle_model_class.from_pretrained(tempdir, convert_from_hf=True, dtype="float32").eval()
916+
917+
if class_name == "Qwen2_5_VLModel":
918+
paddle_logit = paddle_model(**paddle_inputs)[0]
919+
else:
920+
paddle_logit = paddle_model(**paddle_inputs)["logits"]
921+
922+
# 3. compare the result between paddle and torch
923+
self.assertTrue(
924+
np.allclose(
925+
paddle_logit.detach().cpu().reshape([-1])[:9].astype("float32").numpy(),
926+
torch_logit.detach().cpu().reshape([-1])[:9].float().numpy(),
927+
atol=1e-2,
928+
rtol=1e-2,
929+
)
930+
)

0 commit comments

Comments
 (0)