Skip to content

Commit c0073b6

Browse files
committed
Fix torchscript tests (#12336)
* Fix torchscript tests * Better test * Remove bogus print
1 parent 0b752bf commit c0073b6

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

tests/test_modeling_common.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,13 +564,34 @@ def _create_and_check_torchscript(self, config, inputs_dict):
564564
model_state_dict = model.state_dict()
565565
loaded_model_state_dict = loaded_model.state_dict()
566566

567+
non_persistent_buffers = {}
568+
for key in loaded_model_state_dict.keys():
569+
if key not in model_state_dict.keys():
570+
non_persistent_buffers[key] = loaded_model_state_dict[key]
571+
572+
loaded_model_state_dict = {
573+
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
574+
}
575+
567576
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
568577

578+
model_buffers = list(model.buffers())
579+
for non_persistent_buffer in non_persistent_buffers.values():
580+
found_buffer = False
581+
for i, model_buffer in enumerate(model_buffers):
582+
if torch.equal(non_persistent_buffer, model_buffer):
583+
found_buffer = True
584+
break
585+
586+
self.assertTrue(found_buffer)
587+
model_buffers.pop(i)
588+
569589
models_equal = True
570590
for layer_name, p1 in model_state_dict.items():
571-
p2 = loaded_model_state_dict[layer_name]
572-
if p1.data.ne(p2.data).sum() > 0:
573-
models_equal = False
591+
if layer_name in loaded_model_state_dict:
592+
p2 = loaded_model_state_dict[layer_name]
593+
if p1.data.ne(p2.data).sum() > 0:
594+
models_equal = False
574595

575596
self.assertTrue(models_equal)
576597

0 commit comments

Comments
 (0)