Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ All notable changes to this project will be documented in this file.
## [0.12.0-dev]
### Added
- Example config for fine-tuning the SevenNet-MF-ompa model
- FlashTP support (https://github.com/SNU-ARC/flashTP)

### Changed
- ninja dependency added

### Fixed
- Multi modal model finetuning: Pass modal_map into config during multi-fidelity continual training #232
- omat deploy parallel fixed


## [0.11.2]
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sevenn"
version = "0.11.2.post1"
version = "0.12.0.dev"
authors = [
{ name = "Yutack Park", email = "[email protected]" },
{ name = "Haekwan Jeon", email = "[email protected]" },
Expand Down Expand Up @@ -29,12 +29,13 @@ dependencies = [
"matscipy",
"pandas",
"requests",
"ninja",
"setuptools>=61.0"
]
[project.optional-dependencies]
test = ["pytest", "pytest-cov>=5"]
cueq12 = ["cuequivariance>=0.4.0; python_version >= '3.10'", "cuequivariance-torch>=0.4.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu12; python_version >= '3.10'"]
cueq11 = ["cuequivariance>=0.4.0; python_version >= '3.10'", "cuequivariance-torch>=0.4.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu11; python_version >= '3.10'"]
test = ["pytest", "pytest-cov>=5", "ipython"]
cueq12 = ["cuequivariance>=0.6.0; python_version >= '3.10'", "cuequivariance-torch>=0.6.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu12; python_version >= '3.10'"]
flashTP = ["flashTP @ git+https://github.com/SNU-ARC/[email protected]"]

[project.scripts]
sevenn = "sevenn.main.sevenn:main"
Expand Down
2 changes: 2 additions & 0 deletions sevenn/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def error_record_condition(x):
KEY.SELF_CONNECTION_TYPE: 'nequip',
KEY.INTERACTION_TYPE: 'nequip',
KEY._NORMALIZE_SPH: True,
KEY.USE_FLASH_TP: False,
KEY.CUEQUIVARIANCE_CONFIG: {},
}

Expand Down Expand Up @@ -168,6 +169,7 @@ def error_record_condition(x):
),
KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE,
KEY._NORMALIZE_SPH: bool,
KEY.USE_FLASH_TP: bool,
KEY.CUEQUIVARIANCE_CONFIG: dict,
}

Expand Down
1 change: 1 addition & 0 deletions sevenn/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@
INTERACTION_TYPE = 'interaction_type'
TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated

USE_FLASH_TP = 'use_flash_tp'
CUEQUIVARIANCE_CONFIG = 'cuequivariance_config'

_NORMALIZE_SPH = '_normalize_sph'
Expand Down
19 changes: 15 additions & 4 deletions sevenn/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
file_type: str = 'checkpoint',
device: Union[torch.device, str] = 'auto',
modal: Optional[str] = None,
enable_cueq: bool = False,
enable_cueq: Optional[bool] = None,
enable_flash: Optional[bool] = None,
sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info
**kwargs,
) -> None:
Expand All @@ -56,8 +57,10 @@ def __init__(
modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa,
it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24)
case insensitive
enable_cueq: bool, default=False
enable_cueq: bool, default=None (use the checkpoint's backend)
if True, use cuEquivariant to accelerate inference.
enable_flash: bool, default=None (use the checkpoint's backend)
if True, use FlashTP to accelerate inference.
sevennet_config: dict | None, default=None
Not used, but can be used to carry meta information of this calculator
"""
Expand All @@ -72,6 +75,13 @@ def __init__(
if file_type not in allowed_file_types:
raise ValueError(f'file_type not in {allowed_file_types}')

enable_cueq = os.getenv('SEVENNET_ENABLE_CUEQ') == '1' or enable_cueq
enable_flash = os.getenv('SEVENNET_ENABLE_FLASH') == '1' or enable_flash
print('cueq')
print(enable_cueq)
print('flash')
print(enable_flash)

if enable_cueq and file_type in ['model_instance', 'torchscript']:
warnings.warn(
'file_type should be checkpoint to enable cueq. cueq set to False'
Expand All @@ -91,8 +101,9 @@ def __init__(
if file_type == 'checkpoint' and isinstance(model, str):
cp = util.load_checkpoint(model)

backend = 'e3nn' if not enable_cueq else 'cueq'
model_loaded = cp.build_model(backend)
model_loaded = cp.build_model(
enable_cueq=enable_cueq, enable_flash=enable_flash
)
model_loaded.set_is_batch_data(False)

self.type_map = cp.config[KEY.TYPE_MAP]
Expand Down
57 changes: 42 additions & 15 deletions sevenn/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ def __repr__(self) -> str:
cfg = self.config # just alias
if len(cfg) == 0:
return ''

try:
cp_using_cueq = self.config['cuequivariance_config']['use']
except KeyError:
cp_using_cueq = False

dct = {
'Sevennet version': cfg.get('version', 'Not found'),
'When': self.time,
Expand All @@ -215,6 +221,8 @@ def __repr__(self) -> str:
'Self connection type': cfg.get('self_connection_type', 'nequip'),
'Last epoch': self.epoch,
'Elements': len(cfg.get('chemical_species', [])),
'cuEquivariance used': cp_using_cueq,
'FlashTP used': self.config.get('use_flash_tp', False),
}
if cfg.get('use_modality', False):
dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys()))
Expand Down Expand Up @@ -299,37 +307,56 @@ def _load(self) -> None:

self._loaded = True

def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential:
def build_model(
self,
*,
enable_cueq: Optional[bool] = None,
enable_flash: Optional[bool] = None,
_flash_lammps: bool = False,
) -> AtomGraphSequential:
"""
Breaking change (backends X)
"""
from .model_build import build_E3_equivariant_model

use_cue = not backend or backend.lower() in ['cue', 'cueq']
try:
cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use']
cp_using_cueq = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use']
except KeyError:
cp_using_cue = False
cp_using_cueq = False
enable_cueq = cp_using_cueq if enable_cueq is None else enable_cueq

if (not backend) or (use_cue == cp_using_cue):
cp_using_flash = self.config.get(KEY.USE_FLASH_TP, False)
enable_flash = cp_using_flash if enable_flash is None else enable_flash

assert not _flash_lammps or enable_flash
cfg_new = self.config
cfg_new['_flash_lammps'] = _flash_lammps

if (cp_using_cueq, cp_using_flash) == (enable_cueq, enable_flash):
# backend not given, or checkpoint backend is same as requested
model = build_E3_equivariant_model(self.config)
model = build_E3_equivariant_model(cfg_new)
state_dict = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model
)
missing, not_used = model.load_state_dict(state_dict, strict=True)
assert len(missing) == 0, f'Missing keys: {missing}'
if len(not_used) > 0:
warnings.warn(f'Some keys are not used: {not_used}', UserWarning)
else:
cfg_new = self.config
cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue}
print('Converting model backend...')

cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': enable_cueq}
cfg_new[KEY.USE_FLASH_TP] = enable_flash
model = build_E3_equivariant_model(cfg_new)
stct_src = compat.patch_state_dict_if_old(
self.model_state_dict, self.config, model
)

state_dict = _convert_e3nn_and_cueq(
stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue
stct_src, model.state_dict(), self.config, from_cueq=cp_using_cueq
)
missing, not_used = model.load_state_dict(state_dict, strict=False)

missing, not_used = model.load_state_dict(state_dict, strict=False)
if len(not_used) > 0:
warnings.warn(f'Some keys are not used: {not_used}', UserWarning)

assert len(missing) == 0, f'Missing keys: {missing}'
return model

def yaml_dict(self, mode: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -503,7 +530,7 @@ def append_modal(
scaler = init_shift_scale(dst_config)

# finally, prepare updated continuable state dict using above
orig_model = self.build_model()
orig_model = self.build_model(enable_cueq=False, enable_flash=False)
orig_state_dict = orig_model.state_dict()

new_state_dict = copy_state_dict(orig_state_dict)
Expand Down
16 changes: 16 additions & 0 deletions sevenn/main/sevenn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,20 @@ def run(args):
distributed = args.distributed
distributed_backend = args.distributed_backend
use_cue = args.enable_cueq
use_flash = args.enable_flashTP

if use_cue:
import sevenn.nn.cue_helper

if not sevenn.nn.cue_helper.is_cue_available():
raise ImportError('cuEquivariance not installed.')

if use_flash:
import sevenn.nn.flash_helper

if not sevenn.nn.flash_helper.is_flash_available():
raise ImportError('FlashTP not installed or no GPU found.')

if working_dir is None:
working_dir = os.getcwd()
elif not os.path.isdir(working_dir):
Expand Down Expand Up @@ -108,6 +115,9 @@ def run(args):
else:
model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True})

if use_flash:
model_config[KEY.USE_FLASH_TP] = True

logger.print_config(model_config, data_config, train_config)
# don't have to distinguish configs inside program
global_config.update(model_config)
Expand Down Expand Up @@ -147,6 +157,12 @@ def cmd_parser_train(parser):
help='use cuEq accelerations for training',
action='store_true',
)
ag.add_argument(
'-flashTP',
'--enable_flashTP',
help='use flashTP accelerations for training',
action='store_true',
)
ag.add_argument(
'-w',
'--working_dir',
Expand Down
17 changes: 15 additions & 2 deletions sevenn/main/sevenn_get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def add_args(parser):
help='Modality of multi-modal model',
type=str,
)
ag.add_argument(
'-flashTP',
'--enable_flashTP',
help='use flashTP. LAMMPS must be specially compiled.',
action='store_true',
)


def run(args):
Expand All @@ -47,6 +53,7 @@ def run(args):
get_parallel = args.get_parallel
get_serial = not get_parallel
modal = args.modal
use_flash = args.enable_flashTP

if output_prefix is None:
output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial'
Expand All @@ -57,10 +64,16 @@ def run(args):
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)

if use_flash:
import sevenn.nn.flash_helper

if not sevenn.nn.flash_helper.is_flash_available():
raise ImportError('FlashTP not installed or no GPU found.')

if get_serial:
deploy(checkpoint_path, output_prefix, modal)
deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash)
else:
deploy_parallel(checkpoint_path, output_prefix, modal)
deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash)


# legacy way
Expand Down
40 changes: 40 additions & 0 deletions sevenn/main/sevenn_patch_lammps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import os.path as osp
import subprocess

from sevenn import __version__
Expand All @@ -21,6 +22,7 @@ def add_args(parser):
ag = parser
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
ag.add_argument('--d3', help='Enable D3 support', action='store_true')
ag.add_argument('--flashTP', help='Enable flashTP', action='store_true')
# cxx_standard is detected automatically


Expand All @@ -39,8 +41,46 @@ def run(args):
d3_support = '0'
print(' - D3 support disabled')

so_lammps = ''
if args.flashTP:
try:
import flashTP_e3nn.flashTP as hook
except ImportError:
raise ImportError('FlashTP import failed.')

flash_dir = osp.abspath(osp.dirname(hook.__file__))

so_files = []
so_lammps = []
for ls in os.listdir(flash_dir):
fpath = osp.join(flash_dir, ls)
if ls.endswith('.so'):
so_files.append(fpath)
if 'lammps' in ls:
so_lammps.append(fpath)
if len(so_files) == 0:
raise ValueError(
f'FlashTP .so file not found. The dir searched: {flash_dir}'
)
if len(so_lammps) == 0:
raise ValueError(
f'FlashTP lammps .so file not found The dir searched: {flash_dir}'
)
elif len(so_lammps) > 1:
raise ValueError(f'More than 1 lammps .so files are found: {so_lammps}')
so_lammps = so_lammps[0]

print(' - FlashTP support enabled.')
else:
flash_dir = None

script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}'

if args.flashTP:
assert osp.isfile(so_lammps)
cmd += f' {so_lammps}'

res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?

Expand Down
Loading