@@ -330,3 +330,79 @@ class A(xo.HybridClass):
330330 assert np .all (b_dict .pop ("a" ) == [1 , 2 , 3 ])
331331 assert b_dict .pop ("d" ) == 8
332332 assert b_dict == {}
333+
334+
335+ def test_to_dict_python_vars ():
336+ class TD (xo .HybridClass ):
337+ _xofields = {
338+ "_a" : xo .Float64 [3 ],
339+ "_b" : xo .Int64 ,
340+ }
341+ _skip_in_to_dict = ['_a' , '_b' ]
342+ _store_in_to_dict = ['a' , 'b' ]
343+ _extra_fields_in_to_dict = ['c' ]
344+
345+ def __init__ (self , ** kwargs ):
346+ if '_xobject' in kwargs and kwargs ['_xobject' ] is not None :
347+ self ._initialize (** kwargs )
348+ return
349+ kwargs ['_a' ] = kwargs .pop ('a' , [1. , 2. , 3. ])
350+ kwargs ['_b' ] = kwargs .pop ('b' , 0 )
351+ self ._initialize (** kwargs )
352+
353+ def _initialize (self , ** kwargs ):
354+ # Need to handle non-xofields manually
355+ c = kwargs .pop ("c" , - 9 )
356+ self .xoinitialize (** kwargs )
357+ self ._c = c
358+
359+ @property
360+ def a (self ):
361+ return self ._a
362+
363+ @property
364+ def b (self ):
365+ return self ._b
366+
367+ @property
368+ def c (self ):
369+ return self ._c
370+
371+ # Verify that to_dict has all fields, including python-only ones,
372+ # for default initialisation
373+ td1 = TD ()
374+ td1_dict = td1 .to_dict ()
375+ assert td1_dict .pop ("__class__" ) == "TD"
376+ assert all (td1_dict .pop ("a" ) == [1 , 2 , 3 ])
377+ assert td1_dict .pop ("b" ) == 0
378+ assert td1_dict .pop ("c" ) == - 9
379+ assert td1_dict == {}
380+
381+ # Verify that to_dict has all fields, including python-only ones,
382+ # for custom initialisation
383+ td2 = TD (a = [8 ,9 ,10 ], b = 40 , c = 20 )
384+ td2_dict = td2 .to_dict ()
385+ assert td2_dict .pop ("__class__" ) == "TD"
386+ assert all (td2_dict .pop ("a" ) == [8 , 9 , 10 ])
387+ assert td2_dict .pop ("b" ) == 40
388+ assert td2_dict .pop ("c" ) == 20
389+ assert td2_dict == {}
390+
391+ # Verify that from_dict works correctly
392+ td2_dict = td2 .to_dict ()
393+ td3 = TD .from_dict (td2_dict )
394+ assert all (td3 .a == td2 .a )
395+ assert td3 .b == td2 .b
396+ assert td3 .c == td2 .c
397+
398+ # Verify that copy works correctly
399+ td4 = td3 .copy ()
400+ assert all (td4 .a == td2 .a )
401+ assert td4 .b == td2 .b
402+ assert td4 .c == td2 .c
403+
404+ # Verify that move works correctly
405+ td3 .move (_context = xo .ContextCpu (omp_num_threads = 'auto' ))
406+ assert all (td3 .a == td2 .a )
407+ assert td3 .b == td2 .b
408+ assert td3 .c == td2 .c
0 commit comments