Skip to content

Commit 1cc199f

Browse files
committed
Adding testing for import_data. Fixing bugs found in import_data from testing. Closes #12.
1 parent 5eca702 commit 1cc199f

File tree

2 files changed

+123
-13
lines changed

2 files changed

+123
-13
lines changed

TensorToolbox/import_data.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def import_data(filename):
2424

2525
if data_type == 'tensor':
2626

27-
assert False, f"{data_type} is not currently allowed"
27+
shape = import_shape(fp)
28+
data = import_array(fp, np.prod(shape))
29+
return ttb.tensor().from_data(data, shape)
2830

2931
elif data_type == 'sptensor':
3032

@@ -35,26 +37,22 @@ def import_data(filename):
3537

3638
elif data_type == 'matrix':
3739

38-
assert False, f"{data_type} is not currently allowed"
40+
shape = import_shape(fp)
41+
mat = import_array(fp, np.prod(shape))
42+
mat = np.reshape(mat, np.array(shape))
43+
return mat
3944

4045
elif data_type == 'ktensor':
4146

4247
shape = import_shape(fp)
43-
#print(f"shape: {shape}")
4448
r = import_rank(fp)
45-
#print(f"rank: {r}")
46-
weights = np.array(fp.readline().strip().split(' '),dtype="float")
47-
#print(f"weights: {weights}")
49+
weights = import_array(fp, r)
4850
factor_matrices = []
4951
for n in range(len(shape)):
5052
fac_type = fp.readline().strip()
51-
#print(f"fac_type: {fac_type}")
5253
fac_shape = import_shape(fp)
53-
#print(f"fac_shape: {fac_shape}")
54-
fac = np.zeros(fac_shape, dtype="float")
55-
for r in range(fac_shape[0]):
56-
fac[r,:] = fp.readline().strip().split(' ')
57-
#print(f"fac: {fac}")
54+
fac = import_array(fp, np.prod(fac_shape))
55+
fac = np.reshape(fac, np.array(fac_shape))
5856
factor_matrices.append(fac)
5957
return ttb.ktensor().from_data(weights, factor_matrices)
6058

@@ -87,6 +85,10 @@ def import_sparse_array(fp, n, nz):
8785
vals = np.zeros((nz, 1))
8886
for k in range(nz):
8987
line = fp.readline().strip().split(' ')
90-
subs[k,:] = line[:-1]
88+
# 1-based indexing in file, 0-based indexing in package
89+
subs[k,:] = [np.int64(i)-1 for i in line[:-1]]
9190
vals[k,0] = line[-1]
9291
return subs, vals
92+
93+
def import_array(fp, n):
94+
return np.fromfile(fp, count=n, sep=' ')

tests/test_import_export_data.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2022 National Technology & Engineering Solutions of Sandia,
2+
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
3+
# U.S. Government retains certain rights in this software.
4+
5+
import numpy as np
6+
import pytest
7+
import os
8+
import TensorToolbox as ttb
9+
10+
@pytest.fixture()
11+
def sample_tensor_2way():
12+
data = np.array([[1., 2., 3.], [4., 5., 6.]])
13+
shape = (2, 3)
14+
params = {'data':data, 'shape': shape}
15+
tensorInstance = ttb.tensor().from_data(data, shape)
16+
return params, tensorInstance
17+
18+
@pytest.fixture()
19+
def sample_tensor_3way():
20+
data = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
21+
shape = (2, 3, 2)
22+
params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape}
23+
tensorInstance = ttb.tensor().from_data(data, shape)
24+
return params, tensorInstance
25+
26+
@pytest.fixture()
27+
def sample_tensor_4way():
28+
data = np.arange(1, 82)
29+
shape = (3, 3, 3, 3)
30+
params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape}
31+
tensorInstance = ttb.tensor().from_data(data, shape)
32+
return params, tensorInstance
33+
34+
@pytest.mark.indevelopment
35+
def test_import_data_tensor():
36+
# truth data
37+
T = ttb.tensor.from_data(np.ones((3,3,3)), (3,3,3))
38+
39+
# imported data
40+
data_filename = os.path.join(os.path.dirname(__file__),'data','tensor.tns')
41+
X = ttb.import_data(data_filename)
42+
43+
assert X.shape == (3, 3, 3)
44+
assert T.isequal(X)
45+
46+
@pytest.mark.indevelopment
47+
def test_import_data_sptensor():
48+
# truth data
49+
subs = np.array([[0, 0, 0],[0, 2, 2],[1, 1, 1],[1, 2, 0],[1, 2, 1],[1, 2, 2],
50+
[1, 3, 1],[2, 0, 0],[2, 0, 1],[2, 2, 0],[2, 2, 1],[2, 3, 0],
51+
[2, 3, 2],[3, 0, 0],[3, 0, 1],[3, 2, 0],[4, 0, 2],[4, 3, 2]])
52+
vals = np.reshape(np.array(range(1,19)),(18,1))
53+
shape = (5, 4, 3)
54+
S = ttb.sptensor().from_data(subs, vals, shape)
55+
56+
# imported data
57+
data_filename = os.path.join(os.path.dirname(__file__),'data','sptensor.tns')
58+
X = ttb.import_data(data_filename)
59+
60+
assert S.isequal(X)
61+
62+
@pytest.mark.indevelopment
63+
def test_import_data_ktensor():
64+
# truth data
65+
weights = np.array([3, 2])
66+
fm0 = np.array([[1., 5.], [2., 6.], [3., 7.], [4., 8.]])
67+
fm1 = np.array([[ 2., 7.], [ 3., 8.], [ 4., 9.], [ 5., 10.], [ 6., 11.]])
68+
fm2 = np.array([[3., 6.], [4., 7.], [5., 8.]])
69+
factor_matrices = [fm0, fm1, fm2]
70+
K = ttb.ktensor.from_data(weights, factor_matrices)
71+
72+
# imported data
73+
data_filename = os.path.join(os.path.dirname(__file__),'data','ktensor.tns')
74+
X = ttb.import_data(data_filename)
75+
76+
assert K.isequal(X)
77+
78+
@pytest.mark.indevelopment
79+
def test_import_data_array():
80+
# truth data
81+
M = np.array([[1., 5.], [2., 6.], [3., 7.], [4., 8.]])
82+
print('\nM')
83+
print(M)
84+
85+
# imported data
86+
data_filename = os.path.join(os.path.dirname(__file__),'data','matrix.tns')
87+
X = ttb.import_data(data_filename)
88+
print('\nX')
89+
print(X)
90+
91+
assert (M == X).all()
92+
93+
@pytest.mark.indevelopment
94+
def test_export_data_tensor():
95+
pass
96+
97+
@pytest.mark.indevelopment
98+
def test_export_data_sptensor():
99+
pass
100+
101+
@pytest.mark.indevelopment
102+
def test_export_data_ktensor():
103+
pass
104+
105+
@pytest.mark.indevelopment
106+
def test_export_data_array():
107+
pass
108+

0 commit comments

Comments
 (0)