4
4
import pandas as pd
5
5
import scipy .sparse as sp
6
6
7
+ try :
8
+ import torch
9
+ except :
10
+ pass
11
+
12
+
7
13
modALinput = Union [sp .csr_matrix , pd .DataFrame , np .ndarray , list ]
8
14
9
15
@@ -26,7 +32,13 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
26
32
elif isinstance (blocks [0 ], list ):
27
33
return np .concatenate (blocks ).tolist ()
28
34
29
- raise TypeError ('%s datatype is not supported' % type (blocks [0 ]))
35
+ try :
36
+ if torch .is_tensor (blocks [0 ]):
37
+ return torch .cat (blocks )
38
+ except :
39
+ pass
40
+
41
+ raise TypeError ("%s datatype is not supported" % type (blocks [0 ]))
30
42
31
43
32
44
def data_hstack (blocks : Sequence [modALinput ]) -> modALinput :
@@ -48,7 +60,13 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
48
60
elif isinstance (blocks [0 ], list ):
49
61
return np .hstack (blocks ).tolist ()
50
62
51
- TypeError ('%s datatype is not supported' % type (blocks [0 ]))
63
+ try :
64
+ if torch .is_tensor (blocks [0 ]):
65
+ return torch .cat (blocks , dim = 1 )
66
+ except :
67
+ pass
68
+
69
+ TypeError ("%s datatype is not supported" % type (blocks [0 ]))
52
70
53
71
54
72
def add_row (X : modALinput , row : modALinput ):
@@ -68,8 +86,9 @@ def add_row(X: modALinput, row: modALinput):
68
86
return data_vstack ([X , row ])
69
87
70
88
71
- def retrieve_rows (X : modALinput ,
72
- I : Union [int , List [int ], np .ndarray ]) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
89
+ def retrieve_rows (
90
+ X : modALinput , I : Union [int , List [int ], np .ndarray ]
91
+ ) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
73
92
"""
74
93
Returns the rows I from the data set X
75
94
@@ -78,34 +97,34 @@ def retrieve_rows(X: modALinput,
78
97
* pandas series in case of a pandas data frame
79
98
* row in case of list or numpy format
80
99
"""
81
- if sp .issparse (X ):
82
- # Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
83
- # sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
84
- # and sp.dia_matrix don't support indexing and need to be converted to a sparse format
85
- # that does support indexing. It seems conversion to CSR is currently most efficient.
86
-
87
- try :
88
- return X [I ]
89
- except :
90
- sp_format = X .getformat ()
91
- return X .tocsr ()[I ].asformat (sp_format )
92
- elif isinstance (X , pd .DataFrame ):
93
- return X .iloc [I ]
94
- elif isinstance (X , list ):
95
- return np .array (X )[I ].tolist ()
96
- elif isinstance (X , dict ):
97
- X_return = {}
98
- for key , value in X .items ():
99
- X_return [key ] = retrieve_rows (value , I )
100
- return X_return
101
- elif isinstance (X , np .ndarray ):
102
- return X [I ]
103
-
104
- raise TypeError ('%s datatype is not supported' % type (X ))
105
100
101
+ try :
102
+ return X [I ]
103
+ except :
104
+ if sp .issparse (X ):
105
+ # Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
106
+ # sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
107
+ # and sp.dia_matrix don't support indexing and need to be converted to a sparse format
108
+ # that does support indexing. It seems conversion to CSR is currently most efficient.
106
109
107
- def drop_rows (X : modALinput ,
108
- I : Union [int , List [int ], np .ndarray ]) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
110
+ sp_format = X .getformat ()
111
+ return X .tocsr ()[I ].asformat (sp_format )
112
+ elif isinstance (X , pd .DataFrame ):
113
+ return X .iloc [I ]
114
+ elif isinstance (X , list ):
115
+ return np .array (X )[I ].tolist ()
116
+ elif isinstance (X , dict ):
117
+ X_return = {}
118
+ for key , value in X .items ():
119
+ X_return [key ] = retrieve_rows (value , I )
120
+ return X_return
121
+
122
+ raise TypeError ("%s datatype is not supported" % type (X ))
123
+
124
+
125
+ def drop_rows (
126
+ X : modALinput , I : Union [int , List [int ], np .ndarray ]
127
+ ) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
109
128
"""
110
129
Returns X without the row(s) at index/indices I
111
130
"""
@@ -120,7 +139,13 @@ def drop_rows(X: modALinput,
120
139
elif isinstance (X , list ):
121
140
return np .delete (X , I , axis = 0 ).tolist ()
122
141
123
- raise TypeError ('%s datatype is not supported' % type (X ))
142
+ try :
143
+ if torch .is_tensor (blocks [0 ]):
144
+ return torch .cat (blocks )
145
+ except :
146
+ X [[True if row not in I else False for row in range (X .size (0 ))]]
147
+
148
+ raise TypeError ("%s datatype is not supported" % type (X ))
124
149
125
150
126
151
def enumerate_data (X : modALinput ):
@@ -141,17 +166,18 @@ def enumerate_data(X: modALinput):
141
166
# numpy arrays and lists can readily be enumerated
142
167
return enumerate (X )
143
168
144
- raise TypeError (' %s datatype is not supported' % type (X ))
169
+ raise TypeError (" %s datatype is not supported" % type (X ))
145
170
146
171
147
172
def data_shape (X : modALinput ):
148
173
"""
149
174
Returns the shape of the data set X
150
175
"""
151
- if sp . issparse ( X ) or isinstance ( X , pd . DataFrame ) or isinstance ( X , np . ndarray ) :
152
- # scipy.sparse, pandas and numpy all support .shape
176
+ try :
177
+ # scipy.sparse, torch, pandas and numpy all support .shape
153
178
return X .shape
154
- elif isinstance (X , list ):
155
- return np .array (X ).shape
179
+ except :
180
+ if isinstance (X , list ):
181
+ return np .array (X ).shape
156
182
157
- raise TypeError (' %s datatype is not supported' % type (X ))
183
+ raise TypeError (" %s datatype is not supported" % type (X ))
0 commit comments