Skip to content

[CompileGuard] Fix CompileGuardSelect #1294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion magma/backend/coreir/coreir_compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import subprocess
from magma.compile_guard import CompileGuardSelect
from magma.compiler import Compiler
from magma.config import EnvConfig, config
from magma.backend.coreir.insert_coreir_wires import insert_coreir_wires
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(self, main, basename, opts):
def compile(self):
result = {}
result["symbol_table"] = symbol_table = SymbolTable()
elaborate_all_pass(self.main, generators=(Mux,))
elaborate_all_pass(self.main, generators=(Mux, CompileGuardSelect,))
insert_coreir_wires(self.main)
insert_wrap_casts(self.main)
raise_logs_as_exceptions_pass(self.main)
Expand Down
25 changes: 24 additions & 1 deletion magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from magma.circuit import AnonymousCircuitType, CircuitKind, DefineCircuitKind
from magma.clock import Reset, ResetN, AsyncReset, AsyncResetN
from magma.common import filter_by_key, assert_false
from magma.compile_guard import get_compile_guard_data
from magma.compile_guard import get_compile_guard_data, CompileGuardSelect
from magma.digital import Digital, DigitalMeta
from magma.inline_verilog_expression import InlineVerilogExpression
from magma.inline_verilog2 import InlineVerilog2
Expand Down Expand Up @@ -880,6 +880,27 @@ def visit_magma_xmr_source(self, module: ModuleWrapper) -> bool:
sv.ReadInOutOp(operands=[in_out], results=[result])
return True

@wrap_with_not_implemented_error
def visit_magma_compile_guard_select(self, module: ModuleWrapper) -> bool:
inst = module.module
defn = type(inst)
assert isinstance(defn, CompileGuardSelect)
assert len(defn.keys) + 1 == len(module.operands)
assert len(module.results) == 1
result = module.results[0]
mlir_type = magma_type_to_mlir_type(defn.T)
reg = self.ctx.new_value(hw.InOutType(mlir_type))
sv.RegOp(results=[reg])
with contextlib.ExitStack() as stack:
for i, key in enumerate(defn.keys):
if_def = sv.IfDefOp(key)
stack.enter_context(push_block(if_def.then_block))
sv.AssignOp(operands=[reg, module.operands[i]])
stack.enter_context(push_block(if_def.else_block))
sv.AssignOp(operands=[reg, module.operands[-1]])
sv.ReadInOutOp(operands=[reg], results=[result])
return True

@wrap_with_not_implemented_error
def visit_inline_verilog(self, module: ModuleWrapper) -> bool:
inst = module.module
Expand Down Expand Up @@ -921,6 +942,8 @@ def visit_instance(self, module: ModuleWrapper) -> bool:
return self.visit_magma_xmr_sink(module)
if isinstance(defn, XMRSource):
return self.visit_magma_xmr_source(module)
if isinstance(defn, CompileGuardSelect):
return self.visit_magma_compile_guard_select(module)
if getattr(defn, "inline_verilog_strs", []):
return self.visit_inline_verilog(module)
if isprimitive(defn):
Expand Down
56 changes: 28 additions & 28 deletions magma/compile_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,44 +140,44 @@ def _is_simple_type(T: Kind) -> bool:

class _CompileGuardSelect(Generator2):
def __init__(self, T: Kind, keys: Tuple[str]):
# NOTE(rsetaluri): We need to add this check because the implementation
assert keys, "Expected at least one key"
self.T = T
self.keys = keys
self.io = IO(
**{f"I{i}": In(T) for i in range(len(self.keys) + 1)},
O=Out(T),
)
T_str = type_to_sanitized_string(T)
self.name = f"CompileGuardSelect_{T_str}_{'_'.join(keys)}_default"
self.primitive = True

def elaborate(self):
# NOTE(rsetaluri): We need to add this check because this implementation
# of this generator emits verilog directly, and thereby requires that no
# transformations happen to the port names/types. If the type is not
# "simple" (i.e. Bit or Bits[N]) then the assumption breaks down and
# this implementation will not work.
if not _is_simple_type(T):
raise TypeError(f"Unsupported type: {T}")
num_keys = len(keys)
assert num_keys > 1
self.io = IO(**{f"I{i}": In(T) for i in range(num_keys)}, O=Out(T))
if not _is_simple_type(self.T):
raise TypeError(f"Unsupported type: {self.T}")
self.verilog = ""
for i, key in enumerate(keys):
if i == 0:
stmt = f"`ifdef {key}"
elif key == "default":
assert i == (num_keys - 1)
stmt = "`else"
else:
stmt = f"`elsif {key}"
self.verilog += f"""\
{stmt}
assign O = I{i};
"""
for i, key in enumerate(self.keys):
pred = f"`ifdef {key}" if i == 0 else f"`elsif {key}"
self.verilog += f"{pred}\n assign O = I{i};\n"
self.verilog += f"`else\n assign O = I{len(self.keys)};\n"
self.verilog += "`endif"
T_str = type_to_sanitized_string(T)
self.name = f"CompileGuardSelect_{T_str}_{'_'.join(keys)}"


CompileGuardSelect = _CompileGuardSelect


def compile_guard_select(**kwargs):
try:
default = kwargs.pop("default")
except KeyError:
raise ValueError("Expected default argument") from None
if not (len(kwargs) > 1):
raise ValueError("Expected at least one key besides default")
# We rely on insertion order to make the default the last element for the
# generated if/elif/else code.
kwargs["default"] = default
T, _ = infer_mux_type(list(kwargs.values()))
raise KeyError("Expected default argument") from None
if not kwargs: # kwargs is empty
raise KeyError("Expected at least one key besides default")
values = tuple(kwargs.values()) + (default,)
T, values = infer_mux_type(values)
Select = _CompileGuardSelect(T, tuple(kwargs.keys()))
return Select()(*kwargs.values())
return Select()(*values)
68 changes: 68 additions & 0 deletions tests/gold/test_compile_guard_select_two_keys.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
module coreir_reg #(
parameter width = 1,
parameter clk_posedge = 1,
parameter init = 1
) (
input clk,
input [width-1:0] in,
output [width-1:0] out
);
reg [width-1:0] outReg=init;
wire real_clk;
assign real_clk = clk_posedge ? clk : ~clk;
always @(posedge real_clk) begin
outReg <= in;
end
assign out = outReg;
endmodule

module Register (
input I,
output O,
input CLK
);
wire [0:0] reg_P1_inst0_out;
coreir_reg #(
.clk_posedge(1'b1),
.init(1'h0),
.width(1)
) reg_P1_inst0 (
.clk(CLK),
.in(I),
.out(reg_P1_inst0_out)
);
assign O = reg_P1_inst0_out[0];
endmodule

module CompileGuardSelect_Bit_COND1_default (
input I0,
input I1,
output O
);
`ifdef COND1
assign O = I0;
`else
assign O = I1;
`endif
endmodule

module _Top (
input I,
output O,
input CLK
);
wire Register_inst0_O;
wire magma_Bit_xor_inst0_out;
CompileGuardSelect_Bit_COND1_default CompileGuardSelect_Bit_COND1_default_inst0 (
.I0(Register_inst0_O),
.I1(I),
.O(O)
);
Register Register_inst0 (
.I(magma_Bit_xor_inst0_out),
.O(Register_inst0_O),
.CLK(CLK)
);
assign magma_Bit_xor_inst0_out = I ^ 1'b1;
endmodule

13 changes: 8 additions & 5 deletions tests/test_backend/test_mlir/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,15 @@ class xmr_bind_asserts(m.Circuit):
class simple_compile_guard(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
with m.compile_guard(
"COND1", defn_name="COND1_compile_guard", type="defined"):
out = m.Register(m.Bit)()(io.I)
"COND1", defn_name="COND1_compile_guard", type="defined"
):
m.Register(m.Bit)()(io.I)
with m.compile_guard(
"COND2", defn_name="COND2_compile_guard", type="undefined"):
out = m.Register(m.Bit)()(io.I)
io.O @= io.I
"COND2", defn_name="COND2_compile_guard", type="undefined"
):
m.Register(m.Bit)()(io.I)
out = m.compile_guard_select(COND1=io.I, COND2=~io.I, default=0)
io.O @= out


m.passes.clock.WireClockPass(simple_compile_guard).run()
Expand Down
16 changes: 15 additions & 1 deletion tests/test_backend/test_mlir/golds/simple_compile_guard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,27 @@ module attributes {circt.loweringOptions = "locationInfoStyle=none"} {
%0 = sv.read_inout %1 : !hw.inout<i1>
}
hw.module @simple_compile_guard(%I: i1, %CLK: i1) -> (O: i1) {
%1 = hw.constant -1 : i1
%0 = comb.xor %1, %I : i1
%2 = hw.constant 0 : i1
%4 = sv.reg : !hw.inout<i1>
sv.ifdef "COND1" {
sv.assign %4, %I : i1
} else {
sv.ifdef "COND2" {
sv.assign %4, %0 : i1
} else {
sv.assign %4, %2 : i1
}
}
%3 = sv.read_inout %4 : !hw.inout<i1>
sv.ifdef "COND1" {
hw.instance "COND1_compile_guard" @COND1_compile_guard(port_0: %I: i1, port_1: %CLK: i1) -> ()
}
sv.ifdef "COND2" {
} else {
hw.instance "COND2_compile_guard" @COND2_compile_guard(port_0: %I: i1, port_1: %CLK: i1) -> ()
}
hw.output %I : i1
hw.output %3 : i1
}
}
10 changes: 9 additions & 1 deletion tests/test_backend/test_mlir/golds/simple_compile_guard.v
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,26 @@ module simple_compile_guard(
output O
);

reg _GEN;
`ifdef COND1
assign _GEN = I;
COND1_compile_guard COND1_compile_guard (
.port_0 (I),
.port_1 (CLK)
);
`else // COND1
`ifdef COND2
assign _GEN = ~I;
`else // COND2
assign _GEN = 1'h0;
`endif // COND2
`endif // COND1
`ifndef COND2
COND2_compile_guard COND2_compile_guard (
.port_0 (I),
.port_1 (CLK)
);
`endif // not def COND2
assign O = I;
assign O = _GEN;
endmodule

48 changes: 41 additions & 7 deletions tests/test_compile_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import magma as m
import magma.testing
from magma.passes.elaborate_circuit import elaborate_circuit
import fault as f


Expand Down Expand Up @@ -253,20 +254,50 @@ class _Top(m.Circuit):


def test_compile_guard_select_basic():

class _Top(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()

x = m.Register(m.Bit)()(io.I ^ 1)
y = m.Register(m.Bit)()(io.I)

io.O @= m.compile_guard_select(
COND1=x, COND2=y, default=io.I
)
io.O @= m.compile_guard_select(COND1=x, COND2=y, default=io.I)

basename = "test_compile_guard_select_basic"
m.compile(f"build/{basename}", _Top, inline=True)
assert m.testing.check_files_equal(
__file__, f"build/{basename}.v", f"gold/{basename}.v")
__file__, f"build/{basename}.v", f"gold/{basename}.v"
)


def test_compile_guard_select_two_keys():

class _Top(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
x = m.Register(m.Bit)()(io.I ^ 1)
io.O @= m.compile_guard_select(COND1=x, default=io.I)

basename = "test_compile_guard_select_two_keys"
m.compile(f"build/{basename}", _Top, inline=True)
assert m.testing.check_files_equal(
__file__, f"build/{basename}.v", f"gold/{basename}.v"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update these to use mlir/mlir-verilog instead of coreir? Will save us doing it in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do these in a follow up PR.



def test_compile_guard_select_no_default():
with pytest.raises(KeyError):

class _Top(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
x = m.Register(m.Bit)()(io.I ^ 1)
io.O @= m.compile_guard_select(COND1=x)


def test_compile_guard_select_only_default():
with pytest.raises(KeyError):

class _Top(m.Circuit):
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
x = m.Register(m.Bit)()(io.I ^ 1)
io.O @= m.compile_guard_select(default=x)


def test_compile_guard_select_complex_type():
Expand All @@ -277,7 +308,10 @@ def make_top():
class _Top(m.Circuit):
io = m.IO(I0=m.In(T), I1=m.In(T), O=m.Out(T))
io.O @= m.compile_guard_select(
COND1=io.I0, COND2=io.I1, default=io.I0)
COND1=io.I0, COND2=io.I1, default=io.I0
)

elaborate_circuit(type(_Top.instances[0]))

with pytest.raises(TypeError):
make_top()
Expand Down