Skip to content

Make data.Const comparable with list and dict objects compatible with the layout #1420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions amaranth/lib/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,16 @@ def const(self, init):
elif isinstance(init, Sequence):
iterator = enumerate(init)
else:
raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}"
.format(init))
raise TypeError(f"Layout constant initializer must be a mapping or a sequence, not "
f"{init!r}")

int_value = 0
for key, key_value in iterator:
field = self[key]
try:
field = self[key]
except KeyError:
raise ValueError(f"Layout constant initializer refers to key {key!r}, which is not "
f"a part of the layout")
cast_field_shape = Shape.cast(field.shape)
if isinstance(field.shape, ShapeCastable):
key_value = hdl.Const.cast(hdl.Const(key_value, field.shape))
Expand Down Expand Up @@ -1079,19 +1083,37 @@ def __eq__(self, other):
elif isinstance(other, Const) and self.__layout == other.__layout:
return self.__target == other.__target
else:
cause = None
if isinstance(other, (dict, list)):
try:
other_as_const = self.__layout.const(other)
except (TypeError, ValueError) as exc:
cause = exc
else:
return self == other_as_const
raise TypeError(
f"Constant with layout {self.__layout!r} can only be compared to another view or "
f"constant with the same layout, not {other!r}")
f"Constant with layout {self.__layout!r} can only be compared to another view, "
f"a constant with the same layout, or a dictionary or a list that can be converted "
f"to a constant with the same layout, not {other!r}") from cause

def __ne__(self, other):
if isinstance(other, View) and self.__layout == other._View__layout:
return self.as_value() != other._View__target
elif isinstance(other, Const) and self.__layout == other.__layout:
return self.__target != other.__target
else:
cause = None
if isinstance(other, (dict, list)):
try:
other_as_const = self.__layout.const(other)
except (TypeError, ValueError) as exc:
cause = exc
else:
return self != other_as_const
raise TypeError(
f"Constant with layout {self.__layout!r} can only be compared to another view or "
f"constant with the same layout, not {other!r}")
f"Constant with layout {self.__layout!r} can only be compared to another view, "
f"a constant with the same layout, or a dictionary or a list that can be converted "
f"to a constant with the same layout, not {other!r}") from cause

def __add__(self, other):
raise TypeError("Cannot perform arithmetic operations on a lib.data.Const")
Expand Down
67 changes: 53 additions & 14 deletions tests/test_lib_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ def test_const_wrong(self):
r"^Layout constant initializer must be a mapping or a sequence, not "
r"<.+?object.+?>$"):
sl.const(object())
with self.assertRaisesRegex(ValueError,
r"^Layout constant initializer refers to key 'g', which is not a part "
r"of the layout$"):
sl.const({"g": 1})
sl2 = data.StructLayout({"f": unsigned(2)})
with self.assertRaisesRegex(ValueError,
r"^Const layout StructLayout.* differs from shape layout StructLayout.*$"):
Expand Down Expand Up @@ -740,7 +744,7 @@ def test_bug_837_array_layout_getattr(self):
r"^View with an array layout does not have fields$"):
Signal(data.ArrayLayout(unsigned(1), 1), init=[0]).init

def test_eq(self):
def test_compare(self):
s1 = Signal(data.StructLayout({"a": unsigned(2)}))
s2 = Signal(data.StructLayout({"a": unsigned(2)}))
s3 = Signal(data.StructLayout({"a": unsigned(1), "b": unsigned(1)}))
Expand Down Expand Up @@ -969,11 +973,12 @@ def test_bug_837_array_layout_getattr(self):
r"^Constant with an array layout does not have fields$"):
data.Const(data.ArrayLayout(unsigned(1), 1), 0).init

def test_eq(self):
def test_compare(self):
c1 = data.Const(data.StructLayout({"a": unsigned(2)}), 1)
c2 = data.Const(data.StructLayout({"a": unsigned(2)}), 1)
c3 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)
c4 = data.Const(data.StructLayout({"a": unsigned(1), "b": unsigned(1)}), 2)
c5 = data.Const(data.ArrayLayout(2, 4), 0b11100100)
s1 = Signal(data.StructLayout({"a": unsigned(2)}))
self.assertTrue(c1 == c2)
self.assertFalse(c1 != c2)
Expand All @@ -983,13 +988,23 @@ def test_eq(self):
self.assertRepr(c1 != s1, "(!= (const 2'd1) (sig s1))")
self.assertRepr(s1 == c1, "(== (sig s1) (const 2'd1))")
self.assertRepr(s1 != c1, "(!= (sig s1) (const 2'd1))")
self.assertTrue(c1 == {"a": 1})
self.assertFalse(c1 == {"a": 2})
self.assertFalse(c1 != {"a": 1})
self.assertTrue(c1 != {"a": 2})
self.assertTrue(c5 == [0,1,2,3])
self.assertFalse(c5 == [0,1,3,3])
self.assertFalse(c5 != [0,1,2,3])
self.assertTrue(c5 != [0,1,3,3])
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c1 == c4
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c1 != c4
with self.assertRaisesRegex(TypeError,
r"^View with layout .* can only be compared to another view or constant with "
Expand All @@ -1000,21 +1015,45 @@ def test_eq(self):
r"the same layout, not .*$"):
s1 != c4
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c4 == s1
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c4 != s1
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c1 == Const(0, 2)
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view or constant with "
r"the same layout, not .*$"):
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c1 != Const(0, 2)
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c1 == {"b": 1}
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c1 != {"b": 1}
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c5 == [0,1,2,3,4]
with self.assertRaisesRegex(TypeError,
r"^Constant with layout .* can only be compared to another view, a constant "
r"with the same layout, or a dictionary or a list that can be converted to "
r"a constant with the same layout, not .*$"):
c5 != [0,1,2,3,4]

def test_operator(self):
s1 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)
Expand Down