Skip to content

Commit 7c2381d

Browse files
authored
Api update (#227)
* Revert "Release 0.2.0" This reverts commit 060647c. * updated public api * removed unused second outputs from operators
1 parent 060647c commit 7c2381d

File tree

6 files changed

+87
-21
lines changed

6 files changed

+87
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
55

66
## [Unreleased]
7-
8-
## [0.2.0] - 2023-12-18
97
### Fixed
108
- Fixed arguments error in helmholtz notebook
119

@@ -99,8 +97,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
9997
- Pml for 1D and 3D simulations.
10098
- Plotting functions of `jwave.utils` now work with both `Field`s and arrays.
10199

102-
[Unreleased]: https://github.com/ucl-bug/jwave/compare/0.2.0...master
103-
[0.2.0]: https://github.com/ucl-bug/jwave/compare/0.1.5...0.2.0
100+
[Unreleased]: https://github.com/ucl-bug/jwave/compare/0.1.5...master
104101
[0.1.5]: https://github.com/ucl-bug/jwave/compare/0.1.4...0.1.5
105102
[0.1.4]: https://github.com/ucl-bug/jwave/compare/0.1.3...0.1.4
106103
[0.1.3]: https://github.com/ucl-bug/jwave/compare/0.1.2...0.1.3
@@ -111,4 +108,3 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
111108
[0.0.3]: https://github.com/ucl-bug/jwave/compare/0.0.2...0.0.3
112109
[0.0.2]: https://github.com/ucl-bug/jwave/compare/0.0.1...0.0.2
113110
[0.0.1]: https://github.com/ucl-bug/jwave/releases/tag/0.0.1
114-

jwave/__init__.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,53 @@
1414
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.
1515

1616
# nopycln: file
17-
from jaxdf.discretization import *
17+
from jaxdf import (
18+
operator,
19+
Continuous,
20+
Domain,
21+
FiniteDifferences,
22+
FourierSeries,
23+
Field,
24+
Linear,
25+
OnGrid
26+
)
27+
28+
from .acoustics import (
29+
angular_spectrum,
30+
born_iteration,
31+
born_series,
32+
db2neper,
33+
helmholtz_solver_verbose,
34+
helmholtz_solver,
35+
helmholtz,
36+
homogeneous_helmholtz_green,
37+
laplacian_with_pml,
38+
mass_conservation_rhs,
39+
momentum_conservation_rhs,
40+
pml,
41+
pressure_from_density,
42+
rayleigh_integral,
43+
scale_source_helmholtz,
44+
scattering_potential,
45+
simulate_wave_propagation,
46+
spectral,
47+
wave_propagation_symplectic_step,
48+
wavevector,
49+
TimeWavePropagationSettings,
50+
)
51+
from .geometry import (
52+
BLISensors,
53+
DistributedTransducer,
54+
Medium,
55+
Sensors,
56+
Sources,
57+
TimeAxis,
58+
TimeHarmonicSource,
59+
)
1860

1961
from jwave import acoustics as ac
2062
from jwave import geometry as geometry
63+
from jwave import logger as logger
64+
from jwave import phantoms as phantoms
2165
from jwave import signal_processing as signal_processing
2266
from jwave import utils as utils

jwave/acoustics/__init__.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,31 @@
1414
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.
1515

1616
# nopycln: file
17-
from .operators import *
18-
from .time_harmonic import *
19-
from .time_varying import *
17+
from .conversion import db2neper
18+
from .operators import (
19+
helmholtz,
20+
laplacian_with_pml,
21+
scale_source_helmholtz,
22+
wavevector,
23+
)
24+
from .time_harmonic import (
25+
angular_spectrum,
26+
born_iteration,
27+
born_series,
28+
helmholtz_solver,
29+
helmholtz_solver_verbose,
30+
homogeneous_helmholtz_green,
31+
rayleigh_integral,
32+
scattering_potential
33+
)
34+
from .time_varying import (
35+
mass_conservation_rhs,
36+
momentum_conservation_rhs,
37+
pressure_from_density,
38+
simulate_wave_propagation,
39+
wave_propagation_symplectic_step,
40+
TimeWavePropagationSettings,
41+
)
42+
43+
from . import spectral
44+
from . import pml

jwave/acoustics/time_harmonic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def body_fun(carry):
342342

343343
out_field = _cbs_unnorm_units(out_field, _conversion)
344344

345-
return out_field, None
345+
return out_field
346346

347347

348348
@operator
@@ -377,7 +377,7 @@ def born_iteration(field: Field,
377377
G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon)
378378
V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon)
379379

380-
return field - (1j / epsilon) * V2, params
380+
return field - (1j / epsilon) * V2
381381

382382

383383
@operator
@@ -401,7 +401,7 @@ def scattering_potential(field: Field,
401401

402402
k = k_sq - k0**2 - 1j * epsilon
403403
out = field * k
404-
return out, params
404+
return out
405405

406406

407407
@operator
@@ -430,7 +430,7 @@ def homogeneous_helmholtz_green(field: FourierSeries,
430430
u_fft = jnp.fft.fftn(u)
431431
Gu_fft = g_fourier * u_fft
432432
Gu = jnp.fft.ifftn(Gu_fft)
433-
return field.replace_params(Gu), params
433+
return field.replace_params(Gu)
434434

435435

436436
@operator
@@ -500,7 +500,7 @@ def direc_exp_term(x, y, z):
500500
# Weights of the Rayleigh integral
501501
weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)),
502502
in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2])
503-
return jnp.sum(weights * pressure.on_grid) * area, None
503+
return jnp.sum(weights * pressure.on_grid) * area
504504

505505

506506
@operator
@@ -560,7 +560,7 @@ def helm_func(u):
560560
)[0]
561561
elif method == "bicgstab":
562562
out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0]
563-
return -1j * omega * out, None
563+
return -1j * omega * out
564564

565565

566566
def helmholtz_solver_verbose(

jwave/acoustics/time_varying.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ def __init__(
8282
self.smooth_initial = smooth_initial
8383

8484

85-
default_time_wave_prop_settings = TimeWavePropagationSettings()
86-
8785

8886
def _shift_rho(rho0, direction, dx):
8987
if isinstance(rho0, OnGrid):
@@ -382,7 +380,7 @@ def simulate_wave_propagation(
382380
medium: Medium[OnGrid],
383381
time_axis: TimeAxis,
384382
*,
385-
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
383+
settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
386384
sources=None,
387385
sensors=None,
388386
u0=None,
@@ -533,7 +531,7 @@ def simulate_wave_propagation(
533531
medium: Medium[FourierSeries],
534532
time_axis: TimeAxis,
535533
*,
536-
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
534+
settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
537535
sources=None,
538536
sensors=None,
539537
u0=None,

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "jwave"
3-
version = "0.2.0"
3+
version = "0.1.5"
44
description = "Fast and differentiable acoustic simulations in JAX."
55
authors = [
66
"Antonio Stanziola <[email protected]>",
@@ -108,9 +108,12 @@ split_before_logical_operator = true
108108

109109
[tool.pytest.ini_options]
110110
addopts = """\
111-
--doctest-modules \
111+
--doctest-modules\
112112
"""
113113

114+
[tool.pytest_env]
115+
CUDA_VISIBLE_DEVICES = ""
116+
114117
[tool.coverage.report]
115118
exclude_lines = [
116119
'if TYPE_CHECKING:',

0 commit comments

Comments
 (0)