Skip to content

Commit 97b3756

Browse files
committed
Fix pyright
1 parent 380bc68 commit 97b3756

File tree

9 files changed

+92
-63
lines changed

9 files changed

+92
-63
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"aioserial",
1515
"numpy",
1616
"pydantic",
17-
"pvi~=0.9.0",
17+
"pvi~=0.10.0",
1818
"pytango",
1919
"softioc",
2020
]

src/fastcs/backends/epics/gui.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
LED,
88
ButtonPanel,
99
ComboBox,
10-
Component,
10+
ComponentUnion,
1111
Device,
1212
Grid,
1313
Group,
14-
ReadWidget,
14+
ReadWidgetUnion,
1515
SignalR,
1616
SignalRW,
1717
SignalW,
@@ -22,7 +22,7 @@
2222
TextWrite,
2323
ToggleButton,
2424
Tree,
25-
WriteWidget,
25+
WriteWidgetUnion,
2626
)
2727
from pydantic import ValidationError
2828

@@ -56,7 +56,7 @@ def _get_pv(self, attr_path: list[str], name: str):
5656
return f"{attr_prefix}:{name.title().replace('_', '')}"
5757

5858
@staticmethod
59-
def _get_read_widget(attribute: AttrR) -> ReadWidget:
59+
def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion:
6060
match attribute.datatype:
6161
case Bool():
6262
return LED()
@@ -68,7 +68,7 @@ def _get_read_widget(attribute: AttrR) -> ReadWidget:
6868
raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}")
6969

7070
@staticmethod
71-
def _get_write_widget(attribute: AttrW) -> WriteWidget:
71+
def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion:
7272
match attribute.allowed_values:
7373
case allowed_values if allowed_values is not None:
7474
return ComboBox(choices=allowed_values)
@@ -87,7 +87,7 @@ def _get_write_widget(attribute: AttrW) -> WriteWidget:
8787

8888
def _get_attribute_component(
8989
self, attr_path: list[str], name: str, attribute: Attribute
90-
):
90+
) -> SignalR | SignalW | SignalRW:
9191
pv = self._get_pv(attr_path, name)
9292
name = name.title().replace("_", "")
9393

@@ -108,6 +108,8 @@ def _get_attribute_component(
108108
case AttrW():
109109
write_widget = self._get_write_widget(attribute)
110110
return SignalW(name=name, write_pv=pv, write_widget=write_widget)
111+
case _:
112+
raise FastCSException(f"Unsupported attribute type: {type(attribute)}")
111113

112114
def _get_command_component(self, attr_path: list[str], name: str):
113115
pv = self._get_pv(attr_path, name)
@@ -136,8 +138,8 @@ def create_gui(self, options: EpicsGUIOptions | None = None) -> None:
136138
formatter = DLSFormatter()
137139
formatter.format(device, options.output_path)
138140

139-
def extract_mapping_components(self, mapping: SingleMapping) -> list[Component]:
140-
components: Tree[Component] = []
141+
def extract_mapping_components(self, mapping: SingleMapping) -> Tree:
142+
components: Tree = []
141143
attr_path = mapping.controller.path
142144

143145
for name, sub_controller in mapping.controller.get_sub_controllers().items():
@@ -151,7 +153,7 @@ def extract_mapping_components(self, mapping: SingleMapping) -> list[Component]:
151153
)
152154
)
153155

154-
groups: dict[str, list[Component]] = {}
156+
groups: dict[str, list[ComponentUnion]] = {}
155157
for attr_name, attribute in mapping.attributes.items():
156158
try:
157159
signal = self._get_attribute_component(

src/fastcs/backends/epics/ioc.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ async def async_record_set(value: T):
162162
record.set(enum_value_to_index(attribute, value))
163163
else:
164164

165-
async def async_record_set(value: T): # type: ignore
165+
async def async_record_set(value: T):
166166
record.set(value)
167167

168168
record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
@@ -173,8 +173,10 @@ async def async_record_set(value: T): # type: ignore
173173

174174
def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
175175
if attr_is_enum(attribute):
176-
# https://github.com/python/mypy/issues/16789
177-
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) # type: ignore
176+
assert attribute.allowed_values is not None and all(
177+
isinstance(v, str) for v in attribute.allowed_values
178+
)
179+
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False))
178180
return builder.mbbIn(pv, **state_keys)
179181

180182
match attribute.datatype:
@@ -210,7 +212,7 @@ async def async_write_display(value: T):
210212
async def on_update(value):
211213
await attribute.process_without_display_update(value)
212214

213-
async def async_write_display(value: T): # type: ignore
215+
async def async_write_display(value: T):
214216
record.set(value, process=False)
215217

216218
record = _get_output_record(
@@ -223,7 +225,10 @@ async def async_write_display(value: T): # type: ignore
223225

224226
def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any:
225227
if attr_is_enum(attribute):
226-
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) # type: ignore
228+
assert attribute.allowed_values is not None and all(
229+
isinstance(v, str) for v in attribute.allowed_values
230+
)
231+
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False))
227232
return builder.mbbOut(pv, always_update=True, on_update=on_update, **state_keys)
228233

229234
match attribute.datatype:

src/fastcs/backends/tango/dsr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _collect_dev_properties(mapping: Mapping) -> dict[str, Any]:
144144

145145
def _collect_dev_init(mapping: Mapping) -> dict[str, Callable]:
146146
async def init_device(tango_device: Device):
147-
await server.Device.init_device(tango_device)
147+
await server.Device.init_device(tango_device) # type: ignore
148148
tango_device.set_state(DevState.ON)
149149
await mapping.controller.connect()
150150

src/fastcs/connections/ip_connection.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,59 @@ class IPConnectionSettings:
1212
port: int = 25565
1313

1414

15-
class IPConnection:
16-
def __init__(self):
17-
self._reader, self._writer = (None, None)
15+
@dataclass
16+
class StreamConnection:
17+
reader: asyncio.StreamReader
18+
writer: asyncio.StreamWriter
19+
20+
def __post_init__(self):
1821
self._lock = asyncio.Lock()
1922

20-
async def connect(self, settings: IPConnectionSettings):
21-
self._reader, self._writer = await asyncio.open_connection(
22-
settings.ip, settings.port
23-
)
23+
async def __aenter__(self):
24+
await self._lock.acquire()
25+
return self
2426

25-
def ensure_connected(self):
26-
if self._reader is None or self._writer is None:
27+
async def __aexit__(self, exc_type, exc_val, exc_tb):
28+
self._lock.release()
29+
30+
async def send_message(self, message) -> None:
31+
self.writer.write(message.encode("utf-8"))
32+
await self.writer.drain()
33+
34+
async def receive_response(self) -> str:
35+
data = await self.reader.readline()
36+
return data.decode("utf-8")
37+
38+
async def close(self):
39+
self.writer.close()
40+
await self.writer.wait_closed()
41+
42+
43+
class IPConnection:
44+
def __init__(self):
45+
self.__connection = None
46+
47+
@property
48+
def _connection(self) -> StreamConnection:
49+
if self.__connection is None:
2750
raise DisconnectedError("Need to call connect() before using IPConnection.")
2851

52+
return self.__connection
53+
54+
async def connect(self, settings: IPConnectionSettings):
55+
reader, writer = await asyncio.open_connection(settings.ip, settings.port)
56+
self.__connection = StreamConnection(reader, writer)
57+
2958
async def send_command(self, message) -> None:
30-
async with self._lock:
31-
self.ensure_connected()
32-
await self._send_message(message)
59+
async with self._connection as connection:
60+
await connection.send_message(message)
3361

3462
async def send_query(self, message) -> str:
35-
async with self._lock:
36-
self.ensure_connected()
37-
await self._send_message(message)
38-
return await self._receive_response()
63+
async with self._connection as connection:
64+
await connection.send_message(message)
65+
return await connection.receive_response()
3966

40-
# TODO: Figure out type hinting for connections. TypeGuard fails to work as expected
4167
async def close(self):
42-
async with self._lock:
43-
self.ensure_connected()
44-
self._writer.close()
45-
await self._writer.wait_closed()
46-
self._reader, self._writer = (None, None)
47-
48-
async def _send_message(self, message) -> None:
49-
self._writer.write(message.encode("utf-8"))
50-
await self._writer.drain()
51-
52-
async def _receive_response(self) -> str:
53-
data = await self._reader.readline()
54-
return data.decode("utf-8")
68+
async with self._connection as connection:
69+
await connection.close()
70+
self.__connection = None

src/fastcs/connections/serial_connection.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,33 @@ def __init__(self):
2020
self._lock = asyncio.Lock()
2121

2222
async def connect(self, settings: SerialConnectionSettings) -> None:
23-
self.stream = aioserial.AioSerial(port=settings.port, baudrate=settings.baud)
23+
self.__stream = aioserial.AioSerial(port=settings.port, baudrate=settings.baud)
2424

25-
def ensure_open(self):
26-
if self.stream is None:
25+
@property
26+
def _stream(self) -> aioserial.AioSerial:
27+
if self.__stream is None:
2728
raise NotOpenedError(
2829
"Need to call connect() before using SerialConnection."
2930
)
3031

32+
return self.__stream
33+
3134
async def send_command(self, message: bytes) -> None:
3235
async with self._lock:
33-
self.ensure_open()
3436
await self._send_message(message)
3537

3638
async def send_query(self, message: bytes, response_size: int) -> bytes:
3739
async with self._lock:
38-
self.ensure_open()
3940
await self._send_message(message)
4041
return await self._receive_response(response_size)
4142

42-
async def close(self) -> None:
43-
async with self._lock:
44-
self.ensure_open()
45-
self.stream.close()
46-
self.stream = None
47-
4843
async def _send_message(self, message):
49-
await self.stream.write_async(message)
44+
await self._stream.write_async(message)
5045

5146
async def _receive_response(self, size):
52-
return await self.stream.read_async(size)
47+
return await self._stream.read_async(size)
48+
49+
async def close(self) -> None:
50+
async with self._lock:
51+
self._stream.close()
52+
self.__stream = None

tests/backends/epics/test_gui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SignalRW,
77
SignalW,
88
SignalX,
9+
TextFormat,
910
TextRead,
1011
TextWrite,
1112
ToggleButton,
@@ -53,7 +54,7 @@ def test_get_components(mapping):
5354
SignalRW(
5455
name="StringEnum",
5556
read_pv="DEVICE:StringEnum_RBV",
56-
read_widget=TextRead(format="string"),
57+
read_widget=TextRead(format=TextFormat.string),
5758
write_pv="DEVICE:StringEnum",
5859
write_widget=ComboBox(choices=["red", "green", "blue"]),
5960
),

tests/backends/epics/test_ioc_system.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from p4p import Value
12
from p4p.client.thread import Context
23

34

45
def test_ioc(ioc: None):
56
ctxt = Context("pva")
67

7-
parent_pvi = ctxt.get("DEVICE:PVI").todict()
8+
_parent_pvi = ctxt.get("DEVICE:PVI")
9+
assert isinstance(_parent_pvi, Value)
10+
parent_pvi = _parent_pvi.todict()
811
assert all(f in parent_pvi for f in ("alarm", "display", "timeStamp", "value"))
912
assert parent_pvi["display"] == {"description": "The records in this controller"}
1013
assert parent_pvi["value"] == {
@@ -14,7 +17,9 @@ def test_ioc(ioc: None):
1417
}
1518

1619
child_pvi_pv = parent_pvi["value"]["child"]["d"]
17-
child_pvi = ctxt.get(child_pvi_pv).todict()
20+
_child_pvi = ctxt.get(child_pvi_pv)
21+
assert isinstance(_child_pvi, Value)
22+
child_pvi = _child_pvi.todict()
1823
assert all(f in child_pvi for f in ("alarm", "display", "timeStamp", "value"))
1924
assert child_pvi["display"] == {"description": "The records in this controller"}
2025
assert child_pvi["value"] == {

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def ioc():
106106

107107
start_time = time.monotonic()
108108
while "iocRun: All initialization complete" not in (
109-
process.stdout.readline().strip()
109+
process.stdout.readline().strip() # type: ignore
110110
):
111111
if time.monotonic() - start_time > 10:
112112
raise TimeoutError("IOC did not start in time")

0 commit comments

Comments
 (0)