@@ -330,3 +330,78 @@ class A(xo.HybridClass):
330330 assert np .all (b_dict .pop ("a" ) == [1 , 2 , 3 ])
331331 assert b_dict .pop ("d" ) == 8
332332 assert b_dict == {}
333+
334+
335+ def test_to_dict_python_vars ():
336+ class TD (xo .HybridClass ):
337+ _xofields = {
338+ "_a" : xo .Float64 [3 ],
339+ "_b" : xo .Int64 ,
340+ }
341+ _skip_in_to_dict = ["_a" , "_b" ]
342+ _store_in_to_dict = ["a" , "b" , "c" ]
343+
344+ def __init__ (self , ** kwargs ):
345+ if "_xobject" in kwargs and kwargs ["_xobject" ] is not None :
346+ self .xoinitialize (** kwargs )
347+ return
348+ kwargs ["_a" ] = kwargs .pop ("a" , [1.0 , 2.0 , 3.0 ])
349+ kwargs ["_b" ] = kwargs .pop ("b" , 0 )
350+ self ._initialize (** kwargs )
351+
352+ def _initialize (self , ** kwargs ):
353+ # Need to handle non-xofields manually
354+ c = kwargs .pop ("c" , - 9 )
355+ super ().__init__ (** kwargs )
356+ self ._c = c
357+
358+ @property
359+ def a (self ):
360+ return self ._a
361+
362+ @property
363+ def b (self ):
364+ return self ._b
365+
366+ @property
367+ def c (self ):
368+ return self ._c
369+
370+ # Verify that to_dict has all fields, including python-only ones,
371+ # for default initialisation
372+ td1 = TD ()
373+ td1_dict = td1 .to_dict ()
374+ assert td1_dict .pop ("__class__" ) == "TD"
375+ assert all (td1_dict .pop ("a" ) == [1 , 2 , 3 ])
376+ assert td1_dict .pop ("b" ) == 0
377+ assert td1_dict .pop ("c" ) == - 9
378+ assert td1_dict == {}
379+
380+ # Verify that to_dict has all fields, including python-only ones,
381+ # for custom initialisation
382+ td2 = TD (a = [8 , 9 , 10 ], b = 40 , c = 20 )
383+ td2_dict = td2 .to_dict ()
384+ assert td2_dict .pop ("__class__" ) == "TD"
385+ assert all (td2_dict .pop ("a" ) == [8 , 9 , 10 ])
386+ assert td2_dict .pop ("b" ) == 40
387+ assert td2_dict .pop ("c" ) == 20
388+ assert td2_dict == {}
389+
390+ # Verify that from_dict works correctly
391+ td2_dict = td2 .to_dict ()
392+ td3 = TD .from_dict (td2_dict )
393+ assert all (td3 .a == td2 .a )
394+ assert td3 .b == td2 .b
395+ assert td3 .c == td2 .c
396+
397+ # Verify that copy works correctly
398+ td4 = td3 .copy ()
399+ assert all (td4 .a == td2 .a )
400+ assert td4 .b == td2 .b
401+ assert td4 .c == td2 .c
402+
403+ # Verify that move works correctly
404+ td3 .move (_context = xo .ContextCpu (omp_num_threads = "auto" ))
405+ assert all (td3 .a == td2 .a )
406+ assert td3 .b == td2 .b
407+ assert td3 .c == td2 .c
0 commit comments