diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 8c77cc2cf..69ca86cf5 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -219,12 +219,16 @@ def const(self, init): elif isinstance(init, Sequence): iterator = enumerate(init) else: - raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}" - .format(init)) + raise TypeError(f"Layout constant initializer must be a mapping or a sequence, not " + f"{init!r}") int_value = 0 for key, key_value in iterator: - field = self[key] + try: + field = self[key] + except KeyError: + raise ValueError(f"Layout constant initializer refers to key {key!r}, which is not " + f"a part of the layout") cast_field_shape = Shape.cast(field.shape) if isinstance(field.shape, ShapeCastable): key_value = hdl.Const.cast(hdl.Const(key_value, field.shape)) @@ -1079,9 +1083,18 @@ def __eq__(self, other): elif isinstance(other, Const) and self.__layout == other.__layout: return self.__target == other.__target else: + cause = None + if isinstance(other, (dict, list)): + try: + other_as_const = self.__layout.const(other) + except (TypeError, ValueError) as exc: + cause = exc + else: + return self == other_as_const raise TypeError( - f"Constant with layout {self.__layout!r} can only be compared to another view or " - f"constant with the same layout, not {other!r}") + f"Constant with layout {self.__layout!r} can only be compared to another view, " + f"a constant with the same layout, or a dictionary or a list that can be converted " + f"to a constant with the same layout, not {other!r}") from cause def __ne__(self, other): if isinstance(other, View) and self.__layout == other._View__layout: @@ -1089,9 +1102,18 @@ def __ne__(self, other): elif isinstance(other, Const) and self.__layout == other.__layout: return self.__target != other.__target else: + cause = None + if isinstance(other, (dict, list)): + try: + other_as_const = self.__layout.const(other) + except (TypeError, ValueError) as exc: + cause = exc + else: + return self != other_as_const raise TypeError( - f"Constant with layout {self.__layout!r} can only be compared to another view or " - f"constant with the same layout, not {other!r}") + f"Constant with layout {self.__layout!r} can only be compared to another view, " + f"a constant with the same layout, or a dictionary or a list that can be converted " + f"to a constant with the same layout, not {other!r}") from cause def __add__(self, other): raise TypeError("Cannot perform arithmetic operations on a lib.data.Const") diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index ceafd638b..86a75c1da 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -477,6 +477,10 @@ def test_const_wrong(self): r"^Layout constant initializer must be a mapping or a sequence, not " r"<.+?object.+?>$"): sl.const(object()) + with self.assertRaisesRegex(ValueError, + r"^Layout constant initializer refers to key 'g', which is not a part " + r"of the layout$"): + sl.const({"g": 1}) sl2 = data.StructLayout({"f": unsigned(2)}) with self.assertRaisesRegex(ValueError, r"^Const layout StructLayout.* differs from shape layout StructLayout.*$"): @@ -740,7 +744,7 @@ def test_bug_837_array_layout_getattr(self): r"^View with an array layout does not have fields$"): Signal(data.ArrayLayout(unsigned(1), 1), init=[0]).init - def test_eq(self): + def test_compare(self): s1 = Signal(data.StructLayout({"a": unsigned(2)})) s2 = Signal(data.StructLayout({"a": unsigned(2)})) s3 = Signal(data.StructLayout({"a": unsigned(1), "b": unsigned(1)})) @@ -969,11 +973,12 @@ def test_bug_837_array_layout_getattr(self): r"^Constant with an array layout does not have fields$"): data.Const(data.ArrayLayout(unsigned(1), 1), 0).init - def test_eq(self): + def test_compare(self): c1 = data.Const(data.StructLayout({"a": unsigned(2)}), 1) c2 = data.Const(data.StructLayout({"a": unsigned(2)}), 1) c3 = data.Const(data.StructLayout({"a": unsigned(2)}), 2) c4 = data.Const(data.StructLayout({"a": unsigned(1), "b": unsigned(1)}), 2) + c5 = data.Const(data.ArrayLayout(2, 4), 0b11100100) s1 = Signal(data.StructLayout({"a": unsigned(2)})) self.assertTrue(c1 == c2) self.assertFalse(c1 != c2) @@ -983,13 +988,23 @@ def test_eq(self): self.assertRepr(c1 != s1, "(!= (const 2'd1) (sig s1))") self.assertRepr(s1 == c1, "(== (sig s1) (const 2'd1))") self.assertRepr(s1 != c1, "(!= (sig s1) (const 2'd1))") + self.assertTrue(c1 == {"a": 1}) + self.assertFalse(c1 == {"a": 2}) + self.assertFalse(c1 != {"a": 1}) + self.assertTrue(c1 != {"a": 2}) + self.assertTrue(c5 == [0,1,2,3]) + self.assertFalse(c5 == [0,1,3,3]) + self.assertFalse(c5 != [0,1,2,3]) + self.assertTrue(c5 != [0,1,3,3]) with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 == c4 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 != c4 with self.assertRaisesRegex(TypeError, r"^View with layout .* can only be compared to another view or constant with " @@ -1000,21 +1015,45 @@ def test_eq(self): r"the same layout, not .*$"): s1 != c4 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c4 == s1 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c4 != s1 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 == Const(0, 2) with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 != Const(0, 2) + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c1 == {"b": 1} + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c1 != {"b": 1} + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c5 == [0,1,2,3,4] + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c5 != [0,1,2,3,4] def test_operator(self): s1 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)