Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ def test_free_simple(test_context):
ch.free_string(offset)
ch.check()


@for_all_test_contexts
def test_free(test_context):
class CheckFree(xo.Struct):
a = xo.Float64

ch = CheckFree(a=5, _context=test_context)
assert ch._buffer.capacity == 8
assert ch._buffer.chunks == []
Expand All @@ -107,17 +109,17 @@ class CheckFree(xo.Struct):
with pytest.raises(ValueError, match="Cannot free outside of buffer"):
ch._buffer.free(0, 10)
with pytest.raises(ValueError, match="Cannot free outside of buffer"):
ch._buffer.free(7,2)
ch._buffer.free(0,4)
ch._buffer.free(7, 2)
ch._buffer.free(0, 4)
assert len(ch._buffer.chunks) == 1
assert ch._buffer.chunks[0].start == 0
assert ch._buffer.chunks[0].end == 4
ch._buffer.free(0,4) # Does nothing
ch._buffer.free(2,4) # Increases free chunk
ch._buffer.free(0, 4) # Does nothing
ch._buffer.free(2, 4) # Increases free chunk
assert len(ch._buffer.chunks) == 1
assert ch._buffer.chunks[0].start == 0
assert ch._buffer.chunks[0].end == 6
ch._buffer.free(7,1)
ch._buffer.free(7, 1)
assert len(ch._buffer.chunks) == 2
assert ch._buffer.chunks[0].start == 0
assert ch._buffer.chunks[0].end == 6
Expand Down
75 changes: 75 additions & 0 deletions tests/test_hybrid_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,78 @@ class A(xo.HybridClass):
assert np.all(b_dict.pop("a") == [1, 2, 3])
assert b_dict.pop("d") == 8
assert b_dict == {}


def test_to_dict_python_vars():
class TD(xo.HybridClass):
_xofields = {
"_a": xo.Float64[3],
"_b": xo.Int64,
}
_skip_in_to_dict = ["_a", "_b"]
_store_in_to_dict = ["a", "b", "c"]

def __init__(self, **kwargs):
if "_xobject" in kwargs and kwargs["_xobject"] is not None:
self.xoinitialize(**kwargs)
return
kwargs["_a"] = kwargs.pop("a", [1.0, 2.0, 3.0])
kwargs["_b"] = kwargs.pop("b", 0)
self._initialize(**kwargs)

def _initialize(self, **kwargs):
# Need to handle non-xofields manually
c = kwargs.pop("c", -9)
super().__init__(**kwargs)
self._c = c

@property
def a(self):
return self._a

@property
def b(self):
return self._b

@property
def c(self):
return self._c

# Verify that to_dict has all fields, including python-only ones,
# for default initialisation
td1 = TD()
td1_dict = td1.to_dict()
assert td1_dict.pop("__class__") == "TD"
assert all(td1_dict.pop("a") == [1, 2, 3])
assert td1_dict.pop("b") == 0
assert td1_dict.pop("c") == -9
assert td1_dict == {}

# Verify that to_dict has all fields, including python-only ones,
# for custom initialisation
td2 = TD(a=[8, 9, 10], b=40, c=20)
td2_dict = td2.to_dict()
assert td2_dict.pop("__class__") == "TD"
assert all(td2_dict.pop("a") == [8, 9, 10])
assert td2_dict.pop("b") == 40
assert td2_dict.pop("c") == 20
assert td2_dict == {}

# Verify that from_dict works correctly
td2_dict = td2.to_dict()
td3 = TD.from_dict(td2_dict)
assert all(td3.a == td2.a)
assert td3.b == td2.b
assert td3.c == td2.c

# Verify that copy works correctly
td4 = td3.copy()
assert all(td4.a == td2.a)
assert td4.b == td2.b
assert td4.c == td2.c

# Verify that move works correctly
td3.move(_context=xo.ContextCpu(omp_num_threads="auto"))
assert all(td3.a == td2.a)
assert td3.b == td2.b
assert td3.c == td2.c
8 changes: 4 additions & 4 deletions tests/test_to_json.py → tests/test_to_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import xobjects as xo


def test_to_json():
def test_to_dict():
class A(xo.Struct):
a = xo.Float64[:]
b = xo.Int64
Expand All @@ -11,14 +11,14 @@ class Uref(xo.UnionRef):

x = A(a=[2, 3], b=1)
u = Uref(x)
v = Uref(*u._to_json())
v = Uref(*u._to_dict())

assert v.get().a[0] == 2
assert v.get().a[1] == 3
assert v.get().b == 1


def test_to_json_array():
def test_to_dict_array():
class A(xo.Struct):
a = xo.Float64[:]

Expand All @@ -35,7 +35,7 @@ class Uref(xo.UnionRef):
a[1] = A(a=[3])
a[5] = B(c=2, d=1)

b = AUref(a._to_json())
b = AUref(a._to_dict())

assert b[1].a[0] == 3
assert b[5].d == 1
7 changes: 5 additions & 2 deletions xobjects/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,10 +690,13 @@ def _get_inner_types(cls):
return [cls._itemtype]

def _to_json(self):
raise NameError("`_to_json` has been removed. Use `_to_dict` instead.")

def _to_dict(self):
out = []
for v in self: # TODO does not support multidimensional arrays
if hasattr(v, "_to_json"):
vdata = v._to_json()
if hasattr(v, "_to_dict"):
vdata = v._to_dict()
else:
vdata = v
if self._has_refs and v is not None:
Expand Down
6 changes: 5 additions & 1 deletion xobjects/context_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,11 @@ def build_kernels(

extra_include_paths = self.get_installed_c_source_paths()
include_flags = [f"-I{path}" for path in extra_include_paths]
extra_compile_args = (*extra_compile_args, *include_flags, "-DXO_CONTEXT_CUDA")
extra_compile_args = (
*extra_compile_args,
*include_flags,
"-DXO_CONTEXT_CUDA",
)

module = cupy.RawModule(
code=specialized_source, options=extra_compile_args
Expand Down
12 changes: 10 additions & 2 deletions xobjects/hybrid_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,18 @@ def copy(self, _context=None, _buffer=None, _offset=None):
if _context is None and _buffer is None:
_context = self._xobject._buffer.context
# This makes a copy of the xobject
xobject = self._XoStruct(
new_xobject = self._XoStruct(
self._xobject, _context=_context, _buffer=_buffer, _offset=_offset
)
return self.__class__(_xobject=xobject)
new = self.__class__.__new__(self.__class__)
new.__dict__.update(self.__dict__)
for kk, vv in new.__dict__.items():
if kk == "_xobject":
continue
if hasattr(vv, "copy"):
new.__dict__[kk] = vv.copy()
new._xobject = new_xobject
return new

@property
def _buffer(self):
Expand Down
7 changes: 5 additions & 2 deletions xobjects/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ def _get_inner_types(cls):
return cls._reftypes

def _to_json(self):
raise NameError("`_to_json` has been removed. Use `_to_dict` instead.")

def _to_dict(self):
v = self.get()
classname = v.__class__.__name__
if hasattr(v, "_to_json"):
v = v._to_json()
if hasattr(v, "_to_dict"):
v = v._to_dict()
return (classname, v)


Expand Down
10 changes: 5 additions & 5 deletions xobjects/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,15 @@ def _set_offsets(cls, buffer, offset, loffsets):
foffset = offset + cls._fields[index].offset
Int64._to_buffer(buffer, foffset, data_offset)

def _to_dict(self):
return {field.name: field.__get__(self) for field in self._fields}

def _to_json(self):
raise NameError("`_to_json` has been removed. Use `_to_dict` instead.")

def _to_dict(self):
out = {}
for field in self._fields:
v = field.__get__(self)
if hasattr(v, "_to_json"):
v = v._to_json()
if hasattr(v, "_to_dict"):
v = v._to_dict()
out[field.name] = v
return out

Expand Down