Skip to content

Commit 5043be4

Browse files
authored
Merge branch 'main' into patch-1
2 parents aaa4ce3 + 3cb18e3 commit 5043be4

File tree

6 files changed

+147
-20
lines changed

6 files changed

+147
-20
lines changed

torchx/cli/cmd_status.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99

1010
import argparse
11+
import json
1112
import logging
1213
import sys
1314
from typing import List, Optional
@@ -46,6 +47,11 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
4647
subparser.add_argument(
4748
"--roles", type=str, default="", help="comma separated roles to filter"
4849
)
50+
subparser.add_argument(
51+
"--json",
52+
action="store_true",
53+
help="output the status in JSON format",
54+
)
4955

5056
def run(self, args: argparse.Namespace) -> None:
5157
app_handle = args.app_handle
@@ -54,7 +60,10 @@ def run(self, args: argparse.Namespace) -> None:
5460
app_status = runner.status(app_handle)
5561
filter_roles = parse_list_arg(args.roles)
5662
if app_status:
57-
print(app_status.format(filter_roles))
63+
if args.json:
64+
print(json.dumps(app_status.to_json(filter_roles)))
65+
else:
66+
print(app_status.format(filter_roles))
5867
else:
5968
logger.error(
6069
f"AppDef: {app_id},"

torchx/cli/test/cmd_run_test.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from unittest.mock import MagicMock, patch
2222

2323
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig
24-
2524
from torchx.cli.cmd_run import _parse_component_name_and_args, CmdBuiltins, CmdRun
2625
from torchx.schedulers.local_scheduler import SignalException
2726

@@ -216,38 +215,41 @@ def test_store_experiment_id(self, mock_runner_run: MagicMock) -> None:
216215
self.assertEqual(call_kwargs["parent_run_id"], "experiment_1")
217216

218217
def test_parse_component_name_and_args_no_default(self) -> None:
218+
# set dirs to test tmpdir so tests don't accidentally pick up user's $HOME/.torchxconfig
219+
dirs = [str(self.tmpdir)]
220+
219221
sp = argparse.ArgumentParser(prog="test")
220222
self.assertEqual(
221223
("utils.echo", []),
222-
_parse_component_name_and_args(["utils.echo"], sp),
224+
_parse_component_name_and_args(["utils.echo"], sp, dirs),
223225
)
224226
self.assertEqual(
225227
("utils.echo", []),
226-
_parse_component_name_and_args(["--", "utils.echo"], sp),
228+
_parse_component_name_and_args(["--", "utils.echo"], sp, dirs),
227229
)
228230
self.assertEqual(
229231
("utils.echo", ["--msg", "hello"]),
230-
_parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp),
232+
_parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp, dirs),
231233
)
232234

233235
self.assertEqual(
234236
("utils.echo", ["--msg", "hello", "--", "--"]),
235237
_parse_component_name_and_args(
236-
["utils.echo", "--msg", "hello", "--", "--"], sp
238+
["utils.echo", "--msg", "hello", "--", "--"], sp, dirs
237239
),
238240
)
239241

240242
self.assertEqual(
241243
("utils.echo", ["--msg", "hello", "-", "-"]),
242244
_parse_component_name_and_args(
243-
["utils.echo", "--msg", "hello", "-", "-"], sp
245+
["utils.echo", "--msg", "hello", "-", "-"], sp, dirs
244246
),
245247
)
246248

247249
self.assertEqual(
248250
("utils.echo", ["--msg", "hello", "- ", "- "]),
249251
_parse_component_name_and_args(
250-
["utils.echo", "--msg", "hello", "- ", "- "], sp
252+
["utils.echo", "--msg", "hello", "- ", "- "], sp, dirs
251253
),
252254
)
253255

@@ -274,32 +276,35 @@ def test_parse_component_name_and_args_no_default(self) -> None:
274276
"-m",
275277
],
276278
sp,
279+
dirs,
277280
),
278281
)
279282

280283
with self.assertRaises(SystemExit):
281-
_parse_component_name_and_args(["--"], sp)
284+
_parse_component_name_and_args(["--"], sp, dirs)
282285

283286
with self.assertRaises(SystemExit):
284-
_parse_component_name_and_args(["--msg", "hello"], sp)
287+
_parse_component_name_and_args(["--msg", "hello"], sp, dirs)
285288

286289
with self.assertRaises(SystemExit):
287-
_parse_component_name_and_args(["-m", "hello"], sp)
290+
_parse_component_name_and_args(["-m", "hello"], sp, dirs)
288291

289292
with self.assertRaises(SystemExit):
290-
_parse_component_name_and_args(["-m", "hello", "-m", "repeate"], sp)
293+
_parse_component_name_and_args(["-m", "hello", "-m", "repeate"], sp, dirs)
291294

292295
with self.assertRaises(SystemExit):
293-
_parse_component_name_and_args(["--msg", "hello", "--msg", "repeate"], sp)
296+
_parse_component_name_and_args(
297+
["--msg", "hello", "--msg", "repeate"], sp, dirs
298+
)
294299

295300
with self.assertRaises(SystemExit):
296301
_parse_component_name_and_args(
297-
["--msg ", "hello", "--msg ", "repeate"], sp
302+
["--msg ", "hello", "--msg ", "repeate"], sp, dirs
298303
)
299304

300305
with self.assertRaises(SystemExit):
301306
_parse_component_name_and_args(
302-
["--m", "hello", "--", "--msg", "msg", "--msg", "repeate"], sp
307+
["--m", "hello", "--", "--msg", "msg", "--msg", "repeate"], sp, dirs
303308
)
304309

305310
def test_parse_component_name_and_args_with_default(self) -> None:

torchx/specs/api.py

+25
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,15 @@ class RoleStatus:
538538
role: str
539539
replicas: List[ReplicaStatus]
540540

541+
def to_json(self) -> Dict[str, Any]:
542+
"""
543+
Convert the RoleStatus to a json object.
544+
"""
545+
return {
546+
"role": self.role,
547+
"replicas": [asdict(replica) for replica in self.replicas],
548+
}
549+
541550

542551
@dataclass
543552
class AppStatus:
@@ -657,6 +666,21 @@ def _format_role_status(
657666
replica_data += self._format_replica_status(replica)
658667
return f"{replica_data}"
659668

669+
def to_json(self, filter_roles: Optional[List[str]] = None) -> Dict[str, Any]:
670+
"""
671+
Convert the AppStatus to a json object, including RoleStatus.
672+
"""
673+
roles = self._get_role_statuses(self.roles, filter_roles)
674+
675+
return {
676+
"state": str(self.state),
677+
"num_restarts": self.num_restarts,
678+
"roles": [role_status.to_json() for role_status in roles],
679+
"msg": self.msg,
680+
"structured_error_msg": self.structured_error_msg,
681+
"url": self.ui_url,
682+
}
683+
660684
def format(
661685
self,
662686
filter_roles: Optional[List[str]] = None,
@@ -672,6 +696,7 @@ def format(
672696
"""
673697
roles_data = ""
674698
roles = self._get_role_statuses(self.roles, filter_roles)
699+
675700
for role_status in roles:
676701
roles_data += self._format_role_status(role_status)
677702
return Template(_APP_STATUS_FORMAT_TEMPLATE).substitute(

torchx/specs/test/api_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,42 @@ def test_format_app_status(self) -> None:
176176
# Split and compare to aviod AssertionError.
177177
self.assertEqual(expected_message.split(), actual_message.split())
178178

179+
def test_app_status_in_json(self) -> None:
180+
app_status = self._get_test_app_status()
181+
result = app_status.to_json()
182+
error_msg = '{"message":{"message":"error","errorCode":-1,"extraInfo":{"timestamp":1293182}}}'
183+
self.assertDictEqual(
184+
result,
185+
{
186+
"state": "RUNNING",
187+
"num_restarts": 0,
188+
"roles": [
189+
{
190+
"role": "worker",
191+
"replicas": [
192+
{
193+
"id": 0,
194+
"state": 5,
195+
"role": "worker",
196+
"hostname": "localhost",
197+
"structured_error_msg": error_msg,
198+
},
199+
{
200+
"id": 1,
201+
"state": 3,
202+
"role": "worker",
203+
"hostname": "localhost",
204+
"structured_error_msg": "<NONE>",
205+
},
206+
],
207+
}
208+
],
209+
"msg": "",
210+
"structured_error_msg": "<NONE>",
211+
"url": None,
212+
},
213+
)
214+
179215

180216
class ResourceTest(unittest.TestCase):
181217
def test_copy_resource(self) -> None:

torchx/util/modules.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,13 @@
88

99
import importlib
1010
from types import ModuleType
11-
from typing import Callable, Optional, Union
11+
from typing import Callable, Optional, TypeVar, Union
1212

1313

1414
def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]:
1515
"""
1616
Loads and returns the module/module attr represented by the ``path``: ``full.module.path:optional_attr``
1717
18-
::
19-
20-
2118
1. ``load_module("this.is.a_module:fn")`` -> equivalent to ``this.is.a_module.fn``
2219
1. ``load_module("this.is.a_module")`` -> equivalent to ``this.is.a_module``
2320
"""
@@ -33,3 +30,36 @@ def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]
3330
return getattr(module, method) if method else module
3431
except Exception:
3532
return None
33+
34+
35+
T = TypeVar("T")
36+
37+
38+
def import_attr(name: str, attr: str, default: T) -> T:
39+
"""
40+
Imports ``name.attr`` and returns it if the module is found.
41+
Otherwise, returns the specified ``default``.
42+
Useful when getting an attribute from an optional dependency.
43+
44+
Note that the ``default`` parameter is intentionally not an optional
45+
since this function is intended to be used with modules that may not be
46+
installed as a dependency. Therefore the caller must ALWAYS provide a
47+
sensible default.
48+
49+
Usage:
50+
51+
.. code-block:: python
52+
53+
aws_resources = import_attr("torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={})
54+
all_resources.update(aws_resources)
55+
56+
Raises:
57+
AttributeError: If the module exists (e.g. can be imported)
58+
but does not have an attribute with name ``attr``.
59+
"""
60+
try:
61+
mod = importlib.import_module(name)
62+
except ModuleNotFoundError:
63+
return default
64+
else:
65+
return getattr(mod, attr)

torchx/util/test/modules_test.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import unittest
810

9-
from torchx.util.modules import load_module
11+
from torchx.util.modules import import_attr, load_module
1012

1113

1214
class ModulesTest(unittest.TestCase):
@@ -21,3 +23,23 @@ def test_load_module_method(self) -> None:
2123
import os
2224

2325
self.assertEqual(result, os.path.join)
26+
27+
def test_try_import(self) -> None:
28+
def _join(_0: str, *_1: str) -> str:
29+
return "" # should never be called
30+
31+
os_path_join = import_attr("os.path", "join", default=_join)
32+
import os
33+
34+
self.assertEqual(os.path.join, os_path_join)
35+
36+
def test_try_import_non_existent_module(self) -> None:
37+
should_default = import_attr("non.existent", "foo", default="bar")
38+
self.assertEqual("bar", should_default)
39+
40+
def test_try_import_non_existent_attr(self) -> None:
41+
def _join(_0: str, *_1: str) -> str:
42+
return "" # should never be called
43+
44+
with self.assertRaises(AttributeError):
45+
import_attr("os.path", "joyin", default=_join)

0 commit comments

Comments
 (0)