Skip to content

Commit 6daf837

Browse files
Fix type invariance causing too strict typing (#115)
* Use covariant containers for type annotations * formatting * stop mutating passed env arg * run isort * fix getting TERM env var in pystemd.run * fix isort formatting
1 parent e993135 commit 6daf837

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

pystemd/__init__.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# the root directory of this source tree.
77
#
88

9+
from collections.abc import Mapping, Sequence
910
from typing import Any, Protocol
1011

1112
from pystemd import machine1, systemd1
@@ -19,15 +20,15 @@ class SupportsFileno(Protocol):
1920
def fileno(self) -> int: ...
2021

2122
def run(
22-
cmd: list[str | bytes] | str | bytes,
23+
cmd: Sequence[str | bytes] | str | bytes,
2324
address: str | bytes | None = None,
2425
service_type: str | bytes | None = None,
2526
name: str | bytes | None = None,
2627
user: str | bytes | None = None,
2728
user_mode: bool = ...,
2829
nice: int | None = None,
2930
runtime_max_sec: int | float | None = None,
30-
env: dict[str | bytes, str | bytes] | None = None,
31+
env: Mapping[str, str | bytes] | Mapping[bytes, str | bytes] | None = None,
3132
extra: dict[bytes, Any] | None = None,
3233
cwd: str | bytes | None = None,
3334
machine: str | bytes | None = None,
@@ -44,8 +45,8 @@ def run(
4445
stderr: int | SupportsFileno | None = None,
4546
_wait_polling: int | float | None = None,
4647
slice_: str | bytes | None = None,
47-
stop_cmd: list[str | bytes] | str | bytes | None = None,
48-
stop_post_cmd: list[str | bytes] | str | bytes | None = None,
49-
start_pre_cmd: list[str | bytes] | str | bytes | None = None,
50-
start_post_cmd: list[str | bytes] | str | bytes | None = None,
48+
stop_cmd: Sequence[str | bytes] | str | bytes | None = None,
49+
stop_post_cmd: Sequence[str | bytes] | str | bytes | None = None,
50+
start_pre_cmd: Sequence[str | bytes] | str | bytes | None = None,
51+
start_post_cmd: Sequence[str | bytes] | str | bytes | None = None,
5152
) -> systemd1.Unit: ...

pystemd/run.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
import termios
1717
import tty
18+
from collections.abc import Mapping, Sequence
1819
from contextlib import ExitStack
1920
from selectors import EVENT_READ, DefaultSelector
2021
from typing import Any, Protocol
@@ -53,15 +54,15 @@ def get_fno(obj: int | SupportsFileno | None) -> int | None:
5354

5455

5556
def run(
56-
cmd: list[str | bytes] | str | bytes,
57+
cmd: Sequence[str | bytes] | str | bytes,
5758
address: str | bytes | None = None,
5859
service_type: str | bytes | None = None,
5960
name: str | bytes | None = None,
6061
user: str | bytes | None = None,
6162
user_mode: bool = USER_MODE,
6263
nice: int | None = None,
6364
runtime_max_sec: int | float | None = None,
64-
env: dict[str | bytes, str | bytes] | None = None,
65+
env: Mapping[str, str | bytes] | Mapping[bytes, str | bytes] | None = None,
6566
extra: dict[bytes, Any] | None = None,
6667
cwd: str | bytes | None = None,
6768
machine: str | bytes | None = None,
@@ -78,10 +79,10 @@ def run(
7879
stderr: int | SupportsFileno | None = None,
7980
_wait_polling: int | float | None = None,
8081
slice_: str | bytes | None = None,
81-
stop_cmd: list[str | bytes] | str | bytes | None = None,
82-
stop_post_cmd: list[str | bytes] | str | bytes | None = None,
83-
start_pre_cmd: list[str | bytes] | str | bytes | None = None,
84-
start_post_cmd: list[str | bytes] | str | bytes | None = None,
82+
stop_cmd: Sequence[str | bytes] | str | bytes | None = None,
83+
stop_post_cmd: Sequence[str | bytes] | str | bytes | None = None,
84+
start_pre_cmd: Sequence[str | bytes] | str | bytes | None = None,
85+
start_post_cmd: Sequence[str | bytes] | str | bytes | None = None,
8586
) -> Unit:
8687
"""
8788
pystemd.run imitates systemd-run, but with a pythonic feel to it.
@@ -164,7 +165,9 @@ def bus_factory():
164165
runtime_max_usec = (runtime_max_sec or 0) * 10**6 or runtime_max_sec
165166

166167
stdin, stdout, stderr = get_fno(stdin), get_fno(stdout), get_fno(stderr)
167-
env = env or {}
168+
env_dict: dict[bytes, str | bytes] = (
169+
{x2char_star(k): v for k, v in env.items()} if env else {}
170+
)
168171
unit_properties: dict[bytes, object] = {}
169172

170173
extra = extra or {}
@@ -219,9 +222,9 @@ def bus_factory():
219222
sel.register(stdin, EVENT_READ)
220223

221224
if None not in (stdout, pty_master):
222-
if os.getenv("TERM"):
223-
# pyrefly: ignore [missing-attribute]
224-
env[b"TERM"] = env.get(b"TERM", os.getenv("TERM").encode())
225+
term = os.getenv("TERM")
226+
if term:
227+
env_dict[b"TERM"] = env_dict.get(b"TERM", term.encode())
225228

226229
# pyrefly: ignore [bad-argument-type]
227230
sel.register(pty_master, EVENT_READ)
@@ -263,8 +266,8 @@ def bus_factory():
263266
b"Nice": nice,
264267
b"RuntimeMaxUSec": runtime_max_usec,
265268
b"Environment": [
266-
b"%s=%s" % (x2char_star(key), x2char_star(value))
267-
for key, value in env.items()
269+
b"%s=%s" % (key, x2char_star(value))
270+
for key, value in env_dict.items()
268271
]
269272
or None,
270273
}

0 commit comments

Comments
 (0)