Skip to content

Commit d1de535

Browse files
committed
Tweak test tolerance
1 parent 095923d commit d1de535

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_ml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,22 @@ def test_mlp_training():
101101

102102

103103
def test_adam_optimizer():
104-
grid = Grid[Chebyshev](lower=(-1.0,), upper=(1.0,), shape=(64,))
104+
grid = Grid[Chebyshev](lower=(-1.0,), upper=(1.0,), shape=(128,))
105105

106106
x = compute_mesh(grid)
107107

108108
y = vmap(g)(x.flatten()).reshape(x.shape)
109109

110110
topology = (4, 4)
111111
model = MLP(1, 1, topology)
112-
optimizer = JaxOptimizer(jopt.adam)
112+
optimizer = JaxOptimizer(jopt.adam, tol=1e-6)
113113
fit = build_fitting_function(model, optimizer)
114114

115115
params, layout = unpack(model.parameters)
116116
params = fit(params, x, y).params
117117
y_model = model.apply(pack(params, layout), x)
118118

119-
assert np.linalg.norm(y - y_model).item() / x.size < 5e-4
119+
assert np.linalg.norm(y - y_model).item() / x.size < 1e-2
120120

121121
x_plot = np.linspace(-1, 1, 512)
122122
fig, ax = plt.subplots()

0 commit comments

Comments
 (0)