Skip to content

Commit a0e8e01

Browse files
committed
Redo new copy implementation without need for flag
1 parent f643c43 commit a0e8e01

File tree

2 files changed

+12
-23
lines changed

2 files changed

+12
-23
lines changed

tests/test_hybrid_class.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,7 @@ class TD(xo.HybridClass):
339339
"_b": xo.Int64,
340340
}
341341
_skip_in_to_dict = ["_a", "_b"]
342-
_store_in_to_dict = ["a", "b"]
343-
_extra_fields_in_to_dict = ["c"]
342+
_store_in_to_dict = ["a", "b", "c"]
344343

345344
def __init__(self, **kwargs):
346345
if "_xobject" in kwargs and kwargs["_xobject"] is not None:
@@ -353,7 +352,7 @@ def __init__(self, **kwargs):
353352
def _initialize(self, **kwargs):
354353
# Need to handle non-xofields manually
355354
c = kwargs.pop("c", -9)
356-
self.xoinitialize(**kwargs)
355+
super().__init__(**kwargs)
357356
self._c = c
358357

359358
@property

xobjects/hybrid_class.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,7 @@ def to_dict(self, copy_to_cpu=True):
292292

293293
skip_fields = set(getattr(obj, "_skip_in_to_dict", []))
294294
additional_fields = set(getattr(obj, "_store_in_to_dict", []))
295-
extra_fields = set(getattr(obj, "_extra_fields_in_to_dict", []))
296295
fields_to_store = (set(obj._fields) - skip_fields) | additional_fields
297-
fields_to_store |= extra_fields
298296

299297
defaults = {}
300298
for field in obj._XoStruct._fields:
@@ -343,24 +341,18 @@ def copy(self, _context=None, _buffer=None, _offset=None):
343341
if _context is None and _buffer is None:
344342
_context = self._xobject._buffer.context
345343
# This makes a copy of the xobject
346-
xobject = self._XoStruct(
344+
new_xobject = self._XoStruct(
347345
self._xobject, _context=_context, _buffer=_buffer, _offset=_offset
348346
)
349-
# Get the python-only attributes
350-
additional_fields = set(getattr(self, "_extra_fields_in_to_dict", []))
351-
python_kwargs = {}
352-
for field in additional_fields:
353-
vv = getattr(self, field)
354-
if hasattr(vv, "copy"):
355-
try:
356-
python_kwargs[field] = vv.copy(
357-
_context=_context, _buffer=_buffer, _offset=_offset
358-
)
359-
except TypeError:
360-
python_kwargs[field] = vv.copy()
361-
else:
362-
python_kwargs[field] = vv
363-
return self.__class__(_xobject=xobject, **python_kwargs)
347+
new = self.__class__.__new__(self.__class__)
348+
new.__dict__.update(self.__dict__)
349+
for kk, vv in new.__dict__.items():
350+
if kk == '_xobject':
351+
continue
352+
if hasattr(vv, 'copy'):
353+
new.__dict__[kk] = vv.copy()
354+
new._xobject = new_xobject
355+
return new
364356

365357
@property
366358
def _buffer(self):
@@ -407,8 +399,6 @@ def __repr__(self):
407399
if hasattr(self, "_add_to_repr"):
408400
fnames += self._add_to_repr
409401
fnames += [fname for fname in self._fields]
410-
if hasattr(self, "_extra_fields_in_to_dict"):
411-
fnames += self._extra_fields_in_to_dict
412402
if hasattr(self, "_skip_in_repr"):
413403
fnames = [ff for ff in fnames if ff not in self._skip_in_repr]
414404

0 commit comments

Comments
 (0)