@@ -25,9 +25,10 @@ def test_fit(rank, scale):
2525
2626@pytest .mark .parametrize ("scale" , [True , False ])
2727@pytest .mark .parametrize ("rank" , [1 , 2 , 10 ])
28- def test_fit (rank , scale ):
28+ @pytest .mark .parametrize ("randomized" , [True , False ])
29+ def test_fit (rank , scale , randomized ):
2930 pod = PODBlock (rank , scale )
30- pod .fit (toy_snapshots )
31+ pod .fit (toy_snapshots , randomized )
3132 n_snap = toy_snapshots .shape [0 ]
3233 dof = toy_snapshots .shape [1 ]
3334 assert pod .basis .shape == (rank , dof )
@@ -65,18 +66,20 @@ def test_forward():
6566
6667@pytest .mark .parametrize ("scale" , [True , False ])
6768@pytest .mark .parametrize ("rank" , [1 , 2 , 10 ])
68- def test_expand (rank , scale ):
69+ @pytest .mark .parametrize ("randomized" , [True , False ])
70+ def test_expand (rank , scale , randomized ):
6971 pod = PODBlock (rank , scale )
70- pod .fit (toy_snapshots )
72+ pod .fit (toy_snapshots , randomized )
7173 c = pod (toy_snapshots )
7274 torch .testing .assert_close (pod .expand (c ), toy_snapshots )
7375 torch .testing .assert_close (pod .expand (c [0 ]), toy_snapshots [0 ].unsqueeze (0 ))
7476
7577@pytest .mark .parametrize ("scale" , [True , False ])
7678@pytest .mark .parametrize ("rank" , [1 , 2 , 10 ])
77- def test_reduce_expand (rank , scale ):
79+ @pytest .mark .parametrize ("randomized" , [True , False ])
80+ def test_reduce_expand (rank , scale , randomized ):
7881 pod = PODBlock (rank , scale )
79- pod .fit (toy_snapshots )
82+ pod .fit (toy_snapshots , randomized )
8083 torch .testing .assert_close (
8184 pod .expand (pod .reduce (toy_snapshots )),
8285 toy_snapshots )
0 commit comments