Skip to content

Commit e63339f

Browse files
committed
updates
1 parent 426d3dd commit e63339f

File tree

2 files changed

+38
-42
lines changed

2 files changed

+38
-42
lines changed

src/hangar/dataloaders/grouper.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import numpy as np
22

33
from ..arrayset import ArraysetDataReader
4+
from ..records.hashmachine import array_hash_digest
45

56
from collections import defaultdict
6-
import hashlib
7-
from typing import Sequence, Union, Iterable, NamedTuple
8-
import struct
7+
from typing import Sequence, Union, Iterable, NamedTuple, Tuple
98

109

1110
# -------------------------- typehints ---------------------------------------
@@ -21,21 +20,14 @@
2120
# ------------------------------------------------------------------------------
2221

2322

24-
def _calculate_hash_digest(data: np.ndarray) -> str:
25-
hasher = hashlib.blake2b(data, digest_size=20)
26-
hasher.update(struct.pack(f'<{len(data.shape)}QB', *data.shape, data.dtype.num))
27-
digest = hasher.hexdigest()
28-
return digest
29-
30-
3123
class FakeNumpyKeyDict(object):
3224
def __init__(self, group_spec_samples, group_spec_value, group_digest_spec):
3325
self._group_spec_samples = group_spec_samples
3426
self._group_spec_value = group_spec_value
3527
self._group_digest_spec = group_digest_spec
3628

3729
def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames:
38-
digest = _calculate_hash_digest(key)
30+
digest = array_hash_digest(key)
3931
spec = self._group_digest_spec[digest]
4032
samples = self._group_spec_samples[spec]
4133
return samples
@@ -53,7 +45,7 @@ def __len__(self) -> int:
5345
return len(self._group_digest_spec)
5446

5547
def __contains__(self, key: np.ndarray) -> bool:
56-
digest = _calculate_hash_digest(key)
48+
digest = array_hash_digest(key)
5749
res = True if digest in self._group_digest_spec else False
5850
return res
5951

@@ -69,7 +61,7 @@ def values(self) -> Iterable[ArraysetSampleNames]:
6961
for spec in self._group_digest_spec.values():
7062
yield self._group_spec_samples[spec]
7163

72-
def items(self) -> Iterable[ArraysetSampleNames]:
64+
def items(self) -> Iterable[Tuple[np.ndarray, ArraysetSampleNames]]:
7365
for spec in self._group_digest_spec.values():
7466
yield (self._group_spec_value[spec], self._group_spec_samples[spec])
7567

@@ -81,11 +73,10 @@ def __repr__(self):
8173
def _repr_pretty_(self, p, cycle):
8274
res = f'Mapping: Group Data Value -> Sample Name \n'
8375
for k, v in self.items():
84-
res += f'\n {k} :: {v}'
76+
res += f'\n {k} :: {v} \n'
8577
p.text(res)
8678

8779

88-
8980
# ---------------------------- MAIN METHOD ------------------------------------
9081

9182

@@ -112,7 +103,7 @@ def _setup(self):
112103
for spec, names in self._group_spec_samples.items():
113104
data = self.__arrayset._fs[spec.backend].read_data(spec)
114105
self._group_spec_value[spec] = data
115-
digest = _calculate_hash_digest(data)
106+
digest = array_hash_digest(data)
116107
self._group_digest_spec[digest] = spec
117108

118109
@property

src/hangar/dataloaders/sampler.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Sequence, Union, List, Iterable
22

33
import numpy as np
4+
import numpy.random
45

56
from ..arrayset import ArraysetDataReader
67

@@ -93,32 +94,36 @@ def __init__(self, data_source):
9394
def __iter__(self):
9495
raise NotImplementedError
9596

96-
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
97-
#
98-
# Many times we have an abstract class representing a collection/iterable of
99-
# data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
100-
# implementing a `__len__` method. In such cases, we must make sure to not
101-
# provide a default implementation, because both straightforward default
102-
# implementations have their issues:
103-
#
104-
# + `return NotImplemented`:
105-
# Calling `len(subclass_instance)` raises:
106-
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
107-
#
108-
# + `raise NotImplementedError()`:
109-
# This prevents triggering some fallback behavior. E.g., the built-in
110-
# `list(X)` tries to call `len(X)` first, and executes a different code
111-
# path if the method is not found or `NotImplemented` is returned, while
112-
# raising an `NotImplementedError` will propagate and and make the call
113-
# fail where it could have use `__iter__` to complete the call.
114-
#
115-
# Thus, the only two sensible things to do are
116-
#
117-
# + **not** provide a default `__len__`.
118-
#
119-
# + raise a `TypeError` instead, which is what Python uses when users call
120-
# a method that is not defined on an object.
121-
# (@ssnl verifies that this works on at least Python 3.7.)
97+
def __len__(self):
98+
"""
99+
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
100+
#
101+
# Many times we have an abstract class representing a collection/iterable of
102+
# data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
103+
# implementing a `__len__` method. In such cases, we must make sure to not
104+
# provide a default implementation, because both straightforward default
105+
# implementations have their issues:
106+
#
107+
# + `return NotImplemented`:
108+
# Calling `len(subclass_instance)` raises:
109+
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
110+
#
111+
# + `raise NotImplementedError()`:
112+
# This prevents triggering some fallback behavior. E.g., the built-in
113+
# `list(X)` tries to call `len(X)` first, and executes a different code
114+
# path if the method is not found or `NotImplemented` is returned, while
115+
# raising an `NotImplementedError` will propagate and and make the call
116+
# fail where it could have use `__iter__` to complete the call.
117+
#
118+
# Thus, the only two sensible things to do are
119+
#
120+
# + **not** provide a default `__len__`.
121+
#
122+
# + raise a `TypeError` instead, which is what Python uses when users call
123+
# a method that is not defined on an object.
124+
# (@ssnl verifies that this works on at least Python 3.7.)
125+
"""
126+
raise TypeError
122127

123128

124129
class SequentialSampler(Sampler):

0 commit comments

Comments
 (0)