Skip to content

Commit dc0f59c

Browse files
authored
Make get_norm_layer repr test tolerant of PyTorch bias= field (#8880)
PyTorch >= 2.13 adds an optional 'bias=' token to GroupNorm/InstanceNorm __repr__, breaking the exact-string match in test_norm_layer. Normalize the repr by stripping the bias= field so the test passes on PyTorch versions with or without it (backward- and forward-compatible). ### Description test_norm_layer in tests/networks/layers/test_get_layers.py asserts that a norm layer's repr() exactly equals a hard-coded string. PyTorch >= 2.13 added an optional bias= token to GroupNorm / InstanceNorm __repr__, which breaks that exact-string match. This normalizes both the actual and expected repr by stripping the bias= field, so the test passes on PyTorch versions with or without it — backward- and forward-compatible, with no change to production code. ### Types of changes - [x] Non-breaking change (fix that would not break existing functionality). - [x] Quick tests passed locally (python -m tests.networks.layers.test_get_layers — 7/7). PyTorch < 2.13: GroupNorm(1, 1, eps=1e-05, affine=True) PyTorch >= 2.13: GroupNorm(1, 1, eps=1e-05, affine=True, bias=True) The fix strips , bias=(True|False) from both sides before comparison. Local checks: targeted test 7/7 OK; black/isort/ruff clean; full pre-commit hooks pass. Signed-off-by: Hans Johnson <hans-johnson@uiowa.edu>
1 parent e33941c commit dc0f59c

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

tests/networks/layers/test_get_layers.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,26 @@
1111

1212
from __future__ import annotations
1313

14+
import re
1415
import unittest
1516

1617
from parameterized import parameterized
1718

1819
from monai.networks.layers import get_act_layer, get_dropout_layer, get_norm_layer
1920

21+
22+
def _strip_bias_field(text: str) -> str:
23+
"""Strip the optional PyTorch >= 2.13 ``, bias=True|False`` repr fragment.
24+
25+
Args:
26+
text: Layer string representation to normalize.
27+
28+
Returns:
29+
The representation with any ``, bias=True|False`` removed.
30+
"""
31+
return re.sub(r",\s*bias=(?:True|False)", "", text)
32+
33+
2034
TEST_CASE_NORM = [
2135
[{"name": ("group", {"num_groups": 1})}, "GroupNorm(1, 1, eps=1e-05, affine=True)"],
2236
[
@@ -41,7 +55,7 @@ class TestGetLayers(unittest.TestCase):
4155
@parameterized.expand(TEST_CASE_NORM)
4256
def test_norm_layer(self, input_param, expected):
4357
layer = get_norm_layer(**input_param)
44-
self.assertEqual(f"{layer}", expected)
58+
self.assertEqual(_strip_bias_field(f"{layer}"), _strip_bias_field(expected))
4559

4660
@parameterized.expand(TEST_CASE_ACT)
4761
def test_act_layer(self, input_param, expected):

0 commit comments

Comments
 (0)