Skip to content

Commit c15d44a

Browse files
author
Max Keller
committed
Remove torch dependency from data.py
1 parent 855be41 commit c15d44a

File tree

5 files changed

+4
-25
lines changed

5 files changed

+4
-25
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ after_success:
2222
matrix:
2323
include:
2424
install:
25-
- pip install numpy==1.20 scikit-learn==0.18 scipy==0.18
25+
- pip install numpy==1.20 scikit-learn==0.18 scipy==0.18 torch==1.8.1
2626
- pip install codecov
2727
- pip install coverage
2828
- pip install .

modAL/utils/data.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import pandas as pd
55
import scipy.sparse as sp
6-
import torch
76

87
modALinput = Union[sp.csr_matrix, pd.DataFrame, np.ndarray, list]
98

@@ -26,8 +25,6 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
2625
return np.concatenate(blocks)
2726
elif isinstance(blocks[0], list):
2827
return np.concatenate(blocks).tolist()
29-
elif torch.is_tensor(blocks[0]):
30-
return torch.cat(blocks)
3128

3229
raise TypeError('%s datatype is not supported' % type(blocks[0]))
3330

@@ -50,8 +47,6 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
5047
return np.hstack(blocks)
5148
elif isinstance(blocks[0], list):
5249
return np.hstack(blocks).tolist()
53-
elif torch.is_tensor(blocks[0]):
54-
return torch.cat(blocks, dim=1)
5550

5651
TypeError('%s datatype is not supported' % type(blocks[0]))
5752

@@ -65,8 +60,6 @@ def add_row(X: modALinput, row: modALinput):
6560
row] """
6661
if isinstance(X, np.ndarray):
6762
return np.vstack((X, row))
68-
elif torch.is_tensor(X):
69-
return torch.cat((X, row))
7063
elif isinstance(X, list):
7164
return np.vstack((X, row)).tolist()
7265

@@ -107,8 +100,6 @@ def retrieve_rows(X: modALinput,
107100
return X_return
108101
elif isinstance(X, np.ndarray):
109102
return X[I]
110-
elif torch.is_tensor(X):
111-
return X[I]
112103

113104
raise TypeError('%s datatype is not supported' % type(X))
114105

@@ -128,9 +119,6 @@ def drop_rows(X: modALinput,
128119
return np.delete(X, I, axis=0)
129120
elif isinstance(X, list):
130121
return np.delete(X, I, axis=0).tolist()
131-
elif torch.is_tensor(X):
132-
return X[[True if row not in I else False
133-
for row in range(X.size(0))]]
134122

135123
raise TypeError('%s datatype is not supported' % type(X))
136124

@@ -149,8 +137,8 @@ def enumerate_data(X: modALinput):
149137
return enumerate(X.tocsr())
150138
elif isinstance(X, pd.DataFrame):
151139
return X.iterrows()
152-
elif isinstance(X, np.ndarray) or isinstance(X, list) or torch.is_tensor(X):
153-
# numpy arrays, torch tensors and lists can readily be enumerated
140+
elif isinstance(X, np.ndarray) or isinstance(X, list):
141+
# numpy arrays and lists can readily be enumerated
154142
return enumerate(X)
155143

156144
raise TypeError('%s datatype is not supported' % type(X))
@@ -165,7 +153,5 @@ def data_shape(X: modALinput):
165153
return X.shape
166154
elif isinstance(X, list):
167155
return np.array(X).shape
168-
elif torch.is_tensor(X):
169-
return tuple(X.size())
170156

171157
raise TypeError('%s datatype is not supported' % type(X))

rtd_requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ ipykernel
55
nbsphinx
66
pandas
77
skorch
8-
torch

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
packages=['modAL', 'modAL.models', 'modAL.utils'],
1212
classifiers=['Development Status :: 4 - Beta'],
1313
install_requires=['numpy==1.20.0', 'scikit-learn>=0.18',
14-
'scipy>=0.18', 'pandas>=1.1.0', 'skorch==0.9.0', 'torch>=1.8.1'],
14+
'scipy>=0.18', 'pandas>=1.1.0', 'skorch==0.9.0'],
1515
)

tests/core_tests.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,6 @@ def test_data_vstack(self):
181181
self.assertEqual((modAL.utils.data.data_vstack(
182182
(a, b)) != sp.vstack((a, b))).sum(), 0)
183183

184-
# pytorch tensors
185-
a, b = torch.randn(n_samples, n_features), torch.randn(
186-
n_samples, n_features)
187-
self.assertTrue(
188-
torch.equal(modAL.utils.data.data_vstack((a, b)), torch.cat((a, b))))
189-
190184
# lists
191185
a, b = np.random.rand(n_samples, n_features).tolist(), np.random.rand(
192186
n_samples, n_features).tolist()

0 commit comments

Comments
 (0)