diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..0879bdb9 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +ignore = W291,W503,W504,E123,E126,E203,E402,E701 +per-file-ignores = __init__.py: F401 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..4e26c59e --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,99 @@ +name: Publish +on: + release: + types: [published] + branches: [master] + +jobs: + build_and_test: + strategy: + matrix: + python-version: [ 3.6, 3.8 ] + os: [ macos-latest, ubuntu-latest, windows-latest ] + fail-fast: false + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Check version + shell: bash + run: | + python -m pip install --upgrade pip + + python -m pip install torchsde + pypi_info=$(pip list | grep torchsde) + pypi_version=$(echo ${pypi_info} | cut -d " " -f2) + python -m pip uninstall -y torchsde + + python setup.py install + master_info=$(pip list | grep torchsde) + master_version=$(echo ${master_info} | cut -d " " -f2) + python -m pip uninstall -y torchsde + + python -c "import itertools as it + import sys + _, pypi_version, master_version = sys.argv + pypi_version_ = [int(i) for i in pypi_version.split('.')] + master_version_ = [int(i) for i in master_version.split('.')] + pypi_version__ = tuple(p for m, p in it.zip_longest(master_version_, pypi_version_, fillvalue=0)) + master_version__ = tuple(m for m, p in it.zip_longest(master_version_, pypi_version_, fillvalue=0)) + sys.exit(master_version__ <= pypi_version__)" ${pypi_version} ${master_version} + + - name: Install dependencies + run: | + python -m pip install flake8 pytest wheel + + - name: Lint with flake8 + run: | + python -m flake8 . + + - name: Build + shell: bash + run: | + python setup.py sdist bdist_wheel + rm -f dist/*.egg + + - name: Run sdist tests + shell: bash + run: | + python -m pip install dist/*.tar.gz + python -m pytest + python -m pip uninstall -y torchsde + + - name: Run bdist_wheel tests + shell: bash + run: | + python -m pip install dist/*.whl + python -m pytest + python -m pip uninstall -y torchsde + + - name: Upload builds + if: matrix.python-version == '3.8' && matrix.os == 'ubuntu-latest' + uses: actions/upload-artifact@v2 + with: + name: build-artifact + path: dist/ + + publish: + needs: [ build_and_test ] + strategy: + matrix: + os: [ ubuntu-latest ] + runs-on: ${{ matrix.os }} + steps: + - name: Download builds + uses: actions/download-artifact@v2 + with: + name: build-artifact + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@v1.4.2 + with: + user: ${{ secrets.pypi_username }} + password: ${{ secrets.pypi_password }} diff --git a/examples/cont_ddpm.py b/examples/cont_ddpm.py index f3e0d22e..2576c0a5 100644 --- a/examples/cont_ddpm.py +++ b/examples/cont_ddpm.py @@ -331,11 +331,11 @@ def plot(imgs, path): if global_step % pause_every == 0: logging.warning(f'global_step: {global_step:06d}, loss: {loss:.4f}') - img_path = os.path.join(train_dir, f'ode_samples', f'global_step_{global_step:07d}.png') + img_path = os.path.join(train_dir, 'ode_samples', f'global_step_{global_step:07d}.png') ode_samples = reverse.ode_sample_final(tau=tau) plot(ode_samples, img_path) - img_path = os.path.join(train_dir, f'sde_samples', f'global_step_{global_step:07d}.png') + img_path = os.path.join(train_dir, 'sde_samples', f'global_step_{global_step:07d}.png') sde_samples = reverse.sde_sample_final(tau=tau) plot(sde_samples, img_path) diff --git a/tests/__init__.py b/tests/__init__.py index fc9b9eeb..6913f02e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index f5c669b4..ac13a095 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -38,7 +38,8 @@ def _methods(): yield SDE_TYPES.stratonovich, METHODS.reversible_heun, None -@pytest.mark.parametrize("sde_cls", [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive, problems.NeuralGeneral]) +@pytest.mark.parametrize("sde_cls", [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive, + problems.NeuralGeneral]) @pytest.mark.parametrize("sde_type, method, options", _methods()) @pytest.mark.parametrize('adaptive', (False,)) def test_against_numerical(sde_cls, sde_type, method, options, adaptive): diff --git a/torchsde/_brownian/brownian_interval.py b/torchsde/_brownian/brownian_interval.py index 9a3a876b..afde9cb9 100644 --- a/torchsde/_brownian/brownian_interval.py +++ b/torchsde/_brownian/brownian_interval.py @@ -63,14 +63,14 @@ def _check_tensor_info(*tensors, size, dtype, device): devices += [t.device for t in tensors] if len(sizes) == 0: - raise ValueError(f"Must either specify `size` or pass in `W` or `H` to implicitly define the size.") + raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.") if not all(i == sizes[0] for i in sizes): - raise ValueError(f"Multiple sizes found. Make sure `size` and `W` or `H` are consistent.") + raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.") if not all(i == dtypes[0] for i in dtypes): - raise ValueError(f"Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.") + raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.") if not all(i == devices[0] for i in devices): - raise ValueError(f"Multiple devices found. Make sure `device` and `W` or `H` are consistent.") + raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.") # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes. return tuple(sizes[0]), dtypes[0], devices[0] diff --git a/torchsde/_core/adjoint_sde.py b/torchsde/_core/adjoint_sde.py index ceb8e4de..ed8b85d6 100644 --- a/torchsde/_core/adjoint_sde.py +++ b/torchsde/_core/adjoint_sde.py @@ -373,4 +373,5 @@ def g_prod_and_gdg_prod_diagonal(self, t, y_aug, v1, v2): # For Ito/Stratonovic create_graph=requires_grad ) vjp_y_and_params = misc.seq_sub(prod_partials_adj_y_and_params, mixed_partials_adj_y_and_params) - return self._g_prod(g_prod, y, adj_y, requires_grad), misc.flatten((vg_dg_vjp, *vjp_y_and_params)).unsqueeze(0) + return self._g_prod(g_prod, y, adj_y, requires_grad), misc.flatten((vg_dg_vjp, + *vjp_y_and_params)).unsqueeze(0) diff --git a/torchsde/_core/methods/log_ode.py b/torchsde/_core/methods/log_ode.py index ae75bc48..b662d921 100644 --- a/torchsde/_core/methods/log_ode.py +++ b/torchsde/_core/methods/log_ode.py @@ -30,9 +30,9 @@ class LogODEMidpoint(base_solver.BaseSDESolver): def __init__(self, sde, **kwargs): if isinstance(sde, adjoint_sde.AdjointSDE): - raise ValueError(f"Log-ODE schemes cannot be used for adjoint SDEs, because they require " - f"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " - f"diffusion-vector product. Use a different method instead.") + raise ValueError("Log-ODE schemes cannot be used for adjoint SDEs, because they require " + "direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " + "diffusion-vector product. Use a different method instead.") self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0 super(LogODEMidpoint, self).__init__(sde=sde, **kwargs) diff --git a/torchsde/_core/methods/reversible_heun.py b/torchsde/_core/methods/reversible_heun.py index ff242bd4..c1434ce0 100644 --- a/torchsde/_core/methods/reversible_heun.py +++ b/torchsde/_core/methods/reversible_heun.py @@ -42,7 +42,7 @@ from .. import adjoint_sde from .. import base_solver from .. import misc -from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHODS, METHOD_OPTIONS +from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHODS class ReversibleHeun(base_solver.BaseSDESolver): diff --git a/torchsde/_core/methods/srk.py b/torchsde/_core/methods/srk.py index 58bae16a..6144d78c 100644 --- a/torchsde/_core/methods/srk.py +++ b/torchsde/_core/methods/srk.py @@ -44,9 +44,9 @@ def __init__(self, sde, **kwargs): self.step = self.diagonal_or_scalar_step if isinstance(sde, adjoint_sde.AdjointSDE): - raise ValueError(f"Stochastic Runge–Kutta methods cannot be used for adjoint SDEs, because it requires " - f"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " - f"diffusion-vector product. Use a different method instead.") + raise ValueError("Stochastic Runge–Kutta methods cannot be used for adjoint SDEs, because it requires " + "direct access to the diffusion, whilst adjoint SDEs rely on a more efficient " + "diffusion-vector product. Use a different method instead.") super(SRK, self).__init__(sde=sde, **kwargs) diff --git a/torchsde/_core/misc.py b/torchsde/_core/misc.py index 865f1530..18d00e38 100644 --- a/torchsde/_core/misc.py +++ b/torchsde/_core/misc.py @@ -71,7 +71,7 @@ def stable_division(a, b, epsilon=1e-7): def vjp(outputs, inputs, **kwargs): if torch.is_tensor(inputs): inputs = [inputs] - _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. + _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 74 if torch.is_tensor(outputs): outputs = [outputs] @@ -85,7 +85,7 @@ def jvp(outputs, inputs, grad_inputs=None, **kwargs): # Unlike `torch.autograd.functional.jvp`, this function avoids repeating forward computation. if torch.is_tensor(inputs): inputs = [inputs] - _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. + _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 88 if torch.is_tensor(outputs): outputs = [outputs] diff --git a/torchsde/_core/sdeint.py b/torchsde/_core/sdeint.py index 09706a3c..5523bf94 100644 --- a/torchsde/_core/sdeint.py +++ b/torchsde/_core/sdeint.py @@ -122,13 +122,13 @@ def check_contract(sde, y0, ts, bm, method, adaptive, options, names, logqp): sde = base_sde.RenameMethodsSDE(sde, **names_to_change) if not hasattr(sde, "noise_type"): - raise ValueError(f"sde does not have the attribute noise_type.") + raise ValueError("sde does not have the attribute noise_type.") if sde.noise_type not in NOISE_TYPES: raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {sde.noise_type}.") if not hasattr(sde, "sde_type"): - raise ValueError(f"sde does not have the attribute sde_type.") + raise ValueError("sde does not have the attribute sde_type.") if sde.sde_type not in SDE_TYPES: raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde.sde_type}.") @@ -160,7 +160,7 @@ def check_contract(sde, y0, ts, bm, method, adaptive, options, names, logqp): if not torch.is_tensor(ts): if not isinstance(ts, (tuple, list)) or not all(isinstance(t, (float, int)) for t in ts): - raise ValueError(f"Evaluation times `ts` must be a 1-D Tensor or list/tuple of floats.") + raise ValueError("Evaluation times `ts` must be a 1-D Tensor or list/tuple of floats.") ts = torch.tensor(ts, dtype=y0.dtype, device=y0.device) if not misc.is_strictly_increasing(ts): raise ValueError("Evaluation times `ts` must be strictly increasing.") @@ -275,8 +275,8 @@ def _check_2d_or_3d(name, shape): options = options.copy() if adaptive and method == METHODS.euler and sde.noise_type != NOISE_TYPES.additive: - warnings.warn(f"Numerical solution is not guaranteed to converge to the correct solution when using adaptive " - f"time-stepping with the Euler--Maruyama method with non-additive noise.") + warnings.warn("Numerical solution is not guaranteed to converge to the correct solution when using adaptive " + "time-stepping with the Euler--Maruyama method with non-additive noise.") return sde, y0, ts, bm, method, options diff --git a/torchsde/types.py b/torchsde/types.py index 6a92616f..b17b8b27 100644 --- a/torchsde/types.py +++ b/torchsde/types.py @@ -13,7 +13,7 @@ # limitations under the License. # We import from `typing` more than what's enough, so that other modules can import from this file and not `typing`. -from typing import Sequence, Union, Optional, Any, Dict, Tuple, Callable +from typing import Sequence, Union, Optional, Any, Dict, Tuple, Callable # noqa: F401 import torch