Skip to content

Commit c28a4bd

Browse files
authored
Change SVD type in pod.py (#449)
* Change SVD type in pod.py
1 parent d94256f commit c28a4bd

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

pina/model/layers/pod.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from .stride import Stride
66
from .utils_convolution import optimizing
7-
7+
import warnings
88

99
class PODBlock(torch.nn.Module):
1010
"""
@@ -85,15 +85,15 @@ def scale_coefficients(self):
8585
"""
8686
return self.__scale_coefficients
8787

88-
def fit(self, X):
88+
def fit(self, X, randomized=True):
8989
"""
9090
Set the POD basis by performing the singular value decomposition of the
9191
given tensor. If `self.scale_coefficients` is True, the coefficients
9292
are scaled after the projection to have zero mean and unit variance.
9393
9494
:param torch.Tensor X: The tensor to be reduced.
9595
"""
96-
self._fit_pod(X)
96+
self._fit_pod(X, randomized)
9797

9898
if self.__scale_coefficients:
9999
self._fit_scaler(torch.matmul(self._basis, X.T))
@@ -112,16 +112,24 @@ def _fit_scaler(self, coeffs):
112112
"mean": torch.mean(coeffs, dim=1),
113113
}
114114

115-
def _fit_pod(self, X):
115+
def _fit_pod(self, X, randomized):
116116
"""
117117
Private method that computes the POD basis of the given tensor and stores it in the private member `_basis`.
118118
119119
:param torch.Tensor X: The tensor to be reduced.
120120
"""
121121
if X.device.type == "mps": # svd_lowrank not arailable for mps
122+
warnings.warn(
123+
"svd_lowrank not available for mps, using svd instead."
124+
"This may slow down computations.", ResourceWarning
125+
)
122126
self._basis = torch.svd(X.T)[0].T
123127
else:
124-
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
128+
if randomized:
129+
warnings.warn("Considering a randomized algorithm to compute the POD basis")
130+
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
131+
else:
132+
self._basis = torch.svd(X.T)[0].T
125133

126134
def forward(self, X):
127135
"""

tests/test_layers/test_pod.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def test_fit(rank, scale):
2525

2626
@pytest.mark.parametrize("scale", [True, False])
2727
@pytest.mark.parametrize("rank", [1, 2, 10])
28-
def test_fit(rank, scale):
28+
@pytest.mark.parametrize("randomized", [True, False])
29+
def test_fit(rank, scale, randomized):
2930
pod = PODBlock(rank, scale)
30-
pod.fit(toy_snapshots)
31+
pod.fit(toy_snapshots, randomized)
3132
n_snap = toy_snapshots.shape[0]
3233
dof = toy_snapshots.shape[1]
3334
assert pod.basis.shape == (rank, dof)
@@ -65,18 +66,20 @@ def test_forward():
6566

6667
@pytest.mark.parametrize("scale", [True, False])
6768
@pytest.mark.parametrize("rank", [1, 2, 10])
68-
def test_expand(rank, scale):
69+
@pytest.mark.parametrize("randomized", [True, False])
70+
def test_expand(rank, scale, randomized):
6971
pod = PODBlock(rank, scale)
70-
pod.fit(toy_snapshots)
72+
pod.fit(toy_snapshots, randomized)
7173
c = pod(toy_snapshots)
7274
torch.testing.assert_close(pod.expand(c), toy_snapshots)
7375
torch.testing.assert_close(pod.expand(c[0]), toy_snapshots[0].unsqueeze(0))
7476

7577
@pytest.mark.parametrize("scale", [True, False])
7678
@pytest.mark.parametrize("rank", [1, 2, 10])
77-
def test_reduce_expand(rank, scale):
79+
@pytest.mark.parametrize("randomized", [True, False])
80+
def test_reduce_expand(rank, scale, randomized):
7881
pod = PODBlock(rank, scale)
79-
pod.fit(toy_snapshots)
82+
pod.fit(toy_snapshots, randomized)
8083
torch.testing.assert_close(
8184
pod.expand(pod.reduce(toy_snapshots)),
8285
toy_snapshots)

0 commit comments

Comments
 (0)