Skip to content

Commit 555f829

Browse files
committed
hdl.dsl: bring new naming rules to FSM
1 parent bba4235 commit 555f829

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

amaranth/hdl/_dsl.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
import sys
77

8-
from .._utils import flatten
8+
from .._utils import flatten, validate_name
99
from ..utils import bits_for
1010
from .. import tracer
1111
from ._ast import *
@@ -164,6 +164,7 @@ def __init__(self, data):
164164
self.decoding = data["decoding"]
165165

166166
def ongoing(self, name):
167+
validate_name(name, "FSM state")
167168
if name not in self.encoding:
168169
self.encoding[name] = len(self.encoding)
169170
fsm_name = self._data["name"]
@@ -426,6 +427,9 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
426427
warnings.warn("`reset=` is deprecated, use `init=` instead",
427428
DeprecationWarning, stacklevel=2)
428429
init = reset
430+
validate_name(name, "FSM name")
431+
validate_name(init, "FSM state", none_ok=True)
432+
validate_name(domain, "FSM clock domain")
429433
fsm_data = self._set_ctrl("FSM", {
430434
"name": name,
431435
"init": init,
@@ -455,6 +459,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
455459
@contextmanager
456460
def State(self, name):
457461
self._check_context("FSM State", context="FSM")
462+
validate_name(name, "FSM state")
458463
src_loc = tracer.get_src_loc(src_loc_at=1)
459464
fsm_data = self._get_ctrl("FSM")
460465
if name in fsm_data["states"]:
@@ -481,6 +486,7 @@ def next(self):
481486
@next.setter
482487
def next(self, name):
483488
if self._ctrl_context != "FSM":
489+
validate_name(name, "FSM state")
484490
for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)):
485491
if ctrl_name == "FSM":
486492
if name not in ctrl_data["encoding"]:

tests/test_hdl_dsl.py

+31
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,37 @@ def test_FSM_wrong_next(self):
793793
with m.FSM():
794794
m.next = "FOO"
795795

796+
def test_FSM_wrong_name(self):
797+
m = Module()
798+
with self.assertRaisesRegex(TypeError,
799+
r"^FSM name must be a string, not 1$"):
800+
with m.FSM(name=1):
801+
pass
802+
with self.assertRaisesRegex(TypeError,
803+
r"^FSM state must be a string, not 1$"):
804+
with m.FSM(init=1):
805+
pass
806+
with self.assertRaisesRegex(TypeError,
807+
r"^FSM clock domain must be a string, not 1$"):
808+
with m.FSM(domain=1):
809+
pass
810+
with self.assertRaisesRegex(TypeError,
811+
r"^FSM state must be a string, not 1$"):
812+
with m.FSM():
813+
with m.State(1):
814+
pass
815+
with self.assertRaisesRegex(TypeError,
816+
r"^FSM state must be a string, not 1$"):
817+
with m.FSM():
818+
with m.State("FOO"):
819+
m.next = 1
820+
with self.assertRaisesRegex(TypeError,
821+
r"^FSM state must be a string, not 1$"):
822+
with m.FSM() as fsm:
823+
s = fsm.ongoing(1)
824+
with m.State("FOO"):
825+
m.next = "FOO"
826+
796827
def test_If_inside_FSM_wrong(self):
797828
m = Module()
798829
with m.FSM():

0 commit comments

Comments
 (0)