Skip to content

Commit 92dd4ed

Browse files
committed
FEAT: implement different algorithms for NumericalIntegral
1 parent 4c1c131 commit 92dd4ed

File tree

6 files changed

+250
-40
lines changed

6 files changed

+250
-40
lines changed

docs/usage/sympy.ipynb

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,137 @@
399399
")"
400400
]
401401
},
402+
{
403+
"cell_type": "markdown",
404+
"metadata": {},
405+
"source": [
406+
"Note that the choice of {attr}`.NumericalIntegral.algorithm` is important. For vectorized integrals with large input arrays, it is better to use [JAX](https://docs.jax.dev) as a numerical backend and any of the [`quadax`](http://quadax.readthedocs.io) algorithms as the integration method.\n",
407+
"\n",
408+
"Below, we see a comparison between the {func}`quadax.romberg` and {func}`quadax.quadgk` algorithms for a contour integral in the complex plane. The {func}`quadax.quadgk` algorithm is much faster than the {func}`quadax.romberg` algorithm, but is less accurate when $a$ is close to the contour."
409+
]
410+
},
411+
{
412+
"cell_type": "code",
413+
"execution_count": null,
414+
"metadata": {
415+
"tags": [
416+
"hide-input"
417+
]
418+
},
419+
"outputs": [],
420+
"source": [
421+
"a, z = sp.symbols(\"a z\")\n",
422+
"f = z**3 / (z - a)\n",
423+
"f"
424+
]
425+
},
426+
{
427+
"cell_type": "code",
428+
"execution_count": null,
429+
"metadata": {},
430+
"outputs": [],
431+
"source": [
432+
"r, phi = sp.symbols(\"r phi\")\n",
433+
"f_C = f.subs(z, r * sp.exp(sp.I * phi))\n",
434+
"quadgk_expr = NumericalIntegral(f_C, (phi, 0, 2 * sp.pi), algorithm=\"quadax.quadgk\")\n",
435+
"romberg_expr = NumericalIntegral(f_C, (phi, 0, 2 * sp.pi), algorithm=\"quadax.romberg\")\n",
436+
"romberg_expr"
437+
]
438+
},
439+
{
440+
"cell_type": "code",
441+
"execution_count": null,
442+
"metadata": {},
443+
"outputs": [],
444+
"source": [
445+
"args = (a, r)\n",
446+
"quadgk_func = sp.lambdify(args, quadgk_expr, modules=\"jax\")\n",
447+
"romberg_func = sp.lambdify(args, romberg_expr, modules=\"jax\")"
448+
]
449+
},
450+
{
451+
"cell_type": "code",
452+
"execution_count": null,
453+
"metadata": {
454+
"tags": [
455+
"hide-input"
456+
]
457+
},
458+
"outputs": [],
459+
"source": [
460+
"import time\n",
461+
"\n",
462+
"import jax.numpy as jnp\n",
463+
"\n",
464+
"x_max = 4\n",
465+
"X, Y = jnp.meshgrid(\n",
466+
" jnp.linspace(-x_max, +x_max, num=300),\n",
467+
" jnp.linspace(-x_max, +x_max, num=300),\n",
468+
")\n",
469+
"start = time.perf_counter()\n",
470+
"Z = quadgk_func(a=X + Y * 1j, r=3).block_until_ready()\n",
471+
"end = time.perf_counter()\n",
472+
"print(f\"Computation took {end - start:.2f} seconds for {X.shape[0]}x{X.shape[1]} grid\")\n",
473+
"\n",
474+
"z_max = 30\n",
475+
"fig, axes = plt.subplots(ncols=2, figsize=(8.5, 5), sharey=True)\n",
476+
"fig.suptitle(\"Integral $I_a$ computed with quadax.quadgk\", y=0.95)\n",
477+
"ax_real, ax_imag = axes\n",
478+
"ax_real.set_title(\"Real part of $I_a$\")\n",
479+
"ax_imag.set_title(\"Imaginary part of $I_a$\")\n",
480+
"ax_real.set_xlabel(\"Re $a$\")\n",
481+
"ax_real.set_ylabel(\"Im $a$\")\n",
482+
"ax_imag.set_xlabel(\"Re $a$\")\n",
483+
"for ax, Z_proj in zip(axes, [Z.real, Z.imag], strict=True):\n",
484+
" ax.pcolormesh(\n",
485+
" X, Y, Z_proj, cmap=\"RdBu_r\", rasterized=True, vmin=-z_max, vmax=+z_max\n",
486+
" )\n",
487+
"fig.tight_layout()\n",
488+
"plt.show()"
489+
]
490+
},
491+
{
492+
"cell_type": "code",
493+
"execution_count": null,
494+
"metadata": {
495+
"tags": [
496+
"hide-input"
497+
]
498+
},
499+
"outputs": [],
500+
"source": [
501+
"X, Y = jnp.meshgrid(\n",
502+
" jnp.linspace(-x_max, +x_max, num=50),\n",
503+
" jnp.linspace(-x_max, +x_max, num=50),\n",
504+
")\n",
505+
"start = time.perf_counter()\n",
506+
"Z = romberg_func(a=X + Y * 1j, r=3).block_until_ready()\n",
507+
"end = time.perf_counter()\n",
508+
"print(f\"Computation took {end - start:.2f} seconds for {X.shape[0]}x{X.shape[1]} grid\")\n",
509+
"\n",
510+
"fig, axes = plt.subplots(ncols=2, figsize=(8.5, 5), sharey=True)\n",
511+
"fig.suptitle(\"Integral $I_a$ computed with quadax.romberg\", y=0.95)\n",
512+
"ax_real, ax_imag = axes\n",
513+
"ax_real.set_title(\"Real part\")\n",
514+
"ax_imag.set_title(\"Imaginary part\")\n",
515+
"ax_real.set_xlabel(\"Re $a$\")\n",
516+
"ax_real.set_ylabel(\"Im $a$\")\n",
517+
"ax_imag.set_xlabel(\"Re $a$\")\n",
518+
"for ax, z in zip(axes, [Z.real, Z.imag], strict=True):\n",
519+
" ax.pcolormesh(X, Y, z, cmap=\"RdBu_r\", rasterized=True, vmin=-z_max, vmax=+z_max)\n",
520+
"fig.tight_layout()\n",
521+
"plt.show()"
522+
]
523+
},
524+
{
525+
"cell_type": "markdown",
526+
"metadata": {},
527+
"source": [
528+
":::{seealso}\n",
529+
"The performance of different integration algorithms for computing the dispersion integral in different scenarios is investigated in [this notebook](./dynamics/integration-algorithms.ipynb).\n",
530+
":::"
531+
]
532+
},
402533
{
403534
"cell_type": "markdown",
404535
"metadata": {},

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ test = [
118118
"pytest-cov",
119119
"pytest-profiling",
120120
"pytest-xdist",
121+
"quadax",
121122
]
122123
types = [
123124
"IPython",
@@ -365,6 +366,8 @@ filterwarnings = [
365366
"ignore:datetime.datetime.utcfromtimestamp\\(\\) is deprecated and scheduled for removal in a future version.*:DeprecationWarning",
366367
"ignore:unclosed .*:ResourceWarning",
367368
'ignore:Widget\..* is deprecated\.:DeprecationWarning',
369+
'ignore:atleast_1d requires ndarray or scalar arguments.*:DeprecationWarning',
370+
'ignore:jax\.core\.(un)?mapped_aval is deprecated\.:DeprecationWarning',
368371
]
369372
markers = ["slow: marks tests as slow (select with '-m slow')"]
370373
minversion = "9.0"

src/ampform/dynamics/phasespace.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,15 @@ class ChewMandelstamIntegral(sp.Expr):
298298
meson_radius: Any = 1
299299
name: str | None = argument(default=None, sympify=False)
300300
algorithm: tuple[str, str] | None = argument(
301-
default=None, sympify=False, kw_only=True
301+
default=None, kw_only=True, sympify=False
302302
)
303+
"""See :attr:`.NumericalIntegral.algorithm`."""
303304
configuration: dict[str, Any] | None = argument(
304-
default=None, sympify=False, kw_only=True
305+
default=None, kw_only=True, sympify=False
305306
)
306-
dummify: bool = argument(default=True, sympify=False, kw_only=True)
307+
"""See :attr:`.NumericalIntegral.configuration`."""
308+
dummify: bool = argument(default=True, kw_only=True, sympify=False)
309+
"""Whether to dummify the integration variable. See :attr:`.NumericalIntegral.dummify`."""
307310

308311
def evaluate(self) -> sp.Expr:
309312
s, m1, m2, L, s_prime, epsilon, meson_radius, *_ = self.args # noqa: N806

src/ampform/sympy/__init__.py

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import itertools
1616
import re
1717
import sys
18-
import warnings
1918
from abc import abstractmethod
2019
from typing import TYPE_CHECKING, Any, cast
2120

@@ -373,14 +372,35 @@ class NumericalIntegral(sp.Integral):
373372
"""
374373

375374
function: sp.Expr
375+
"""Integrand of the integral."""
376376
limits: tuple[sp.Symbol, sp.Basic, sp.Basic]
377-
algorithm: tuple[str, str] | None = argument(
378-
default=None, sympify=False, kw_only=True
379-
)
377+
"""Integration variable and its limits (can be `~sympy.core.numbers.Infinity`)."""
378+
algorithm: str | None = argument(default=None, kw_only=True, sympify=False)
379+
"""Name of the numerical integration algorithm to use when lambdifying this integral.
380+
381+
The algorithm should be in the format :code:`module.function`, for instance
382+
:func:`scipy.integrate.quad_vec` or :func:`quadax.quadgk`. By default, the algorithm
383+
is :func:`quadax.romberg` when lambdifying to JAX and
384+
:func:`scipy.integrate.quad_vec` when lambdifying to NumPy.
385+
"""
380386
configuration: dict[str, Any] | None = argument(
381387
default=None, sympify=False, kw_only=True
382388
)
389+
"""Keyword arguments for the numerical integration algorithm.
390+
391+
For example, for :func:`scipy.integrate.quad_vec`, one can set the relative
392+
tolerance with :code:`configuration={'epsrel': 1e-5}`.
393+
"""
383394
dummify: bool = argument(default=True, sympify=False, kw_only=True)
395+
"""Replace the integration variable with a dummy symbol before lambdification.
396+
397+
The integrand expression is lambdified to a :code:`lambda` function. Therefore, when
398+
the integrand expresssion contains the integration variable in a non-trivial way,
399+
and the expression is lambdified using common sub-expressions, it is better to
400+
replace it with a unique `~sympy.core.symbol.Dummy` symbol that does not appear
401+
anywhere else in the expression tree, so that is not pulled out of the
402+
:code:`lambda` function.
403+
"""
384404

385405
@override
386406
def doit(self, **hints):
@@ -391,54 +411,82 @@ def doit(self, **hints):
391411
}
392412
return self.func(*args, **kwargs)
393413

414+
@override
415+
def _jaxcode(self, printer, *args) -> str: # ty:ignore[invalid-explicit-override]
416+
algorithm = self.algorithm or "quadax.romberg"
417+
if algorithm.startswith("quadax"):
418+
return self.__to_quadax_like(printer, algorithm)
419+
return self.__to_scipy_like(printer, algorithm)
420+
394421
@override
395422
def _numpycode(self, printer, *args) -> str: # ty:ignore[invalid-explicit-override]
396-
module, algorithm = self.algorithm or ("scipy.integrate", "quad_vec")
397-
if module.startswith("scipy"):
398-
_warn_if_scipy_not_installed()
423+
algorithm = self.algorithm or "scipy.integrate.quad_vec"
424+
if algorithm.startswith("quadax"):
425+
return self.__to_quadax_like(printer, algorithm)
426+
return self.__to_scipy_like(printer, algorithm)
427+
428+
def __to_quadax_like(self, printer, algorithm: str) -> str:
429+
"""https://quadax.readthedocs.io."""
430+
integrate, integrand, x, a, b = self.__prepare_components(printer, algorithm)
431+
src = _generate_function_call(
432+
integrate,
433+
fun=f"lambda {x}: {integrand}",
434+
interval=f"({a}, {b})",
435+
**self.configuration or {},
436+
)
437+
return f"{src}[0]"
438+
439+
def __to_scipy_like(self, printer, algorithm: str) -> str:
440+
"""https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.quad_vec.html."""
441+
integrate, integrand, x, a, b = self.__prepare_components(printer, algorithm)
442+
kwargs = self.configuration or {}
443+
src = _generate_function_call(
444+
integrate, f"lambda {x}: {integrand}", a, b, **kwargs
445+
)
446+
return f"{src}[0]"
447+
448+
def __prepare_components(
449+
self, printer, algorithm: str
450+
) -> tuple[str, str, str, str, str]:
399451
integration_vars, limits = _unpack_integral_limits(self)
400452
if len(limits) != 1 or len(integration_vars) != 1:
401453
msg = f"Cannot handle {len(limits)}-dimensional integrals"
402454
raise ValueError(msg)
403455
x = integration_vars[0]
404456
a, b = limits[0]
405-
expr = self.args[0]
457+
integrand = self.function
406458
if self.dummify:
407459
dummy = sp.Dummy()
408-
expr = expr.xreplace({x: dummy})
460+
integrand = integrand.xreplace({x: dummy})
409461
x = dummy
410-
integrate_numerically = algorithm
411-
printer.module_imports[module].add(integrate_numerically)
412-
src = _generate_function_call(
413-
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.quad_vec.html
414-
integrate_numerically,
415-
f"lambda {printer._print(x)}: {printer._print(expr)}",
462+
parts = algorithm.split(".")
463+
if len(parts) < 2: # noqa: PLR2004
464+
msg = f"Algorithm should be in format 'module.function', got '{algorithm}'"
465+
raise ValueError(msg)
466+
module_name = ".".join(parts[:-1])
467+
algorithm_name = parts[-1]
468+
printer.module_imports[module_name].add(algorithm_name)
469+
return (
470+
algorithm_name,
471+
printer._print(integrand),
472+
printer._print(x),
416473
printer._print(a),
417474
printer._print(b),
418-
**self.configuration or {},
419475
)
420-
return f"{src}[0]"
421476

422477

423-
def _generate_function_call(func_name: str, *args, **kwargs) -> str:
478+
def _generate_function_call(func_name: str, /, *args, **kwargs) -> str:
424479
"""Generate a function call string with the given function name, arguments, and keyword arguments.
425480
426481
>>> _generate_function_call("quad_vec", "f", 0, 1, epsabs=1e-5)
427482
'quad_vec(f, 0, 1, epsabs=1e-05)'
483+
>>> _generate_function_call("quadgk", fun="lambda x: x**2", interval=(0, 1))
484+
'quadgk(fun=lambda x: x**2, interval=(0, 1))'
428485
"""
429-
src = f"{func_name}({', '.join(map(str, args))}"
430-
for key, value in kwargs.items():
431-
src += f", {key}={value}"
486+
src = f"{func_name}("
487+
src += ", ".join(map(str, args))
488+
if args:
489+
src += ", "
490+
src += ", ".join(f"{key}={value}" for key, value in kwargs.items())
432491
src += ")"
433492
return src
434-
435-
436-
def _warn_if_scipy_not_installed() -> None:
437-
try:
438-
import scipy # noqa: F401, PLC0415
439-
except ImportError:
440-
warnings.warn(
441-
"Scipy is not installed. Install with 'pip install scipy' or with 'pip"
442-
" install ampform[scipy]'",
443-
stacklevel=1,
444-
)

tests/sympy/test_integral.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,40 @@
66

77

88
class TestNumericalIntegral:
9-
@pytest.mark.parametrize("call_doit", [True, False])
10-
@pytest.mark.parametrize("configuration", [{}, {"limit": 10}, {"limit": 100}])
9+
@pytest.mark.parametrize(
10+
("backend", "algorithm", "configuration"),
11+
[
12+
("jax", "quadax.quadgk", None),
13+
("jax", "quadax.romberg", None),
14+
("numpy", "scipy.integrate.quad_vec", None),
15+
("numpy", "scipy.integrate.quad_vec", {"limit": 10}),
16+
("numpy", None, None),
17+
],
18+
)
19+
@pytest.mark.parametrize("call_doit", [False, True])
20+
@pytest.mark.parametrize("dummify", [False, True])
1121
def test_real_value_function(
12-
self, call_doit: bool, configuration: dict[str, int | None]
22+
self,
23+
algorithm: str | None,
24+
backend: str,
25+
call_doit: bool,
26+
configuration: dict[str, int | None],
27+
dummify: bool,
1328
):
1429
x = sp.symbols("x")
15-
integral_expr = NumericalIntegral(x**2, (x, 1, 3), configuration=configuration)
30+
integral_expr = NumericalIntegral(
31+
x**2,
32+
(x, 1, 3),
33+
algorithm=algorithm,
34+
configuration=configuration,
35+
dummify=dummify,
36+
)
1637
if call_doit:
1738
integral_expr = integral_expr.doit()
39+
assert integral_expr.algorithm == algorithm
1840
assert integral_expr.configuration == configuration
19-
func = sp.lambdify(args=[], expr=integral_expr)
41+
assert integral_expr.dummify is dummify
42+
func = sp.lambdify([], integral_expr, backend)
2043
assert func() == 26 / 3 # noqa: RUF069
2144

2245
@pytest.mark.parametrize(

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)