|
| 1 | +# Copyright 2022 National Technology & Engineering Solutions of Sandia, |
| 2 | +# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the |
| 3 | +# U.S. Government retains certain rights in this software. |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pytest |
| 7 | +import os |
| 8 | +import TensorToolbox as ttb |
| 9 | + |
| 10 | +@pytest.fixture() |
| 11 | +def sample_tensor_2way(): |
| 12 | + data = np.array([[1., 2., 3.], [4., 5., 6.]]) |
| 13 | + shape = (2, 3) |
| 14 | + params = {'data':data, 'shape': shape} |
| 15 | + tensorInstance = ttb.tensor().from_data(data, shape) |
| 16 | + return params, tensorInstance |
| 17 | + |
| 18 | +@pytest.fixture() |
| 19 | +def sample_tensor_3way(): |
| 20 | + data = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]) |
| 21 | + shape = (2, 3, 2) |
| 22 | + params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape} |
| 23 | + tensorInstance = ttb.tensor().from_data(data, shape) |
| 24 | + return params, tensorInstance |
| 25 | + |
| 26 | +@pytest.fixture() |
| 27 | +def sample_tensor_4way(): |
| 28 | + data = np.arange(1, 82) |
| 29 | + shape = (3, 3, 3, 3) |
| 30 | + params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape} |
| 31 | + tensorInstance = ttb.tensor().from_data(data, shape) |
| 32 | + return params, tensorInstance |
| 33 | + |
| 34 | +@pytest.mark.indevelopment |
| 35 | +def test_import_data_tensor(): |
| 36 | + # truth data |
| 37 | + T = ttb.tensor.from_data(np.ones((3,3,3)), (3,3,3)) |
| 38 | + |
| 39 | + # imported data |
| 40 | + data_filename = os.path.join(os.path.dirname(__file__),'data','tensor.tns') |
| 41 | + X = ttb.import_data(data_filename) |
| 42 | + |
| 43 | + assert X.shape == (3, 3, 3) |
| 44 | + assert T.isequal(X) |
| 45 | + |
| 46 | +@pytest.mark.indevelopment |
| 47 | +def test_import_data_sptensor(): |
| 48 | + # truth data |
| 49 | + subs = np.array([[0, 0, 0],[0, 2, 2],[1, 1, 1],[1, 2, 0],[1, 2, 1],[1, 2, 2], |
| 50 | + [1, 3, 1],[2, 0, 0],[2, 0, 1],[2, 2, 0],[2, 2, 1],[2, 3, 0], |
| 51 | + [2, 3, 2],[3, 0, 0],[3, 0, 1],[3, 2, 0],[4, 0, 2],[4, 3, 2]]) |
| 52 | + vals = np.reshape(np.array(range(1,19)),(18,1)) |
| 53 | + shape = (5, 4, 3) |
| 54 | + S = ttb.sptensor().from_data(subs, vals, shape) |
| 55 | + |
| 56 | + # imported data |
| 57 | + data_filename = os.path.join(os.path.dirname(__file__),'data','sptensor.tns') |
| 58 | + X = ttb.import_data(data_filename) |
| 59 | + |
| 60 | + assert S.isequal(X) |
| 61 | + |
| 62 | +@pytest.mark.indevelopment |
| 63 | +def test_import_data_ktensor(): |
| 64 | + # truth data |
| 65 | + weights = np.array([3, 2]) |
| 66 | + fm0 = np.array([[1., 5.], [2., 6.], [3., 7.], [4., 8.]]) |
| 67 | + fm1 = np.array([[ 2., 7.], [ 3., 8.], [ 4., 9.], [ 5., 10.], [ 6., 11.]]) |
| 68 | + fm2 = np.array([[3., 6.], [4., 7.], [5., 8.]]) |
| 69 | + factor_matrices = [fm0, fm1, fm2] |
| 70 | + K = ttb.ktensor.from_data(weights, factor_matrices) |
| 71 | + |
| 72 | + # imported data |
| 73 | + data_filename = os.path.join(os.path.dirname(__file__),'data','ktensor.tns') |
| 74 | + X = ttb.import_data(data_filename) |
| 75 | + |
| 76 | + assert K.isequal(X) |
| 77 | + |
| 78 | +@pytest.mark.indevelopment |
| 79 | +def test_import_data_array(): |
| 80 | + # truth data |
| 81 | + M = np.array([[1., 5.], [2., 6.], [3., 7.], [4., 8.]]) |
| 82 | + print('\nM') |
| 83 | + print(M) |
| 84 | + |
| 85 | + # imported data |
| 86 | + data_filename = os.path.join(os.path.dirname(__file__),'data','matrix.tns') |
| 87 | + X = ttb.import_data(data_filename) |
| 88 | + print('\nX') |
| 89 | + print(X) |
| 90 | + |
| 91 | + assert (M == X).all() |
| 92 | + |
| 93 | +@pytest.mark.indevelopment |
| 94 | +def test_export_data_tensor(): |
| 95 | + pass |
| 96 | + |
| 97 | +@pytest.mark.indevelopment |
| 98 | +def test_export_data_sptensor(): |
| 99 | + pass |
| 100 | + |
| 101 | +@pytest.mark.indevelopment |
| 102 | +def test_export_data_ktensor(): |
| 103 | + pass |
| 104 | + |
| 105 | +@pytest.mark.indevelopment |
| 106 | +def test_export_data_array(): |
| 107 | + pass |
| 108 | + |
0 commit comments