Skip to content

Commit 87f64be

Browse files
committed
torch data utils added, torch is an optional dependency of the module
1 parent c15d44a commit 87f64be

File tree

2 files changed

+70
-37
lines changed

2 files changed

+70
-37
lines changed

modAL/utils/data.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import pandas as pd
55
import scipy.sparse as sp
66

7+
try:
8+
import torch
9+
except:
10+
pass
11+
12+
713
modALinput = Union[sp.csr_matrix, pd.DataFrame, np.ndarray, list]
814

915

@@ -26,7 +32,13 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
2632
elif isinstance(blocks[0], list):
2733
return np.concatenate(blocks).tolist()
2834

29-
raise TypeError('%s datatype is not supported' % type(blocks[0]))
35+
try:
36+
if torch.is_tensor(blocks[0]):
37+
return torch.cat(blocks)
38+
except:
39+
pass
40+
41+
raise TypeError("%s datatype is not supported" % type(blocks[0]))
3042

3143

3244
def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
@@ -48,7 +60,13 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
4860
elif isinstance(blocks[0], list):
4961
return np.hstack(blocks).tolist()
5062

51-
TypeError('%s datatype is not supported' % type(blocks[0]))
63+
try:
64+
if torch.is_tensor(blocks[0]):
65+
return torch.cat(blocks, dim=1)
66+
except:
67+
pass
68+
69+
TypeError("%s datatype is not supported" % type(blocks[0]))
5270

5371

5472
def add_row(X: modALinput, row: modALinput):
@@ -68,8 +86,9 @@ def add_row(X: modALinput, row: modALinput):
6886
return data_vstack([X, row])
6987

7088

71-
def retrieve_rows(X: modALinput,
72-
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
89+
def retrieve_rows(
90+
X: modALinput, I: Union[int, List[int], np.ndarray]
91+
) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
7392
"""
7493
Returns the rows I from the data set X
7594
@@ -78,34 +97,34 @@ def retrieve_rows(X: modALinput,
7897
* pandas series in case of a pandas data frame
7998
* row in case of list or numpy format
8099
"""
81-
if sp.issparse(X):
82-
# Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
83-
# sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
84-
# and sp.dia_matrix don't support indexing and need to be converted to a sparse format
85-
# that does support indexing. It seems conversion to CSR is currently most efficient.
86-
87-
try:
88-
return X[I]
89-
except:
90-
sp_format = X.getformat()
91-
return X.tocsr()[I].asformat(sp_format)
92-
elif isinstance(X, pd.DataFrame):
93-
return X.iloc[I]
94-
elif isinstance(X, list):
95-
return np.array(X)[I].tolist()
96-
elif isinstance(X, dict):
97-
X_return = {}
98-
for key, value in X.items():
99-
X_return[key] = retrieve_rows(value, I)
100-
return X_return
101-
elif isinstance(X, np.ndarray):
102-
return X[I]
103-
104-
raise TypeError('%s datatype is not supported' % type(X))
105100

101+
try:
102+
return X[I]
103+
except:
104+
if sp.issparse(X):
105+
# Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
106+
# sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
107+
# and sp.dia_matrix don't support indexing and need to be converted to a sparse format
108+
# that does support indexing. It seems conversion to CSR is currently most efficient.
106109

107-
def drop_rows(X: modALinput,
108-
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
110+
sp_format = X.getformat()
111+
return X.tocsr()[I].asformat(sp_format)
112+
elif isinstance(X, pd.DataFrame):
113+
return X.iloc[I]
114+
elif isinstance(X, list):
115+
return np.array(X)[I].tolist()
116+
elif isinstance(X, dict):
117+
X_return = {}
118+
for key, value in X.items():
119+
X_return[key] = retrieve_rows(value, I)
120+
return X_return
121+
122+
raise TypeError("%s datatype is not supported" % type(X))
123+
124+
125+
def drop_rows(
126+
X: modALinput, I: Union[int, List[int], np.ndarray]
127+
) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
109128
"""
110129
Returns X without the row(s) at index/indices I
111130
"""
@@ -120,7 +139,13 @@ def drop_rows(X: modALinput,
120139
elif isinstance(X, list):
121140
return np.delete(X, I, axis=0).tolist()
122141

123-
raise TypeError('%s datatype is not supported' % type(X))
142+
try:
143+
if torch.is_tensor(blocks[0]):
144+
return torch.cat(blocks)
145+
except:
146+
X[[True if row not in I else False for row in range(X.size(0))]]
147+
148+
raise TypeError("%s datatype is not supported" % type(X))
124149

125150

126151
def enumerate_data(X: modALinput):
@@ -141,17 +166,18 @@ def enumerate_data(X: modALinput):
141166
# numpy arrays and lists can readily be enumerated
142167
return enumerate(X)
143168

144-
raise TypeError('%s datatype is not supported' % type(X))
169+
raise TypeError("%s datatype is not supported" % type(X))
145170

146171

147172
def data_shape(X: modALinput):
148173
"""
149174
Returns the shape of the data set X
150175
"""
151-
if sp.issparse(X) or isinstance(X, pd.DataFrame) or isinstance(X, np.ndarray):
152-
# scipy.sparse, pandas and numpy all support .shape
176+
try:
177+
# scipy.sparse, torch, pandas and numpy all support .shape
153178
return X.shape
154-
elif isinstance(X, list):
155-
return np.array(X).shape
179+
except:
180+
if isinstance(X, list):
181+
return np.array(X).shape
156182

157-
raise TypeError('%s datatype is not supported' % type(X))
183+
raise TypeError("%s datatype is not supported" % type(X))

tests/core_tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ def test_data_vstack(self):
189189
np.concatenate((a, b))
190190
)
191191

192+
# torch.Tensors
193+
a, b = torch.ones(2, 2), torch.ones(2, 2)
194+
torch.testing.assert_allclose(
195+
modAL.utils.data.data_vstack((a, b)),
196+
torch.cat((a, b))
197+
)
198+
192199
# not supported formats
193200
self.assertRaises(TypeError, modAL.utils.data.data_vstack, (1, 1))
194201

0 commit comments

Comments
 (0)