Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/hangar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
def raise_ImportError(message, *args, **kwargs):
raise ImportError(message)


try:
from .dataloaders.tfloader import make_tf_dataset
except ImportError:
Expand Down
6 changes: 4 additions & 2 deletions src/hangar/arrayset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import os
import warnings
from multiprocessing import cpu_count, get_context
Expand Down Expand Up @@ -26,8 +27,9 @@
from .records.parsing import arrayset_record_schema_db_val_from_raw_val


CompatibleArray = NamedTuple(
'CompatibleArray', [('compatible', bool), ('reason', str)])
CompatibleArray = NamedTuple('CompatibleArray', [
('compatible', bool),
('reason', str)])


class ArraysetDataReader(object):
Expand Down
3 changes: 3 additions & 0 deletions src/hangar/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .grouper import GroupedArraysetDataReader

__all__ = ['GroupedArraysetDataReader']
8 changes: 4 additions & 4 deletions src/hangar/dataloaders/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ def __init__(self,
if len(arraysets) == 0:
raise ValueError('len(arraysets) cannot == 0')

aset_lens = set()
# aset_lens = set()
all_keys = []
all_remote_keys = []
for aset in arraysets:
if aset.iswriteable is True:
raise TypeError(f'Cannot load arraysets opened in `write-enabled` checkout.')
self.arrayset_array.append(aset)
self.arrayset_names.append(aset.name)
aset_lens.add(len(aset))
# aset_lens.add(len(aset))
all_keys.append(set(aset.keys()))
all_remote_keys.append(set(aset.remote_reference_keys))

if len(aset_lens) > 1:
warnings.warn('Arraysets do not contain equal number of samples', UserWarning)
# if len(aset_lens) > 1:
# warnings.warn('Arraysets do not contain equal number of samples', UserWarning)

common_keys = set.intersection(*all_keys)
remote_keys = set.union(*all_remote_keys)
Expand Down
116 changes: 116 additions & 0 deletions src/hangar/dataloaders/grouper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np

from ..arrayset import ArraysetDataReader
from ..records.hashmachine import array_hash_digest

from collections import defaultdict
from typing import Sequence, Union, Iterable, NamedTuple, Tuple


# -------------------------- typehints ---------------------------------------


ArraysetSampleNames = Sequence[Union[str, int]]

SampleGroup = NamedTuple('SampleGroup', [
('group', np.ndarray),
('samples', Union[str, int])])


# ------------------------------------------------------------------------------


class FakeNumpyKeyDict(object):
def __init__(self, group_spec_samples, group_spec_value, group_digest_spec):
self._group_spec_samples = group_spec_samples
self._group_spec_value = group_spec_value
self._group_digest_spec = group_digest_spec

def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames:
digest = array_hash_digest(key)
spec = self._group_digest_spec[digest]
samples = self._group_spec_samples[spec]
return samples

def get(self, key: np.ndarray) -> ArraysetSampleNames:
return self.__getitem__(key)

def __setitem__(self, key, val):
raise PermissionError('Not User Editable')

def __delitem__(self, key):
raise PermissionError('Not User Editable')

def __len__(self) -> int:
return len(self._group_digest_spec)

def __contains__(self, key: np.ndarray) -> bool:
digest = array_hash_digest(key)
res = True if digest in self._group_digest_spec else False
return res

def __iter__(self) -> Iterable[np.ndarray]:
for spec in self._group_digest_spec.values():
yield self._group_spec_value[spec]

def keys(self) -> Iterable[np.ndarray]:
for spec in self._group_digest_spec.values():
yield self._group_spec_value[spec]

def values(self) -> Iterable[ArraysetSampleNames]:
for spec in self._group_digest_spec.values():
yield self._group_spec_samples[spec]

def items(self) -> Iterable[Tuple[np.ndarray, ArraysetSampleNames]]:
for spec in self._group_digest_spec.values():
yield (self._group_spec_value[spec], self._group_spec_samples[spec])

def __repr__(self):
print('Mapping: Group Data Value -> Sample Name')
for k, v in self.items():
print(k, v)

def _repr_pretty_(self, p, cycle):
res = f'Mapping: Group Data Value -> Sample Name \n'
for k, v in self.items():
res += f'\n {k} :: {v} \n'
p.text(res)


# ---------------------------- MAIN METHOD ------------------------------------


class GroupedArraysetDataReader(object):
'''Pass in an arrayset and automatically find sample groups.
'''

def __init__(self, arrayset: ArraysetDataReader, *args, **kwargs):

self.__arrayset = arrayset # TODO: Do we actually need to keep this around?
self._group_spec_samples = defaultdict(list)
self._group_spec_value = {}
self._group_digest_spec = {}

self._setup()
self._group_samples = FakeNumpyKeyDict(
self._group_spec_samples,
self._group_spec_value,
self._group_digest_spec)

def _setup(self):
for name, bespec in self.__arrayset._sspecs.items():
self._group_spec_samples[bespec].append(name)
for spec, names in self._group_spec_samples.items():
data = self.__arrayset._fs[spec.backend].read_data(spec)
self._group_spec_value[spec] = data
digest = array_hash_digest(data)
self._group_digest_spec[digest] = spec

@property
def groups(self) -> Iterable[np.ndarray]:
for spec in self._group_digest_spec.values():
yield self._group_spec_value[spec]

@property
def group_samples(self):
return self._group_samples
Loading