Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ inference, train, etc
### Fixed
- Circular import in sevenn.checkpoint (dev0)
- Fix typing issues
- Added missing typings (especially return type)


## [0.11.0]
Expand Down
8 changes: 4 additions & 4 deletions sevenn/atom_graph_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional

import torch
import torch_geometric.data
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(
pos: Optional[torch.Tensor] = None,
edge_attr: Optional[torch.Tensor] = None,
**kwargs
):
) -> None:
super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos)
self[KEY.NODE_ATTR] = x # ?
for k, v in kwargs.items():
Expand All @@ -47,7 +47,7 @@ def to_numpy_dict(self):
}
return dct

def fit_dimension(self):
def fit_dimension(self) -> 'AtomGraphData':
per_atom_keys = [
KEY.ATOMIC_NUMBERS,
KEY.ATOMIC_ENERGY,
Expand All @@ -66,7 +66,7 @@ def fit_dimension(self):
return self

@staticmethod
def from_numpy_dict(dct):
def from_numpy_dict(dct: Dict[str, Any]) -> 'AtomGraphData':
for k, v in dct.items():
if k == KEY.CELL_SHIFT:
dct[k] = torch.Tensor(v) # this is special
Expand Down
11 changes: 6 additions & 5 deletions sevenn/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.jit
import torch.jit._script
from ase.atoms import Atoms
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.mixing import SumCalculator
from ase.data import chemical_symbols
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(
enable_cueq: bool = False,
sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info
**kwargs,
):
) -> None:
"""Initialize SevenNetCalculator.

Parameters
Expand Down Expand Up @@ -163,7 +164,7 @@ def __init__(
'energies',
]

def set_atoms(self, atoms):
def set_atoms(self, atoms: Atoms) -> None:
# called by ase, when atoms.calc = calc
zs = tuple(set(atoms.get_atomic_numbers()))
for z in zs:
Expand All @@ -173,7 +174,7 @@ def set_atoms(self, atoms):
f'Model do not know atomic number: {z}, (knows: {sp})'
)

def output_to_results(self, output):
def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]:
energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item()
num_atoms = output['num_atoms'].item()
atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten()
Expand Down Expand Up @@ -233,7 +234,7 @@ def __init__(
vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au
cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au
**kwargs,
):
) -> None:
"""Initialize SevenNetD3Calculator. CUDA required.

Parameters
Expand Down Expand Up @@ -351,7 +352,7 @@ def __init__(
vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au
cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au
**kwargs,
):
) -> None:
super().__init__(**kwargs)

if not torch.cuda.is_available():
Expand Down
27 changes: 17 additions & 10 deletions sevenn/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import warnings
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from ase.atoms import Atoms
from packaging.version import Version
from torch import Tensor
from torch import load as torch_load
Expand All @@ -20,7 +21,9 @@
from sevenn.nn.sequential import AtomGraphSequential


def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6):
def assert_atoms(
atoms1: Atoms, atoms2: Atoms, rtol: float = 1e-5, atol: float = 1e-6
) -> None:
import numpy as np

def acl(a, b, rtol=rtol, atol=atol):
Expand All @@ -39,7 +42,9 @@ def acl(a, b, rtol=rtol, atol=atol):
# assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies())


def copy_state_dict(state_dict) -> dict:
def copy_state_dict(
state_dict: Union[Dict[str, Any], List[Any], Tensor],
) -> Dict[str, Any]:
if isinstance(state_dict, dict):
return {key: copy_state_dict(value) for key, value in state_dict.items()}
elif isinstance(state_dict, list):
Expand All @@ -55,8 +60,10 @@ def _config_cp_routine(config):
cp_ver = Version(config.get('version', None))
this_ver = Version(sevenn.__version__)
if cp_ver > this_ver:
warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source'
f'({this_ver}). This may cause unexpected behaviors')
warnings.warn(
f'The checkpoint version ({cp_ver}) is newer than this source'
f'({this_ver}). This may cause unexpected behaviors'
)

defaults = {**consts.model_defaults(config)}
config = compat.patch_old_config(config) # type: ignore
Expand Down Expand Up @@ -177,7 +184,7 @@ class SevenNetCheckpoint:
Tool box for checkpoint processed from SevenNet.
"""

def __init__(self, checkpoint_path: Union[pathlib.Path, str]):
def __init__(self, checkpoint_path: Union[pathlib.Path, str]) -> None:
self._checkpoint_path = os.path.abspath(checkpoint_path)
self._config = None
self._epoch = None
Expand Down Expand Up @@ -322,7 +329,7 @@ def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential:
assert len(missing) == 0, f'Missing keys: {missing}'
return model

def yaml_dict(self, mode: str) -> dict:
def yaml_dict(self, mode: str) -> Dict[str, Any]:
"""
Return dict for input.yaml from checkpoint config
Dataset paths and statistic values are removed intentionally
Expand Down Expand Up @@ -410,10 +417,10 @@ def yaml_dict(self, mode: str) -> dict:

def append_modal(
self,
dst_config,
dst_config: Dict[str, Any],
original_modal_name: str = 'origin',
working_dir: str = os.getcwd(),
):
) -> Dict[str, Any]:
""" """
import sevenn.train.modal_dataset as modal_dataset
from sevenn.model_build import init_shift_scale
Expand Down Expand Up @@ -536,7 +543,7 @@ def append_modal(

return new_state_dict

def get_checkpoint_dict(self) -> dict:
def get_checkpoint_dict(self) -> Dict[str, Any]:
"""
Return duplicate of this checkpoint with new hash and time.
Convenient for creating variant of the checkpoint
Expand Down
Loading