Skip to content

Commit cd39c8e

Browse files
committed
Replace error by warning when loading an architecture in another (#11207)
* Replace error by warning when loading an architecture in another * Style * Style again * Add a test * Adapt old test
1 parent 4906a29 commit cd39c8e

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

src/transformers/configuration_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
399399
400400
"""
401401
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
402-
if config_dict.get("model_type", False) and hasattr(cls, "model_type"):
403-
assert (
404-
config_dict["model_type"] == cls.model_type
405-
), f"You tried to initiate a model of type '{cls.model_type}' with a pretrained model of type '{config_dict['model_type']}'"
402+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
403+
logger.warn(
404+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
405+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
406+
)
406407

407408
return cls.from_dict(config_dict, **kwargs)
408409

tests/test_modeling_bert_generation.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,7 @@ def create_and_check_for_causal_lm(
231231
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
232232

233233
def prepare_config_and_inputs_for_common(self):
234-
config_and_inputs = self.prepare_config_and_inputs()
235-
(
236-
config,
237-
input_ids,
238-
input_mask,
239-
token_labels,
240-
) = config_and_inputs
234+
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
241235
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
242236
return config, inputs_dict
243237

@@ -259,6 +253,11 @@ def test_model(self):
259253
config_and_inputs = self.model_tester.prepare_config_and_inputs()
260254
self.model_tester.create_and_check_model(*config_and_inputs)
261255

256+
def test_model_as_bert(self):
257+
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
258+
config.model_type = "bert"
259+
self.model_tester.create_and_check_model(config, input_ids, input_mask, token_labels)
260+
262261
def test_model_as_decoder(self):
263262
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
264263
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)

tests/test_modeling_common.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import unittest
2323
from typing import List, Tuple
2424

25-
from transformers import is_torch_available
25+
from transformers import is_torch_available, logging
2626
from transformers.file_utils import WEIGHTS_NAME
27-
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
27+
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_multi_gpu, slow, torch_device
2828

2929

3030
if is_torch_available():
@@ -1295,6 +1295,7 @@ def test_model_from_pretrained_with_different_pretrained_model_name(self):
12951295
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
12961296
self.assertIsNotNone(model)
12971297

1298-
with self.assertRaises(Exception) as context:
1298+
logger = logging.get_logger("transformers.configuration_utils")
1299+
with CaptureLogger(logger) as cl:
12991300
BertModel.from_pretrained(TINY_T5)
1300-
self.assertTrue("You tried to initiate a model of type" in str(context.exception))
1301+
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)

0 commit comments

Comments
 (0)