diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 21720ea..4f11158 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -95,10 +95,12 @@ def test_free_simple(test_context): ch.free_string(offset) ch.check() + @for_all_test_contexts def test_free(test_context): class CheckFree(xo.Struct): a = xo.Float64 + ch = CheckFree(a=5, _context=test_context) assert ch._buffer.capacity == 8 assert ch._buffer.chunks == [] @@ -107,17 +109,17 @@ class CheckFree(xo.Struct): with pytest.raises(ValueError, match="Cannot free outside of buffer"): ch._buffer.free(0, 10) with pytest.raises(ValueError, match="Cannot free outside of buffer"): - ch._buffer.free(7,2) - ch._buffer.free(0,4) + ch._buffer.free(7, 2) + ch._buffer.free(0, 4) assert len(ch._buffer.chunks) == 1 assert ch._buffer.chunks[0].start == 0 assert ch._buffer.chunks[0].end == 4 - ch._buffer.free(0,4) # Does nothing - ch._buffer.free(2,4) # Increases free chunk + ch._buffer.free(0, 4) # Does nothing + ch._buffer.free(2, 4) # Increases free chunk assert len(ch._buffer.chunks) == 1 assert ch._buffer.chunks[0].start == 0 assert ch._buffer.chunks[0].end == 6 - ch._buffer.free(7,1) + ch._buffer.free(7, 1) assert len(ch._buffer.chunks) == 2 assert ch._buffer.chunks[0].start == 0 assert ch._buffer.chunks[0].end == 6 diff --git a/tests/test_hybrid_class.py b/tests/test_hybrid_class.py index 647b809..f4b0676 100644 --- a/tests/test_hybrid_class.py +++ b/tests/test_hybrid_class.py @@ -330,3 +330,78 @@ class A(xo.HybridClass): assert np.all(b_dict.pop("a") == [1, 2, 3]) assert b_dict.pop("d") == 8 assert b_dict == {} + + +def test_to_dict_python_vars(): + class TD(xo.HybridClass): + _xofields = { + "_a": xo.Float64[3], + "_b": xo.Int64, + } + _skip_in_to_dict = ["_a", "_b"] + _store_in_to_dict = ["a", "b", "c"] + + def __init__(self, **kwargs): + if "_xobject" in kwargs and kwargs["_xobject"] is not None: + self.xoinitialize(**kwargs) + return + kwargs["_a"] = kwargs.pop("a", [1.0, 2.0, 3.0]) + kwargs["_b"] = kwargs.pop("b", 0) + self._initialize(**kwargs) + + def _initialize(self, **kwargs): + # Need to handle non-xofields manually + c = kwargs.pop("c", -9) + super().__init__(**kwargs) + self._c = c + + @property + def a(self): + return self._a + + @property + def b(self): + return self._b + + @property + def c(self): + return self._c + + # Verify that to_dict has all fields, including python-only ones, + # for default initialisation + td1 = TD() + td1_dict = td1.to_dict() + assert td1_dict.pop("__class__") == "TD" + assert all(td1_dict.pop("a") == [1, 2, 3]) + assert td1_dict.pop("b") == 0 + assert td1_dict.pop("c") == -9 + assert td1_dict == {} + + # Verify that to_dict has all fields, including python-only ones, + # for custom initialisation + td2 = TD(a=[8, 9, 10], b=40, c=20) + td2_dict = td2.to_dict() + assert td2_dict.pop("__class__") == "TD" + assert all(td2_dict.pop("a") == [8, 9, 10]) + assert td2_dict.pop("b") == 40 + assert td2_dict.pop("c") == 20 + assert td2_dict == {} + + # Verify that from_dict works correctly + td2_dict = td2.to_dict() + td3 = TD.from_dict(td2_dict) + assert all(td3.a == td2.a) + assert td3.b == td2.b + assert td3.c == td2.c + + # Verify that copy works correctly + td4 = td3.copy() + assert all(td4.a == td2.a) + assert td4.b == td2.b + assert td4.c == td2.c + + # Verify that move works correctly + td3.move(_context=xo.ContextCpu(omp_num_threads="auto")) + assert all(td3.a == td2.a) + assert td3.b == td2.b + assert td3.c == td2.c diff --git a/tests/test_to_json.py b/tests/test_to_dict.py similarity index 85% rename from tests/test_to_json.py rename to tests/test_to_dict.py index 075a18e..040dc87 100644 --- a/tests/test_to_json.py +++ b/tests/test_to_dict.py @@ -1,7 +1,7 @@ import xobjects as xo -def test_to_json(): +def test_to_dict(): class A(xo.Struct): a = xo.Float64[:] b = xo.Int64 @@ -11,14 +11,14 @@ class Uref(xo.UnionRef): x = A(a=[2, 3], b=1) u = Uref(x) - v = Uref(*u._to_json()) + v = Uref(*u._to_dict()) assert v.get().a[0] == 2 assert v.get().a[1] == 3 assert v.get().b == 1 -def test_to_json_array(): +def test_to_dict_array(): class A(xo.Struct): a = xo.Float64[:] @@ -35,7 +35,7 @@ class Uref(xo.UnionRef): a[1] = A(a=[3]) a[5] = B(c=2, d=1) - b = AUref(a._to_json()) + b = AUref(a._to_dict()) assert b[1].a[0] == 3 assert b[5].d == 1 diff --git a/xobjects/array.py b/xobjects/array.py index e49d051..be1191e 100644 --- a/xobjects/array.py +++ b/xobjects/array.py @@ -690,10 +690,13 @@ def _get_inner_types(cls): return [cls._itemtype] def _to_json(self): + raise NameError("`_to_json` has been removed. Use `_to_dict` instead.") + + def _to_dict(self): out = [] for v in self: # TODO does not support multidimensional arrays - if hasattr(v, "_to_json"): - vdata = v._to_json() + if hasattr(v, "_to_dict"): + vdata = v._to_dict() else: vdata = v if self._has_refs and v is not None: diff --git a/xobjects/context_cupy.py b/xobjects/context_cupy.py index a0b9fac..8b93372 100644 --- a/xobjects/context_cupy.py +++ b/xobjects/context_cupy.py @@ -457,7 +457,11 @@ def build_kernels( extra_include_paths = self.get_installed_c_source_paths() include_flags = [f"-I{path}" for path in extra_include_paths] - extra_compile_args = (*extra_compile_args, *include_flags, "-DXO_CONTEXT_CUDA") + extra_compile_args = ( + *extra_compile_args, + *include_flags, + "-DXO_CONTEXT_CUDA", + ) module = cupy.RawModule( code=specialized_source, options=extra_compile_args diff --git a/xobjects/hybrid_class.py b/xobjects/hybrid_class.py index df751d5..c9a09df 100644 --- a/xobjects/hybrid_class.py +++ b/xobjects/hybrid_class.py @@ -340,10 +340,18 @@ def copy(self, _context=None, _buffer=None, _offset=None): if _context is None and _buffer is None: _context = self._xobject._buffer.context # This makes a copy of the xobject - xobject = self._XoStruct( + new_xobject = self._XoStruct( self._xobject, _context=_context, _buffer=_buffer, _offset=_offset ) - return self.__class__(_xobject=xobject) + new = self.__class__.__new__(self.__class__) + new.__dict__.update(self.__dict__) + for kk, vv in new.__dict__.items(): + if kk == "_xobject": + continue + if hasattr(vv, "copy"): + new.__dict__[kk] = vv.copy() + new._xobject = new_xobject + return new @property def _buffer(self): diff --git a/xobjects/ref.py b/xobjects/ref.py index aa7e99a..eaa2671 100644 --- a/xobjects/ref.py +++ b/xobjects/ref.py @@ -294,10 +294,13 @@ def _get_inner_types(cls): return cls._reftypes def _to_json(self): + raise NameError("`_to_json` has been removed. Use `_to_dict` instead.") + + def _to_dict(self): v = self.get() classname = v.__class__.__name__ - if hasattr(v, "_to_json"): - v = v._to_json() + if hasattr(v, "_to_dict"): + v = v._to_dict() return (classname, v) diff --git a/xobjects/struct.py b/xobjects/struct.py index 54e299f..7a1a98c 100644 --- a/xobjects/struct.py +++ b/xobjects/struct.py @@ -366,15 +366,15 @@ def _set_offsets(cls, buffer, offset, loffsets): foffset = offset + cls._fields[index].offset Int64._to_buffer(buffer, foffset, data_offset) - def _to_dict(self): - return {field.name: field.__get__(self) for field in self._fields} - def _to_json(self): + raise NameError("`_to_json` has been removed. Use `_to_dict` instead.") + + def _to_dict(self): out = {} for field in self._fields: v = field.__get__(self) - if hasattr(v, "_to_json"): - v = v._to_json() + if hasattr(v, "_to_dict"): + v = v._to_dict() out[field.name] = v return out