Skip to content

Commit 2e034f3

Browse files
authored
Merge pull request #154 from freddieknets/PythonVarsInHybridClass
Python vars in hybrid class
2 parents 06c9845 + 651f4ec commit 2e034f3

File tree

8 files changed

+116
-21
lines changed

8 files changed

+116
-21
lines changed

tests/test_buffer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,12 @@ def test_free_simple(test_context):
9595
ch.free_string(offset)
9696
ch.check()
9797

98+
9899
@for_all_test_contexts
99100
def test_free(test_context):
100101
class CheckFree(xo.Struct):
101102
a = xo.Float64
103+
102104
ch = CheckFree(a=5, _context=test_context)
103105
assert ch._buffer.capacity == 8
104106
assert ch._buffer.chunks == []
@@ -107,17 +109,17 @@ class CheckFree(xo.Struct):
107109
with pytest.raises(ValueError, match="Cannot free outside of buffer"):
108110
ch._buffer.free(0, 10)
109111
with pytest.raises(ValueError, match="Cannot free outside of buffer"):
110-
ch._buffer.free(7,2)
111-
ch._buffer.free(0,4)
112+
ch._buffer.free(7, 2)
113+
ch._buffer.free(0, 4)
112114
assert len(ch._buffer.chunks) == 1
113115
assert ch._buffer.chunks[0].start == 0
114116
assert ch._buffer.chunks[0].end == 4
115-
ch._buffer.free(0,4) # Does nothing
116-
ch._buffer.free(2,4) # Increases free chunk
117+
ch._buffer.free(0, 4) # Does nothing
118+
ch._buffer.free(2, 4) # Increases free chunk
117119
assert len(ch._buffer.chunks) == 1
118120
assert ch._buffer.chunks[0].start == 0
119121
assert ch._buffer.chunks[0].end == 6
120-
ch._buffer.free(7,1)
122+
ch._buffer.free(7, 1)
121123
assert len(ch._buffer.chunks) == 2
122124
assert ch._buffer.chunks[0].start == 0
123125
assert ch._buffer.chunks[0].end == 6

tests/test_hybrid_class.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,78 @@ class A(xo.HybridClass):
330330
assert np.all(b_dict.pop("a") == [1, 2, 3])
331331
assert b_dict.pop("d") == 8
332332
assert b_dict == {}
333+
334+
335+
def test_to_dict_python_vars():
336+
class TD(xo.HybridClass):
337+
_xofields = {
338+
"_a": xo.Float64[3],
339+
"_b": xo.Int64,
340+
}
341+
_skip_in_to_dict = ["_a", "_b"]
342+
_store_in_to_dict = ["a", "b", "c"]
343+
344+
def __init__(self, **kwargs):
345+
if "_xobject" in kwargs and kwargs["_xobject"] is not None:
346+
self.xoinitialize(**kwargs)
347+
return
348+
kwargs["_a"] = kwargs.pop("a", [1.0, 2.0, 3.0])
349+
kwargs["_b"] = kwargs.pop("b", 0)
350+
self._initialize(**kwargs)
351+
352+
def _initialize(self, **kwargs):
353+
# Need to handle non-xofields manually
354+
c = kwargs.pop("c", -9)
355+
super().__init__(**kwargs)
356+
self._c = c
357+
358+
@property
359+
def a(self):
360+
return self._a
361+
362+
@property
363+
def b(self):
364+
return self._b
365+
366+
@property
367+
def c(self):
368+
return self._c
369+
370+
# Verify that to_dict has all fields, including python-only ones,
371+
# for default initialisation
372+
td1 = TD()
373+
td1_dict = td1.to_dict()
374+
assert td1_dict.pop("__class__") == "TD"
375+
assert all(td1_dict.pop("a") == [1, 2, 3])
376+
assert td1_dict.pop("b") == 0
377+
assert td1_dict.pop("c") == -9
378+
assert td1_dict == {}
379+
380+
# Verify that to_dict has all fields, including python-only ones,
381+
# for custom initialisation
382+
td2 = TD(a=[8, 9, 10], b=40, c=20)
383+
td2_dict = td2.to_dict()
384+
assert td2_dict.pop("__class__") == "TD"
385+
assert all(td2_dict.pop("a") == [8, 9, 10])
386+
assert td2_dict.pop("b") == 40
387+
assert td2_dict.pop("c") == 20
388+
assert td2_dict == {}
389+
390+
# Verify that from_dict works correctly
391+
td2_dict = td2.to_dict()
392+
td3 = TD.from_dict(td2_dict)
393+
assert all(td3.a == td2.a)
394+
assert td3.b == td2.b
395+
assert td3.c == td2.c
396+
397+
# Verify that copy works correctly
398+
td4 = td3.copy()
399+
assert all(td4.a == td2.a)
400+
assert td4.b == td2.b
401+
assert td4.c == td2.c
402+
403+
# Verify that move works correctly
404+
td3.move(_context=xo.ContextCpu(omp_num_threads="auto"))
405+
assert all(td3.a == td2.a)
406+
assert td3.b == td2.b
407+
assert td3.c == td2.c

tests/test_to_json.py renamed to tests/test_to_dict.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import xobjects as xo
22

33

4-
def test_to_json():
4+
def test_to_dict():
55
class A(xo.Struct):
66
a = xo.Float64[:]
77
b = xo.Int64
@@ -11,14 +11,14 @@ class Uref(xo.UnionRef):
1111

1212
x = A(a=[2, 3], b=1)
1313
u = Uref(x)
14-
v = Uref(*u._to_json())
14+
v = Uref(*u._to_dict())
1515

1616
assert v.get().a[0] == 2
1717
assert v.get().a[1] == 3
1818
assert v.get().b == 1
1919

2020

21-
def test_to_json_array():
21+
def test_to_dict_array():
2222
class A(xo.Struct):
2323
a = xo.Float64[:]
2424

@@ -35,7 +35,7 @@ class Uref(xo.UnionRef):
3535
a[1] = A(a=[3])
3636
a[5] = B(c=2, d=1)
3737

38-
b = AUref(a._to_json())
38+
b = AUref(a._to_dict())
3939

4040
assert b[1].a[0] == 3
4141
assert b[5].d == 1

xobjects/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,10 +690,13 @@ def _get_inner_types(cls):
690690
return [cls._itemtype]
691691

692692
def _to_json(self):
693+
raise NameError("`_to_json` has been removed. Use `_to_dict` instead.")
694+
695+
def _to_dict(self):
693696
out = []
694697
for v in self: # TODO does not support multidimensional arrays
695-
if hasattr(v, "_to_json"):
696-
vdata = v._to_json()
698+
if hasattr(v, "_to_dict"):
699+
vdata = v._to_dict()
697700
else:
698701
vdata = v
699702
if self._has_refs and v is not None:

xobjects/context_cupy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,11 @@ def build_kernels(
457457

458458
extra_include_paths = self.get_installed_c_source_paths()
459459
include_flags = [f"-I{path}" for path in extra_include_paths]
460-
extra_compile_args = (*extra_compile_args, *include_flags, "-DXO_CONTEXT_CUDA")
460+
extra_compile_args = (
461+
*extra_compile_args,
462+
*include_flags,
463+
"-DXO_CONTEXT_CUDA",
464+
)
461465

462466
module = cupy.RawModule(
463467
code=specialized_source, options=extra_compile_args

xobjects/hybrid_class.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,18 @@ def copy(self, _context=None, _buffer=None, _offset=None):
340340
if _context is None and _buffer is None:
341341
_context = self._xobject._buffer.context
342342
# This makes a copy of the xobject
343-
xobject = self._XoStruct(
343+
new_xobject = self._XoStruct(
344344
self._xobject, _context=_context, _buffer=_buffer, _offset=_offset
345345
)
346-
return self.__class__(_xobject=xobject)
346+
new = self.__class__.__new__(self.__class__)
347+
new.__dict__.update(self.__dict__)
348+
for kk, vv in new.__dict__.items():
349+
if kk == "_xobject":
350+
continue
351+
if hasattr(vv, "copy"):
352+
new.__dict__[kk] = vv.copy()
353+
new._xobject = new_xobject
354+
return new
347355

348356
@property
349357
def _buffer(self):

xobjects/ref.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,13 @@ def _get_inner_types(cls):
294294
return cls._reftypes
295295

296296
def _to_json(self):
297+
raise NameError("`_to_json` has been removed. Use `_to_dict` instead.")
298+
299+
def _to_dict(self):
297300
v = self.get()
298301
classname = v.__class__.__name__
299-
if hasattr(v, "_to_json"):
300-
v = v._to_json()
302+
if hasattr(v, "_to_dict"):
303+
v = v._to_dict()
301304
return (classname, v)
302305

303306

xobjects/struct.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,15 +366,15 @@ def _set_offsets(cls, buffer, offset, loffsets):
366366
foffset = offset + cls._fields[index].offset
367367
Int64._to_buffer(buffer, foffset, data_offset)
368368

369-
def _to_dict(self):
370-
return {field.name: field.__get__(self) for field in self._fields}
371-
372369
def _to_json(self):
370+
raise NameError("`_to_json` has been removed. Use `_to_dict` instead.")
371+
372+
def _to_dict(self):
373373
out = {}
374374
for field in self._fields:
375375
v = field.__get__(self)
376-
if hasattr(v, "_to_json"):
377-
v = v._to_json()
376+
if hasattr(v, "_to_dict"):
377+
v = v._to_dict()
378378
out[field.name] = v
379379
return out
380380

0 commit comments

Comments
 (0)