Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
932 changes: 932 additions & 0 deletions docs/examples/png_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/examples/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pmwd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pmwd.configuration import Configuration
from pmwd.cosmology import Cosmology, SimpleLCDM, Planck18, E2, H_deriv, Omega_m_a
from pmwd.boltzmann import (transfer_integ, transfer_fit, transfer, growth_integ,
growth, varlin_integ, varlin, boltzmann, linear_power)
growth, varlin_integ, varlin, boltzmann, linear_power, linear_transfer)
from pmwd.particles import (Particles, ptcl_enmesh,
ptcl_pos, ptcl_rpos, ptcl_rsd, ptcl_los)
from pmwd.scatter import scatter
Expand Down
45 changes: 45 additions & 0 deletions pmwd/boltzmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,48 @@ def linear_power(k, a, cosmo, conf):
Plin *= D**2

return Plin.astype(float_dtype)

def linear_transfer(k, a, cosmo, conf):
r"""Linear matter transfer function at given wavenumbers and scale factors.

Parameters
----------
k : array_like
Wavenumbers in [1/L].
a : array_like or None
Scale factors. If None, output is not scaled by growth.
cosmo : Cosmology
conf : Configuration

Returns
-------
Tlin : jax.numpy.ndarray of (k * a * 1.).dtype
Linear matter transfer function.

Raises
------
ValueError
If not in 3D.

"""

if conf.dim != 3:
raise ValueError(f'dim={conf.dim} not supported')

k = jnp.asarray(k)
float_dtype = jnp.promote_types(k.dtype, float)

T = transfer(k, cosmo, conf)

# TF: the 3/5 is because the primordial amplitude A_s is given for \zeta instead of \Phi
Tlin = (3/5) * (2/3) * (conf.c / conf.H_0)**2 / cosmo.Omega_m * T

if a is not None:
a = jnp.asarray(a)
float_dtype = jnp.promote_types(float_dtype, a.dtype)

D = growth(a, cosmo, conf)

Tlin *= D

return Tlin.astype(float_dtype)
10 changes: 9 additions & 1 deletion pmwd/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class Cosmology:
Dark energy equation of state linear parameter. Default is None.
h : float ArrayLike
Hubble constant in unit of 100 [km/s/Mpc].

f_nl_loc_ : float or jax.numpy.ndarray, optional
amplitude of local primordial non-Gaussianity. Default is None.
"""

conf: Configuration = field(repr=False)
Expand All @@ -64,6 +65,8 @@ class Cosmology:
w_0_fixed: ClassVar[float] = -1
w_a_: Optional[ArrayLike] = None
w_a_fixed: ClassVar[float] = 0
f_nl_loc_: Optional[ArrayLike] = None
f_nl_loc_fixed: ClassVar[float] = 0

transfer: Optional[Array] = field(default=None, compare=False)

Expand Down Expand Up @@ -148,6 +151,11 @@ def w_a(self):
"""Dark energy equation of state linear parameter."""
return self.w_a_fixed if self.w_a_ is None else self.w_a_

@property
def f_nl_loc(self):
"""Amplitude of local primordial non-Gaussianity."""
return self.f_nl_loc_fixed if self.f_nl_loc_ is None else self.f_nl_loc_

@property
def sigma8(self):
"""Linear matter rms overdensity within a tophat sphere of 8 Mpc/h radius at a=1."""
Expand Down
47 changes: 40 additions & 7 deletions pmwd/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random
import jax.numpy as jnp

from pmwd.boltzmann import linear_power
from pmwd.boltzmann import linear_power, linear_transfer
from pmwd.pm_util import fftfreq, fftfwd, fftinv


Expand Down Expand Up @@ -65,7 +65,7 @@ def _safe_sqrt_bwd(y, y_cot):


@partial(jit, static_argnums=4)
@partial(checkpoint, static_argnums=4)
# @partial(checkpoint, static_argnums=4)
def linear_modes(modes, cosmo, conf, a=None, real=False):
"""Linear matter overdensity Fourier or real modes.

Expand Down Expand Up @@ -100,14 +100,47 @@ def linear_modes(modes, cosmo, conf, a=None, real=False):
if a is not None:
a = jnp.asarray(a, dtype=conf.float_dtype)

Plin = linear_power(k, a, cosmo, conf)

if jnp.isrealobj(modes):
modes = fftfwd(modes, norm='ortho')

modes *= _safe_sqrt(Plin * conf.box_vol)

if cosmo.f_nl_loc_ is not None:
Tlin = linear_transfer(k, a, cosmo, conf)*k*k
Pprim = 2*jnp.pi**2. * cosmo.A_s * (k/cosmo.k_pivot)**(cosmo.n_s-1.)\
* k**(-3.)
Pprim = Pprim.at[0,0,0].set(0.)

modes *= _safe_sqrt(Pprim / conf.ptcl_cell_vol)

modes = fftinv(modes, norm='ortho')
modes = jnp.fft.rfftn(modes)

# TF: padding for antialiasing (factor of (3/2)**3. for the change in dimension)
modes_NG = jnp.fft.fftshift(modes,axes=[0,1])
modes_NG = jnp.pad(modes_NG, ((conf.ptcl_grid_shape[0]//4,conf.ptcl_grid_shape[0]//4),(conf.ptcl_grid_shape[1]//4,conf.ptcl_grid_shape[1]//4),(0,conf.ptcl_grid_shape[2]//4))) * (3/2)**3.
modes_NG = jnp.fft.ifftshift(modes_NG,axes=[0,1])

# TF: square the modes in real space
modes_NG = jnp.fft.rfftn(jnp.fft.irfftn(modes_NG)**2.)

# TF: downsampling (factor of (3/2)**3. for the change in dimension)
modes_NG = jnp.fft.fftshift(modes_NG,axes=[0,1])
modes_NG = modes_NG[conf.ptcl_grid_shape[0]//4:-conf.ptcl_grid_shape[0]//4, conf.ptcl_grid_shape[1]//4:-conf.ptcl_grid_shape[1]//4,:-conf.ptcl_grid_shape[2]//4] / (3/2)**3.
modes_NG = jnp.fft.ifftshift(modes_NG,axes=[0,1])

# TF: add to the gaussian modes, factor of 3/5 is because we are generating \zeta and f_nl is defined for \Phi
modes = jnp.fft.irfftn(modes)
modes_NG = jnp.fft.irfftn(modes_NG)
modes = modes + 3/5 * cosmo.f_nl_loc * (modes_NG - jnp.mean(modes_NG))
modes = modes.astype(conf.float_dtype)

# TF: apply transfer function
modes = fftfwd(modes, norm='ortho')
modes *= Tlin * conf.box_vol / jnp.sqrt(conf.ptcl_num)
else:
Plin = linear_power(k, a, cosmo, conf)
modes *= _safe_sqrt(Plin * conf.box_vol)

if real:
modes = fftinv(modes, shape=conf.ptcl_grid_shape, norm=conf.ptcl_spacing)

return modes
return modes