Skip to content

Commit 53e07f4

Browse files
committed
working - but extremly primitive - method to sample and batch arrayset groups
1 parent 60300b6 commit 53e07f4

File tree

5 files changed

+346
-102
lines changed

5 files changed

+346
-102
lines changed

src/hangar/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
def raise_ImportError(message, *args, **kwargs): # pragma: no cover
99
raise ImportError(message)
1010

11+
from .dataloaders.tfloader import make_tf_dataset
12+
from .dataloaders.torchloader import make_torch_dataset
1113

12-
try: # pragma: no cover
13-
from .dataloaders.tfloader import make_tf_dataset
14-
except ImportError: # pragma: no cover
15-
make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")
14+
# try: # pragma: no cover
15+
# from .dataloaders.tfloader import make_tf_dataset
16+
# except ImportError: # pragma: no cover
17+
# make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")
1618

17-
try: # pragma: no cover
18-
from .dataloaders.torchloader import make_torch_dataset
19-
except ImportError: # pragma: no cover
20-
make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")
19+
# try: # pragma: no cover
20+
# from .dataloaders.torchloader import make_torch_dataset
21+
# except ImportError: # pragma: no cover
22+
# make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")

src/hangar/arrayset.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
from .records.parsing import arrayset_record_schema_db_val_from_raw_val
2828

2929

30-
CompatibleArray = NamedTuple(
31-
'CompatibleArray', [('compatible', bool), ('reason', str)])
30+
CompatibleArray = NamedTuple('CompatibleArray', [
31+
('compatible', bool),
32+
('reason', str)])
3233

3334

3435
class ArraysetDataReader(object):
@@ -305,18 +306,6 @@ def backend_opts(self):
305306
"""
306307
return self._dflt_backend_opts
307308

308-
@property
309-
def sample_classes(self):
310-
grouped_spec_names = defaultdict(list)
311-
for name, bespec in self._sspecs.items():
312-
grouped_spec_names[bespec].append(name)
313-
314-
grouped_data_names = {}
315-
for spec, names in grouped_spec_names.items():
316-
data = self._fs[spec.backend].read_data(spec)
317-
grouped_data_names[tuple(data.tolist())] = names
318-
return grouped_data_names
319-
320309
def keys(self, local: bool = False) -> Iterator[Union[str, int]]:
321310
"""generator which yields the names of every sample in the arrayset
322311

src/hangar/dataloaders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .grouper import GroupedArraysetDataReader
2+
3+
__all__ = ['GroupedArraysetDataReader']

src/hangar/dataloaders/grouper.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
3+
from ..arrayset import ArraysetDataReader
4+
5+
from collections import defaultdict
6+
import hashlib
7+
from typing import Sequence, Union, Iterable, NamedTuple
8+
import struct
9+
10+
11+
# -------------------------- typehints ---------------------------------------
12+
13+
14+
ArraysetSampleNames = Sequence[Union[str, int]]
15+
16+
SampleGroup = NamedTuple('SampleGroup', [
17+
('group', np.ndarray),
18+
('samples', Union[str, int])])
19+
20+
21+
# ------------------------------------------------------------------------------
22+
23+
24+
def _calculate_hash_digest(data: np.ndarray) -> str:
25+
hasher = hashlib.blake2b(data, digest_size=20)
26+
hasher.update(struct.pack(f'<{len(data.shape)}QB', *data.shape, data.dtype.num))
27+
digest = hasher.hexdigest()
28+
return digest
29+
30+
31+
class FakeNumpyKeyDict(object):
32+
def __init__(self, group_spec_samples, group_spec_value, group_digest_spec):
33+
self._group_spec_samples = group_spec_samples
34+
self._group_spec_value = group_spec_value
35+
self._group_digest_spec = group_digest_spec
36+
37+
def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames:
38+
digest = _calculate_hash_digest(key)
39+
spec = self._group_digest_spec[digest]
40+
samples = self._group_spec_samples[spec]
41+
return samples
42+
43+
def get(self, key: np.ndarray) -> ArraysetSampleNames:
44+
return self.__getitem__(key)
45+
46+
def __setitem__(self, key, val):
47+
raise PermissionError('Not User Editable')
48+
49+
def __delitem__(self, key):
50+
raise PermissionError('Not User Editable')
51+
52+
def __len__(self) -> int:
53+
return len(self._group_digest_spec)
54+
55+
def __contains__(self, key: np.ndarray) -> bool:
56+
digest = _calculate_hash_digest(key)
57+
res = True if digest in self._group_digest_spec else False
58+
return res
59+
60+
def __iter__(self) -> Iterable[np.ndarray]:
61+
for spec in self._group_digest_spec.values():
62+
yield self._group_spec_value[spec]
63+
64+
def keys(self) -> Iterable[np.ndarray]:
65+
for spec in self._group_digest_spec.values():
66+
yield self._group_spec_value[spec]
67+
68+
def values(self) -> Iterable[ArraysetSampleNames]:
69+
for spec in self._group_digest_spec.values():
70+
yield self._group_spec_samples[spec]
71+
72+
def items(self) -> Iterable[ArraysetSampleNames]:
73+
for spec in self._group_digest_spec.values():
74+
yield (self._group_spec_value[spec], self._group_spec_samples[spec])
75+
76+
def __repr__(self):
77+
print('Mapping: Group Data Value -> Sample Name')
78+
for k, v in self.items():
79+
print(k, v)
80+
81+
def _repr_pretty_(self, p, cycle):
82+
res = f'Mapping: Group Data Value -> Sample Name \n'
83+
for k, v in self.items():
84+
res += f'\n {k} :: {v}'
85+
p.text(res)
86+
87+
88+
89+
# ---------------------------- MAIN METHOD ------------------------------------
90+
91+
92+
class GroupedArraysetDataReader(object):
93+
'''Pass in an arrayset and automatically find sample groups.
94+
'''
95+
96+
def __init__(self, arrayset: ArraysetDataReader, *args, **kwargs):
97+
98+
self.__arrayset = arrayset # TODO: Do we actually need to keep this around?
99+
self._group_spec_samples = defaultdict(list)
100+
self._group_spec_value = {}
101+
self._group_digest_spec = {}
102+
103+
self._setup()
104+
self._group_samples = FakeNumpyKeyDict(
105+
self._group_spec_samples,
106+
self._group_spec_value,
107+
self._group_digest_spec)
108+
109+
def _setup(self):
110+
for name, bespec in self.__arrayset._sspecs.items():
111+
self._group_spec_samples[bespec].append(name)
112+
for spec, names in self._group_spec_samples.items():
113+
data = self.__arrayset._fs[spec.backend].read_data(spec)
114+
self._group_spec_value[spec] = data
115+
digest = _calculate_hash_digest(data)
116+
self._group_digest_spec[digest] = spec
117+
118+
@property
119+
def groups(self) -> Iterable[np.ndarray]:
120+
for spec in self._group_digest_spec.values():
121+
yield self._group_spec_value[spec]
122+
123+
@property
124+
def group_samples(self):
125+
return self._group_samples

0 commit comments

Comments
 (0)