Skip to content

Commit d50f0d9

Browse files
committed
Add __get__ accessor cache
1 parent 2dd1ce5 commit d50f0d9

File tree

2 files changed

+104
-15
lines changed

2 files changed

+104
-15
lines changed

quivr/columns.py

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def __get__(self, obj: tables.Table, objtype: type) -> T: ...
232232
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, T]:
233233
if obj is None:
234234
return self
235+
# Do NOT cache subtable objects: they are mutable via attributes,
236+
# and caching would make attribute mutations appear persistent.
235237
array = obj.table.column(self.name)
236238

237239
metadata = self.metadata
@@ -287,7 +289,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.Int8Array: ...
287289
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int8Array]:
288290
if obj is None:
289291
return self
290-
return _fast_combine_chunks(obj.table[self.name])
292+
cached = obj._array_cache.get(self.name)
293+
if cached is not None:
294+
return cached # type: ignore[return-value]
295+
arr = _fast_combine_chunks(obj.table[self.name])
296+
obj._array_cache[self.name] = arr
297+
return arr
291298

292299

293300
class Int16Column(Column):
@@ -317,7 +324,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.Int16Array: ...
317324
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int16Array]:
318325
if obj is None:
319326
return self
320-
return _fast_combine_chunks(obj.table[self.name])
327+
cached = obj._array_cache.get(self.name)
328+
if cached is not None:
329+
return cached # type: ignore[return-value]
330+
arr = _fast_combine_chunks(obj.table[self.name])
331+
obj._array_cache[self.name] = arr
332+
return arr
321333

322334

323335
class Int32Column(Column):
@@ -347,7 +359,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.Int32Array: ...
347359
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int32Array]:
348360
if obj is None:
349361
return self
350-
return _fast_combine_chunks(obj.table[self.name])
362+
cached = obj._array_cache.get(self.name)
363+
if cached is not None:
364+
return cached # type: ignore[return-value]
365+
arr = _fast_combine_chunks(obj.table[self.name])
366+
obj._array_cache[self.name] = arr
367+
return arr
351368

352369

353370
class Int64Column(Column):
@@ -377,7 +394,13 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.Int64Array: ...
377394
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.Int64Array]:
378395
if obj is None:
379396
return self
380-
return _fast_combine_chunks(obj.table[self.name])
397+
# Array access cache: avoid recombining chunks repeatedly
398+
cached = obj._array_cache.get(self.name)
399+
if cached is not None:
400+
return cached # type: ignore[return-value]
401+
arr = _fast_combine_chunks(obj.table[self.name])
402+
obj._array_cache[self.name] = arr
403+
return arr
381404

382405

383406
class UInt8Column(Column):
@@ -407,7 +430,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt8Array: ...
407430
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt8Array]:
408431
if obj is None:
409432
return self
410-
return _fast_combine_chunks(obj.table[self.name])
433+
cached = obj._array_cache.get(self.name)
434+
if cached is not None:
435+
return cached # type: ignore[return-value]
436+
arr = _fast_combine_chunks(obj.table[self.name])
437+
obj._array_cache[self.name] = arr
438+
return arr
411439

412440

413441
class UInt16Column(Column):
@@ -437,7 +465,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt16Array: ...
437465
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt16Array]:
438466
if obj is None:
439467
return self
440-
return _fast_combine_chunks(obj.table[self.name])
468+
cached = obj._array_cache.get(self.name)
469+
if cached is not None:
470+
return cached # type: ignore[return-value]
471+
arr = _fast_combine_chunks(obj.table[self.name])
472+
obj._array_cache[self.name] = arr
473+
return arr
441474

442475

443476
class UInt32Column(Column):
@@ -467,7 +500,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt32Array: ...
467500
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt32Array]:
468501
if obj is None:
469502
return self
470-
return _fast_combine_chunks(obj.table[self.name])
503+
cached = obj._array_cache.get(self.name)
504+
if cached is not None:
505+
return cached # type: ignore[return-value]
506+
arr = _fast_combine_chunks(obj.table[self.name])
507+
obj._array_cache[self.name] = arr
508+
return arr
471509

472510

473511
class UInt64Column(Column):
@@ -497,7 +535,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.UInt64Array: ...
497535
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.UInt64Array]:
498536
if obj is None:
499537
return self
500-
return _fast_combine_chunks(obj.table[self.name])
538+
cached = obj._array_cache.get(self.name)
539+
if cached is not None:
540+
return cached # type: ignore[return-value]
541+
arr = _fast_combine_chunks(obj.table[self.name])
542+
obj._array_cache[self.name] = arr
543+
return arr
501544

502545

503546
class Float16Column(Column):
@@ -535,7 +578,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.HalfFloatArray: ..
535578
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.lib.HalfFloatArray]:
536579
if obj is None:
537580
return self
538-
return _fast_combine_chunks(obj.table[self.name])
581+
cached = obj._array_cache.get(self.name)
582+
if cached is not None:
583+
return cached # type: ignore[return-value]
584+
arr = _fast_combine_chunks(obj.table[self.name])
585+
obj._array_cache[self.name] = arr
586+
return arr
539587

540588

541589
class Float32Column(Column):
@@ -565,7 +613,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.FloatArray: ...
565613
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.lib.FloatArray]:
566614
if obj is None:
567615
return self
568-
return _fast_combine_chunks(obj.table[self.name])
616+
cached = obj._array_cache.get(self.name)
617+
if cached is not None:
618+
return cached # type: ignore[return-value]
619+
arr = _fast_combine_chunks(obj.table[self.name])
620+
obj._array_cache[self.name] = arr
621+
return arr
569622

570623

571624
class Float64Column(Column):
@@ -595,7 +648,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.lib.DoubleArray: ...
595648
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.lib.DoubleArray]:
596649
if obj is None:
597650
return self
598-
return _fast_combine_chunks(obj.table[self.name])
651+
cached = obj._array_cache.get(self.name)
652+
if cached is not None:
653+
return cached # type: ignore[return-value]
654+
arr = _fast_combine_chunks(obj.table[self.name])
655+
obj._array_cache[self.name] = arr
656+
return arr
599657

600658

601659
class BooleanColumn(Column):
@@ -623,7 +681,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.BooleanArray: ...
623681
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[pa.BooleanArray, Self]:
624682
if obj is None:
625683
return self
626-
return _fast_combine_chunks(obj.table[self.name])
684+
cached = obj._array_cache.get(self.name)
685+
if cached is not None:
686+
return cached # type: ignore[return-value]
687+
arr = _fast_combine_chunks(obj.table[self.name])
688+
obj._array_cache[self.name] = arr
689+
return arr
627690

628691

629692
class StringColumn(Column):
@@ -657,7 +720,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.StringArray: ...
657720
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[pa.StringArray, Self]:
658721
if obj is None:
659722
return self
660-
return _fast_combine_chunks(obj.table[self.name])
723+
cached = obj._array_cache.get(self.name)
724+
if cached is not None:
725+
return cached # type: ignore[return-value]
726+
arr = _fast_combine_chunks(obj.table[self.name])
727+
obj._array_cache[self.name] = arr
728+
return arr
661729

662730

663731
class LargeBinaryColumn(Column):
@@ -686,7 +754,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeBinaryArray: ...
686754
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.LargeBinaryArray]:
687755
if obj is None:
688756
return self
689-
return _fast_combine_chunks(obj.table[self.name])
757+
cached = obj._array_cache.get(self.name)
758+
if cached is not None:
759+
return cached # type: ignore[return-value]
760+
arr = _fast_combine_chunks(obj.table[self.name])
761+
obj._array_cache[self.name] = arr
762+
return arr
690763

691764

692765
class LargeStringColumn(Column):
@@ -715,7 +788,12 @@ def __get__(self, obj: tables.Table, objtype: type) -> pa.LargeStringArray: ...
715788
def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, pa.LargeStringArray]:
716789
if obj is None:
717790
return self
718-
return _fast_combine_chunks(obj.table[self.name])
791+
cached = obj._array_cache.get(self.name)
792+
if cached is not None:
793+
return cached # type: ignore[return-value]
794+
arr = _fast_combine_chunks(obj.table[self.name])
795+
obj._array_cache[self.name] = arr
796+
return arr
719797

720798

721799
class Date32Column(Column):

quivr/tables.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ def __init_subclass__(cls: Type["Table"], **kwargs: Any):
135135

136136
def __init__(self, table: pa.Table, **kwargs: AttributeValueType):
137137
self.table = table
138+
# Per-instance caches for frequently accessed derived values.
139+
# Tables are immutable, so these caches are safe to maintain.
140+
self._array_cache: dict[str, pa.Array] = {}
141+
self._subtable_cache: dict[str, "Table"] = {}
142+
self._column_metadata_cache: dict[str, dict[bytes, bytes]] = {}
138143
for name, value in kwargs.items():
139144
if name in self._quivr_attributes:
140145
setattr(self, name, value)
@@ -1159,13 +1164,19 @@ def _unpack_string_metadata(cls, metadata: dict[str, str]) -> dict[bytes, bytes]
11591164

11601165
def _metadata_for_column(self, column_name: str) -> dict[bytes, bytes]:
11611166
"""Return a dictionary of metadata associated with a subtable column."""
1167+
# Serve from cache if available
1168+
cached = self._column_metadata_cache.get(column_name)
1169+
if cached is not None:
1170+
return cached
11621171
result: dict[bytes, bytes] = {}
11631172
if self.table.schema.metadata is None:
1173+
self._column_metadata_cache[column_name] = result
11641174
return result
11651175
column_name_bytes = (column_name + ".").encode("utf8")
11661176
for key, value in self.table.schema.metadata.items():
11671177
if key.startswith(column_name_bytes):
11681178
result[key[len(column_name_bytes) :]] = value
1179+
self._column_metadata_cache[column_name] = result
11691180
return result
11701181

11711182
@classmethod

0 commit comments

Comments
 (0)