Skip to content

[Types] Support wiring Bit and Bits[1] #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions magma/bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@
* Subtype of the Digital type
* Implementation of hwtypes.AbstractBit
"""
import keyword
import typing as tp
import functools
import hwtypes as ht
from hwtypes.bit_vector_abc import AbstractBit, TypeFamily
from .t import Direction
from .t import Direction, Type
from .digital import Digital, DigitalMeta
from .digital import VCC, GND # TODO(rsetaluri): only here for b.c.

from magma.compatibility import IntegerTypes
from magma.debug import debug_wire
from magma.family import get_family
from magma.interface import IO
from magma.language_utils import primitive_to_python
from magma.protocol_type import magma_type, MagmaProtocol
from magma.protocol_type import magma_value
from magma.operator_utils import output_only


Expand Down Expand Up @@ -113,11 +110,20 @@ def ite(self, t_branch, f_branch):

@debug_wire
def wire(self, o, debug_info):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, why is this not in Digital? why do we need to override it in Bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this is just for Bit <-> Bits[1], if we wanted T <-> Array[1, T] then we would move it to Digital.

o = magma_value(o)
# Cast to Bit here so we don't get a Digital instead
if isinstance(o, (IntegerTypes, bool, ht.Bit)):
o = Bit(o)
if type(o).is_bits_1():
o = o[0]
return super().wire(o, debug_info)

@classmethod
def is_wireable(cls, rhs):
if issubclass(rhs, Type) and rhs.is_bits_1():
return True
return DigitalMeta.is_wireable(cls, rhs)


BitIn = Bit[Direction.In]
BitOut = Bit[Direction.Out]
Expand Down
11 changes: 9 additions & 2 deletions magma/bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,15 @@ def is_wireable(cls, rhs):
return True
if issubclass(cls, UInt) and issubclass(rhs, SInt):
return False
elif issubclass(cls, SInt) and issubclass(rhs, UInt):
if issubclass(cls, SInt) and issubclass(rhs, UInt):
return False
if len(cls) == 1 and issubclass(rhs, Bit):
return True
return super().is_wireable(rhs)

def is_bits_1(cls):
return len(cls) == 1


class Bits(Array, AbstractBitVector, metaclass=BitsMeta):
__hash__ = Array.__hash__
Expand Down Expand Up @@ -237,6 +242,7 @@ def __int__(self):

@debug_wire
def wire(self, other, debug_info):
from .conversions import bits
if isinstance(other, (IntegerTypes, BitVector)):
N = (other.bit_length()
if isinstance(other, IntegerTypes)
Expand All @@ -245,8 +251,9 @@ def wire(self, other, debug_info):
raise ValueError(
f"Cannot convert integer {other} "
f"(bit_length={other.bit_length()}) to Bits ({len(self)})")
from .conversions import bits
other = bits(other, len(self))
if isinstance(other, Bit) and len(self) == 1:
other = bits(other, 1)
super().wire(other, debug_info)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions magma/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def undirected_t(cls):
def is_directed(cls):
return cls is not cls.qualify(Direction.Undirected)

def is_bits_1(self):
return False


@lru_cache()
def In(T):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_circuit/test_new_style_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def definition(io):

def test_defn_wiring_error(caplog):
class _Foo(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.In(m.Bit), O1=m.Out(m.Bits[1]))
io = m.IO(I=m.In(m.Bit), O=m.In(m.Bit), O1=m.Out(m.Bits[2]))

m.wire(io.I, io.O)
m.wire(io.I, io.O1)
Expand All @@ -108,13 +108,13 @@ class _Foo(m.Circuit):
assert has_error(caplog,
"Cannot wire _Foo.I (Out(Bit)) to _Foo.O (Out(Bit))")
assert has_error(caplog,
"Cannot wire _Foo.I (Out(Bit)) to _Foo.O1 (In(Bits[1]))")
"Cannot wire _Foo.I (Out(Bit)) to _Foo.O1 (In(Bits[2]))")


@wrap_with_context_manager(logging_level("DEBUG"))
def test_inst_wiring_error(caplog):
class _Bar(m.Circuit):
io = m.IO(I=m.In(m.Bits[1]), O=m.Out(m.Bits[1]))
io = m.IO(I=m.In(m.Bits[2]), O=m.Out(m.Bits[2]))

class _Foo(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit))
Expand All @@ -125,10 +125,10 @@ class _Foo(m.Circuit):

assert has_error(
caplog,
"Cannot wire _Foo.I (Out(Bit)) to _Foo._Bar_inst0.I (In(Bits[1]))")
"Cannot wire _Foo.I (Out(Bit)) to _Foo._Bar_inst0.I (In(Bits[2]))")
assert has_error(
caplog,
"Cannot wire _Foo._Bar_inst0.O (Out(Bits[1])) to _Foo.O (In(Bit))")
"Cannot wire _Foo._Bar_inst0.O (Out(Bits[2])) to _Foo.O (In(Bit))")
assert has_error(caplog, "_Foo.O not driven")
assert has_debug(caplog, "_Foo.O: Unconnected")
assert has_error(caplog, "_Foo._Bar_inst0.I not driven")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_wire/test_wireable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import magma as m


Expand All @@ -20,3 +21,17 @@ class Main2(m.Circuit):
Cannot wire Main2.a (Out(SInt[16])) to Main2.b (In(UInt[16]))\
"""
assert caplog.messages[1][-len(expected):] == expected


@pytest.mark.parametrize('Ts', [
(m.Bit, m.Bits[1]),
(m.Bits[1], m.Bit),
])
def test_bit_bits1(Ts):
class Main(m.Circuit):
io = m.IO(a=m.In(Ts[0]), b=m.Out(Ts[1]))
io.b @= io.a

# NOTE: We call compile here to ensure a wiring error was not reported
# (otherwise it would raise an exception)
m.compile('build/Main', Main)