Skip to content

Commit 87c5c6a

Browse files
pass method argument to fast laplacian (#648)
1 parent dc808c1 commit 87c5c6a

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

pina/operator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def fast_laplacian(output_, input_, components, d, method="std"):
221221
divergence of the gradient. Default is ``std``.
222222
:return: The computed laplacian tensor.
223223
:rtype: LabelTensor
224+
:raises ValueError: If the passed method is neither ``std`` nor ``divgrad``.
224225
"""
225226
# Scalar laplacian
226227
if output_.shape[-1] == 1:
@@ -415,8 +416,13 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
415416
components, d = _check_values(
416417
output_=output_, input_=input_, components=components, d=d
417418
)
419+
418420
return fast_laplacian(
419-
output_=output_, input_=input_, components=components, d=d
421+
output_=output_,
422+
input_=input_,
423+
components=components,
424+
d=d,
425+
method=method,
420426
)
421427

422428

tests/test_operator.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def test_divergence(f):
253253
Function(),
254254
ids=["scalar_scalar", "scalar_vector", "vector_scalar", "vector_vector"],
255255
)
256-
def test_laplacian(f):
256+
@pytest.mark.parametrize("method", ["std", "divgrad"])
257+
def test_laplacian(f, method):
257258

258259
# Unpack the function
259260
func_input, func, _, _, func_lap = f
@@ -265,7 +266,7 @@ def test_laplacian(f):
265266
output_ = LabelTensor(output_, labels)
266267

267268
# Compute the true laplacian and the pina laplacian
268-
pina_lap = laplacian(output_=output_, input_=input_)
269+
pina_lap = laplacian(output_=output_, input_=input_, method=method)
269270
true_lap = func_lap(input_)
270271

271272
# Check the shape and labels of the laplacian
@@ -276,24 +277,34 @@ def test_laplacian(f):
276277
assert torch.allclose(pina_lap, true_lap)
277278

278279
# Test if labels are handled correctly
279-
laplacian(output_=output_, input_=input_, components=output_.labels[0])
280-
laplacian(output_=output_, input_=input_, d=input_.labels[0])
280+
laplacian(
281+
output_=output_,
282+
input_=input_,
283+
components=output_.labels[0],
284+
method=method,
285+
)
286+
laplacian(output_=output_, input_=input_, d=input_.labels[0], method=method)
281287

282288
# Should fail if input not a LabelTensor
283289
with pytest.raises(TypeError):
284-
laplacian(output_=output_, input_=input_.tensor)
290+
laplacian(output_=output_, input_=input_.tensor, method=method)
285291

286292
# Should fail if output not a LabelTensor
287293
with pytest.raises(TypeError):
288-
laplacian(output_=output_.tensor, input_=input_)
294+
laplacian(output_=output_.tensor, input_=input_, method=method)
289295

290296
# Should fail for non-existent input labels
291297
with pytest.raises(RuntimeError):
292-
laplacian(output_=output_, input_=input_, d=["x", "y"])
298+
laplacian(output_=output_, input_=input_, d=["x", "y"], method=method)
293299

294300
# Should fail for non-existent output labels
295301
with pytest.raises(RuntimeError):
296-
laplacian(output_=output_, input_=input_, components=["a", "b", "c"])
302+
laplacian(
303+
output_=output_,
304+
input_=input_,
305+
components=["a", "b", "c"],
306+
method=method,
307+
)
297308

298309

299310
def test_advection_scalar():

0 commit comments

Comments
 (0)