-
Notifications
You must be signed in to change notification settings - Fork 179
Expand file tree
/
Copy pathgen_nexus_system_models.py
More file actions
143 lines (125 loc) · 3.94 KB
/
gen_nexus_system_models.py
File metadata and controls
143 lines (125 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations
import importlib
import subprocess
import sys
from pathlib import Path
NEXUS_RPC_GEN_VERSION = "0.1.0-alpha.4"
def main() -> None:
repo_root = Path(__file__).resolve().parent.parent
input_schema = (
repo_root
/ "temporalio"
/ "bridge"
/ "sdk-core"
/ "crates"
/ "common"
/ "protos"
/ "api_upstream"
/ "nexus"
/ "temporal-proto-models-nexusrpc.yaml"
)
output_file = (
repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py"
)
if not input_schema.is_file():
raise RuntimeError(f"Expected Nexus schema at {input_schema}")
run_nexus_rpc_gen(
output_file=output_file,
input_schema=input_schema,
)
add_operation_registry(repo_root, output_file)
subprocess.run(
[
"uv",
"run",
"ruff",
"check",
"--select",
"I",
"--fix",
str(output_file),
],
cwd=repo_root,
check=True,
)
subprocess.run(
[
"uv",
"run",
"ruff",
"format",
str(output_file),
],
cwd=repo_root,
check=True,
)
def add_operation_registry(repo_root: Path, output_file: Path) -> None:
source = output_file.read_text()
source = ensure_typing_import(source)
services = discover_services(repo_root)
if not services:
output_file.write_text(source)
return
output_file.write_text(source.rstrip() + "\n\n" + emit_operation_registry(services))
def ensure_typing_import(source: str) -> str:
if "\nimport typing\n" in source:
return source
marker = "from __future__ import annotations\n\n"
if marker not in source:
raise RuntimeError("Expected future-annotations import in generated output")
return source.replace(marker, marker + "import typing\n", 1)
def discover_services(repo_root: Path) -> list[tuple[str, str, list[tuple[str, str]]]]:
module_name = "temporalio.nexus.system._workflow_service_generated"
sys.path.insert(0, str(repo_root))
try:
sys.modules.pop(module_name, None)
importlib.invalidate_caches()
module = importlib.import_module(module_name)
finally:
sys.path.pop(0)
services: list[tuple[str, str, list[tuple[str, str]]]] = []
for value in vars(module).values():
if not isinstance(value, type):
continue
definition = getattr(value, "__nexus_service_definition__", None)
if definition is None:
continue
operations = [
(operation_definition.method_name, operation_definition.name)
for operation_definition in definition.operation_definitions.values()
]
services.append((value.__name__, definition.name, operations))
return services
def emit_operation_registry(
services: list[tuple[str, str, list[tuple[str, str]]]],
) -> str:
lines = [
"__nexus_operation_registry__: dict[",
" tuple[str, str], Operation[typing.Any, typing.Any]",
"] = {",
]
for class_name, service_name, operations in services:
for attr_name, operation_name in operations:
lines.append(
f" ({service_name!r}, {operation_name!r}): {class_name}.{attr_name},"
)
lines.append("}")
return "\n".join(lines).rstrip() + "\n"
def run_nexus_rpc_gen(*, output_file: Path, input_schema: Path) -> None:
common_args = [
"--lang",
"py",
"--out-file",
str(output_file),
str(input_schema),
]
subprocess.run(
["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args],
check=True,
)
if __name__ == "__main__":
try:
main()
except Exception as err:
print(f"Failed to generate Nexus system models: {err}", file=sys.stderr)
raise