Skip to content

Commit b995e68

Browse files
authored
Merge pull request #315 from zyannes/patch-1
Allow recursive references in atdpy
2 parents 3dab250 + e360768 commit b995e68

File tree

5 files changed

+73
-11
lines changed

5 files changed

+73
-11
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Master
22
------------------
33

4+
* atdpy: Support recursive definitions
45
* atdts: fix nullable object field writer (#312)
56

67

atdpy/src/lib/Codegen.ml

+2-11
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ methods and functions to convert data from/to JSON.
195195
# Disable flake8 entirely on this file:
196196
# flake8: noqa
197197

198+
# Import annotations to allow forward references
199+
from __future__ import annotations
198200
from dataclasses import dataclass
199201
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union
200202

@@ -1168,19 +1170,8 @@ let module_body env x =
11681170
|> List.rev
11691171
|> spaced
11701172

1171-
let extract_definition_names (items : A.module_body) =
1172-
List.map (fun (Type (loc, (name, param, an), e)) -> name) items
1173-
11741173
let definition_group ~atd_filename env
11751174
(is_recursive, (items: A.module_body)) : B.t =
1176-
if is_recursive then
1177-
A.error (
1178-
sprintf "recursive definitions are not supported by atdpy \
1179-
at this time: types %s in %S"
1180-
(extract_definition_names items
1181-
|> String.concat ", ")
1182-
atd_filename
1183-
);
11841175
[
11851176
Inline (module_body env items);
11861177
]

atdpy/test/atd-input/everything.atd

+6
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ type require_field <python decorator="deco.deco1"
4444
decorator="dataclass(order=True)"> = {
4545
req: string;
4646
}
47+
48+
type recursive_class = {
49+
id: int;
50+
flag: bool;
51+
children: recursive_class list;
52+
}

atdpy/test/python-expected/everything.py

+35
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Disable flake8 entirely on this file:
88
# flake8: noqa
99

10+
from __future__ import annotations
1011
from dataclasses import dataclass
1112
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union
1213

@@ -240,6 +241,40 @@ def write_nullable(x: Any) -> Any:
240241
from dataclasses import dataclass
241242

242243

244+
@dataclass
245+
class RecursiveClass:
246+
"""Original type: recursive_class = { ... }"""
247+
248+
id: int
249+
flag: bool
250+
children: List[RecursiveClass]
251+
252+
@classmethod
253+
def from_json(cls, x: Any) -> 'RecursiveClass':
254+
if isinstance(x, dict):
255+
return cls(
256+
id=_atd_read_int(x['id']) if 'id' in x else _atd_missing_json_field('RecursiveClass', 'id'),
257+
flag=_atd_read_bool(x['flag']) if 'flag' in x else _atd_missing_json_field('RecursiveClass', 'flag'),
258+
children=_atd_read_list(RecursiveClass.from_json)(x['children']) if 'children' in x else _atd_missing_json_field('RecursiveClass', 'children'),
259+
)
260+
else:
261+
_atd_bad_json('RecursiveClass', x)
262+
263+
def to_json(self) -> Any:
264+
res: Dict[str, Any] = {}
265+
res['id'] = _atd_write_int(self.id)
266+
res['flag'] = _atd_write_bool(self.flag)
267+
res['children'] = _atd_write_list((lambda x: x.to_json()))(self.children)
268+
return res
269+
270+
@classmethod
271+
def from_json_string(cls, x: str) -> 'RecursiveClass':
272+
return cls.from_json(json.loads(x))
273+
274+
def to_json_string(self, **kw: Any) -> str:
275+
return json.dumps(self.to_json(), **kw)
276+
277+
243278
@dataclass
244279
class Root_:
245280
"""Original type: kind = [ ... | Root | ... ]"""

atdpy/test/python-tests/test_atdpy.py

+29
Original file line numberDiff line numberDiff line change
@@ -196,5 +196,34 @@ def test_pair() -> None:
196196
)
197197

198198

199+
def test_recursive_class() -> None:
200+
child1 = e.RecursiveClass(id=1, flag=True, children=[])
201+
child2 = e.RecursiveClass(id=2, flag=True, children=[])
202+
a_obj = e.RecursiveClass(id=0, flag=False, children=[child1, child2])
203+
a_str = a_obj.to_json_string(indent=2)
204+
205+
b_str = """{
206+
"id": 0,
207+
"flag": false,
208+
"children": [
209+
{
210+
"id": 1,
211+
"flag": true,
212+
"children": []
213+
},
214+
{
215+
"id": 2,
216+
"flag": true,
217+
"children": []
218+
}
219+
]
220+
}"""
221+
b_obj = e.RecursiveClass.from_json_string(a_str)
222+
b_str2 = b_obj.to_json_string(indent=2)
223+
224+
assert b_str == b_str2
225+
assert b_str2 == a_str
226+
227+
199228
# print updated json
200229
test_everything_to_json()

0 commit comments

Comments
 (0)