Skip to content

Commit 3e76a92

Browse files
committed
Added test
1 parent 14c0847 commit 3e76a92

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

tests/test_hybrid_class.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,79 @@ class A(xo.HybridClass):
330330
assert np.all(b_dict.pop("a") == [1, 2, 3])
331331
assert b_dict.pop("d") == 8
332332
assert b_dict == {}
333+
334+
335+
def test_to_dict_python_vars():
336+
class TD(xo.HybridClass):
337+
_xofields = {
338+
"_a": xo.Float64[3],
339+
"_b": xo.Int64,
340+
}
341+
_skip_in_to_dict = ['_a', '_b']
342+
_store_in_to_dict = ['a', 'b']
343+
_extra_fields_in_to_dict = ['c']
344+
345+
def __init__(self, **kwargs):
346+
if '_xobject' in kwargs and kwargs['_xobject'] is not None:
347+
self._initialize(**kwargs)
348+
return
349+
kwargs['_a'] = kwargs.pop('a', [1., 2., 3.])
350+
kwargs['_b'] = kwargs.pop('b', 0)
351+
self._initialize(**kwargs)
352+
353+
def _initialize(self, **kwargs):
354+
# Need to handle non-xofields manually
355+
c = kwargs.pop("c", -9)
356+
self.xoinitialize(**kwargs)
357+
self._c = c
358+
359+
@property
360+
def a(self):
361+
return self._a
362+
363+
@property
364+
def b(self):
365+
return self._b
366+
367+
@property
368+
def c(self):
369+
return self._c
370+
371+
# Verify that to_dict has all fields, including python-only ones,
372+
# for default initialisation
373+
td1 = TD()
374+
td1_dict = td1.to_dict()
375+
assert td1_dict.pop("__class__") == "TD"
376+
assert all(td1_dict.pop("a") == [1, 2, 3])
377+
assert td1_dict.pop("b") == 0
378+
assert td1_dict.pop("c") == -9
379+
assert td1_dict == {}
380+
381+
# Verify that to_dict has all fields, including python-only ones,
382+
# for custom initialisation
383+
td2 = TD(a=[8,9,10], b=40, c=20)
384+
td2_dict = td2.to_dict()
385+
assert td2_dict.pop("__class__") == "TD"
386+
assert all(td2_dict.pop("a") == [8, 9, 10])
387+
assert td2_dict.pop("b") == 40
388+
assert td2_dict.pop("c") == 20
389+
assert td2_dict == {}
390+
391+
# Verify that from_dict works correctly
392+
td2_dict = td2.to_dict()
393+
td3 = TD.from_dict(td2_dict)
394+
assert all(td3.a == td2.a)
395+
assert td3.b == td2.b
396+
assert td3.c == td2.c
397+
398+
# Verify that copy works correctly
399+
td4 = td3.copy()
400+
assert all(td4.a == td2.a)
401+
assert td4.b == td2.b
402+
assert td4.c == td2.c
403+
404+
# Verify that move works correctly
405+
td3.move(_context=xo.ContextCpu(omp_num_threads='auto'))
406+
assert all(td3.a == td2.a)
407+
assert td3.b == td2.b
408+
assert td3.c == td2.c

0 commit comments

Comments
 (0)