Skip to content

Commit 1f31ad5

Browse files
committed
working - but extremly primitive - method to sample and batch arrayset groups
1 parent b12812c commit 1f31ad5

File tree

5 files changed

+346
-103
lines changed

5 files changed

+346
-103
lines changed

src/hangar/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
def raise_ImportError(message, *args, **kwargs): # pragma: no cover
99
raise ImportError(message)
1010

11+
from .dataloaders.tfloader import make_tf_dataset
12+
from .dataloaders.torchloader import make_torch_dataset
1113

12-
try: # pragma: no cover
13-
from .dataloaders.tfloader import make_tf_dataset
14-
except ImportError: # pragma: no cover
15-
make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")
14+
# try: # pragma: no cover
15+
# from .dataloaders.tfloader import make_tf_dataset
16+
# except ImportError: # pragma: no cover
17+
# make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")
1618

17-
try: # pragma: no cover
18-
from .dataloaders.torchloader import make_torch_dataset
19-
except ImportError: # pragma: no cover
20-
make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")
19+
# try: # pragma: no cover
20+
# from .dataloaders.torchloader import make_torch_dataset
21+
# except ImportError: # pragma: no cover
22+
# make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")

src/hangar/arrayset.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections import defaultdict
21
import hashlib
32
import os
43
import warnings
@@ -30,8 +29,9 @@
3029
from .records.parsing import arrayset_record_schema_db_val_from_raw_val
3130

3231

33-
CompatibleArray = NamedTuple(
34-
'CompatibleArray', [('compatible', bool), ('reason', str)])
32+
CompatibleArray = NamedTuple('CompatibleArray', [
33+
('compatible', bool),
34+
('reason', str)])
3535

3636

3737
class ArraysetDataReader(object):
@@ -308,18 +308,6 @@ def backend_opts(self):
308308
"""
309309
return self._dflt_backend_opts
310310

311-
@property
312-
def sample_classes(self):
313-
grouped_spec_names = defaultdict(list)
314-
for name, bespec in self._sspecs.items():
315-
grouped_spec_names[bespec].append(name)
316-
317-
grouped_data_names = {}
318-
for spec, names in grouped_spec_names.items():
319-
data = self._fs[spec.backend].read_data(spec)
320-
grouped_data_names[tuple(data.tolist())] = names
321-
return grouped_data_names
322-
323311
def keys(self, local: bool = False) -> Iterator[Union[str, int]]:
324312
"""generator which yields the names of every sample in the arrayset
325313

src/hangar/dataloaders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .grouper import GroupedArraysetDataReader
2+
3+
__all__ = ['GroupedArraysetDataReader']

src/hangar/dataloaders/grouper.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
3+
from ..arrayset import ArraysetDataReader
4+
5+
from collections import defaultdict
6+
import hashlib
7+
from typing import Sequence, Union, Iterable, NamedTuple
8+
import struct
9+
10+
11+
# -------------------------- typehints ---------------------------------------
12+
13+
14+
ArraysetSampleNames = Sequence[Union[str, int]]
15+
16+
SampleGroup = NamedTuple('SampleGroup', [
17+
('group', np.ndarray),
18+
('samples', Union[str, int])])
19+
20+
21+
# ------------------------------------------------------------------------------
22+
23+
24+
def _calculate_hash_digest(data: np.ndarray) -> str:
25+
hasher = hashlib.blake2b(data, digest_size=20)
26+
hasher.update(struct.pack(f'<{len(data.shape)}QB', *data.shape, data.dtype.num))
27+
digest = hasher.hexdigest()
28+
return digest
29+
30+
31+
class FakeNumpyKeyDict(object):
32+
def __init__(self, group_spec_samples, group_spec_value, group_digest_spec):
33+
self._group_spec_samples = group_spec_samples
34+
self._group_spec_value = group_spec_value
35+
self._group_digest_spec = group_digest_spec
36+
37+
def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames:
38+
digest = _calculate_hash_digest(key)
39+
spec = self._group_digest_spec[digest]
40+
samples = self._group_spec_samples[spec]
41+
return samples
42+
43+
def get(self, key: np.ndarray) -> ArraysetSampleNames:
44+
return self.__getitem__(key)
45+
46+
def __setitem__(self, key, val):
47+
raise PermissionError('Not User Editable')
48+
49+
def __delitem__(self, key):
50+
raise PermissionError('Not User Editable')
51+
52+
def __len__(self) -> int:
53+
return len(self._group_digest_spec)
54+
55+
def __contains__(self, key: np.ndarray) -> bool:
56+
digest = _calculate_hash_digest(key)
57+
res = True if digest in self._group_digest_spec else False
58+
return res
59+
60+
def __iter__(self) -> Iterable[np.ndarray]:
61+
for spec in self._group_digest_spec.values():
62+
yield self._group_spec_value[spec]
63+
64+
def keys(self) -> Iterable[np.ndarray]:
65+
for spec in self._group_digest_spec.values():
66+
yield self._group_spec_value[spec]
67+
68+
def values(self) -> Iterable[ArraysetSampleNames]:
69+
for spec in self._group_digest_spec.values():
70+
yield self._group_spec_samples[spec]
71+
72+
def items(self) -> Iterable[ArraysetSampleNames]:
73+
for spec in self._group_digest_spec.values():
74+
yield (self._group_spec_value[spec], self._group_spec_samples[spec])
75+
76+
def __repr__(self):
77+
print('Mapping: Group Data Value -> Sample Name')
78+
for k, v in self.items():
79+
print(k, v)
80+
81+
def _repr_pretty_(self, p, cycle):
82+
res = f'Mapping: Group Data Value -> Sample Name \n'
83+
for k, v in self.items():
84+
res += f'\n {k} :: {v}'
85+
p.text(res)
86+
87+
88+
89+
# ---------------------------- MAIN METHOD ------------------------------------
90+
91+
92+
class GroupedArraysetDataReader(object):
93+
'''Pass in an arrayset and automatically find sample groups.
94+
'''
95+
96+
def __init__(self, arrayset: ArraysetDataReader, *args, **kwargs):
97+
98+
self.__arrayset = arrayset # TODO: Do we actually need to keep this around?
99+
self._group_spec_samples = defaultdict(list)
100+
self._group_spec_value = {}
101+
self._group_digest_spec = {}
102+
103+
self._setup()
104+
self._group_samples = FakeNumpyKeyDict(
105+
self._group_spec_samples,
106+
self._group_spec_value,
107+
self._group_digest_spec)
108+
109+
def _setup(self):
110+
for name, bespec in self.__arrayset._sspecs.items():
111+
self._group_spec_samples[bespec].append(name)
112+
for spec, names in self._group_spec_samples.items():
113+
data = self.__arrayset._fs[spec.backend].read_data(spec)
114+
self._group_spec_value[spec] = data
115+
digest = _calculate_hash_digest(data)
116+
self._group_digest_spec[digest] = spec
117+
118+
@property
119+
def groups(self) -> Iterable[np.ndarray]:
120+
for spec in self._group_digest_spec.values():
121+
yield self._group_spec_value[spec]
122+
123+
@property
124+
def group_samples(self):
125+
return self._group_samples

0 commit comments

Comments
 (0)