Skip to content

Commit 1c3e099

Browse files
authored
feat(pt/tf): init-(frz)-model use pretrain script (#3926)
Support `--use-pretrain-script` for pt&tf when doing init-(frz)-model. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Tests** - Enhanced and added new test cases for deep learning model initialization and evaluation. - Improved setup and cleanup processes for temporary files and directories in tests to ensure a cleaner test environment. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 73312f2 commit 1c3e099

File tree

5 files changed

+242
-8
lines changed

5 files changed

+242
-8
lines changed

deepmd/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def main_parser() -> argparse.ArgumentParser:
258258
parser_train.add_argument(
259259
"--use-pretrain-script",
260260
action="store_true",
261-
help="Use model parameters from the script of the pretrained model instead of user input when doing finetuning. Note: This behavior is default and unchangeable in TensorFlow.",
261+
help="When performing fine-tuning or init-model, "
262+
"utilize the model parameters provided by the script of the pretrained model rather than relying on user input. "
263+
"It is important to note that in TensorFlow, this behavior is the default and cannot be modified for fine-tuning. ",
262264
)
263265
parser_train.add_argument(
264266
"-o",

deepmd/pt/entrypoints/main.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,21 @@ def train(FLAGS):
256256
model_branch=FLAGS.model_branch,
257257
change_model_params=FLAGS.use_pretrain_script,
258258
)
259+
# update init_model or init_frz_model config if necessary
260+
if (
261+
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
262+
) and FLAGS.use_pretrain_script:
263+
if FLAGS.init_model is not None:
264+
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
265+
if "model" in init_state_dict:
266+
init_state_dict = init_state_dict["model"]
267+
config["model"] = init_state_dict["_extra_state"]["model_params"]
268+
else:
269+
config["model"] = json.loads(
270+
torch.jit.load(
271+
FLAGS.init_frz_model, map_location=DEVICE
272+
).get_model_def_script()
273+
)
259274

260275
# argcheck
261276
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")

deepmd/tf/entrypoints/train.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def train(
6565
is_compress: bool = False,
6666
skip_neighbor_stat: bool = False,
6767
finetune: Optional[str] = None,
68+
use_pretrain_script: bool = False,
6869
**kwargs,
6970
):
7071
"""Run DeePMD model training.
@@ -93,6 +94,9 @@ def train(
9394
skip checking neighbor statistics
9495
finetune : Optional[str]
9596
path to pretrained model or None
97+
use_pretrain_script : bool
98+
Whether to use model script in pretrained model when doing init-model or init-frz-model.
99+
Note that this option is true and unchangeable for fine-tuning.
96100
**kwargs
97101
additional arguments
98102
@@ -123,6 +127,41 @@ def train(
123127
jdata, run_opt.finetune
124128
)
125129

130+
if (
131+
run_opt.init_model is not None or run_opt.init_frz_model is not None
132+
) and use_pretrain_script:
133+
from deepmd.tf.utils.errors import (
134+
GraphWithoutTensorError,
135+
)
136+
from deepmd.tf.utils.graph import (
137+
get_tensor_by_name,
138+
get_tensor_by_name_from_graph,
139+
)
140+
141+
err_msg = (
142+
f"The input model: {run_opt.init_model if run_opt.init_model is not None else run_opt.init_frz_model} has no training script, "
143+
f"Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit."
144+
)
145+
if run_opt.init_model is not None:
146+
with tf.Graph().as_default() as graph:
147+
tf.train.import_meta_graph(
148+
f"{run_opt.init_model}.meta", clear_devices=True
149+
)
150+
try:
151+
t_training_script = get_tensor_by_name_from_graph(
152+
graph, "train_attr/training_script"
153+
)
154+
except GraphWithoutTensorError as e:
155+
raise RuntimeError(err_msg) from e
156+
else:
157+
try:
158+
t_training_script = get_tensor_by_name(
159+
run_opt.init_frz_model, "train_attr/training_script"
160+
)
161+
except GraphWithoutTensorError as e:
162+
raise RuntimeError(err_msg) from e
163+
jdata["model"] = json.loads(t_training_script)["model"]
164+
126165
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
127166

128167
jdata = normalize(jdata)

source/tests/pt/test_init_frz_model.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import json
3+
import os
4+
import shutil
5+
import tempfile
36
import unittest
47
from argparse import (
58
Namespace,
@@ -21,12 +24,17 @@
2124
DeepPot,
2225
)
2326

27+
from .common import (
28+
run_dp,
29+
)
30+
2431

2532
class TestInitFrzModel(unittest.TestCase):
2633
def setUp(self):
2734
input_json = str(Path(__file__).parent / "water/se_atten.json")
2835
with open(input_json) as f:
2936
config = json.load(f)
37+
config["model"]["descriptor"]["smooth_type_embedding"] = True
3038
config["training"]["numb_steps"] = 1
3139
config["training"]["save_freq"] = 1
3240
config["learning_rate"]["start_lr"] = 1.0
@@ -38,15 +46,30 @@ def setUp(self):
3846
]
3947

4048
self.models = []
41-
for imodel in range(2):
42-
if imodel == 1:
43-
config["training"]["numb_steps"] = 0
44-
trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1])
49+
for imodel in range(3):
50+
frozen_model = f"frozen_model{imodel}.pth"
51+
if imodel == 0:
52+
temp_config = deepcopy(config)
53+
trainer = get_trainer(temp_config)
54+
elif imodel == 1:
55+
temp_config = deepcopy(config)
56+
temp_config["training"]["numb_steps"] = 0
57+
trainer = get_trainer(temp_config, init_frz_model=self.models[-1])
4558
else:
46-
trainer = get_trainer(deepcopy(config))
47-
trainer.run()
59+
empty_config = deepcopy(config)
60+
empty_config["model"]["descriptor"] = {}
61+
empty_config["model"]["fitting_net"] = {}
62+
empty_config["training"]["numb_steps"] = 0
63+
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
64+
with open(tmp_input.name, "w") as f:
65+
json.dump(empty_config, f, indent=4)
66+
run_dp(
67+
f"dp --pt train {tmp_input.name} --init-frz-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
68+
)
69+
trainer = None
4870

49-
frozen_model = f"frozen_model{imodel}.pth"
71+
if imodel in [0, 1]:
72+
trainer.run()
5073
ns = Namespace(
5174
model="model.pt",
5275
output=frozen_model,
@@ -58,6 +81,7 @@ def setUp(self):
5881
def test_dp_test(self):
5982
dp1 = DeepPot(str(self.models[0]))
6083
dp2 = DeepPot(str(self.models[1]))
84+
dp3 = DeepPot(str(self.models[2]))
6185
cell = np.array(
6286
[
6387
5.122106549439247480e00,
@@ -96,8 +120,26 @@ def test_dp_test(self):
96120
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
97121
ret2 = dp2.eval(coord, cell, atype, atomic=True)
98122
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
123+
ret3 = dp3.eval(coord, cell, atype, atomic=True)
124+
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
99125
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
126+
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
100127
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
128+
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
101129
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
130+
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
102131
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
132+
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
103133
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
134+
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)
135+
136+
def tearDown(self):
137+
for f in os.listdir("."):
138+
if f.startswith("frozen_model") and f.endswith(".pth"):
139+
os.remove(f)
140+
if f.startswith("model") and f.endswith(".pt"):
141+
os.remove(f)
142+
if f in ["lcurve.out"]:
143+
os.remove(f)
144+
if f in ["stat_files"]:
145+
shutil.rmtree(f)

source/tests/pt/test_init_model.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import os
4+
import shutil
5+
import tempfile
6+
import unittest
7+
from copy import (
8+
deepcopy,
9+
)
10+
from pathlib import (
11+
Path,
12+
)
13+
14+
import numpy as np
15+
16+
from deepmd.pt.entrypoints.main import (
17+
get_trainer,
18+
)
19+
from deepmd.pt.infer.deep_eval import (
20+
DeepPot,
21+
)
22+
23+
from .common import (
24+
run_dp,
25+
)
26+
27+
28+
class TestInitModel(unittest.TestCase):
29+
def setUp(self):
30+
input_json = str(Path(__file__).parent / "water/se_atten.json")
31+
with open(input_json) as f:
32+
config = json.load(f)
33+
config["model"]["descriptor"]["smooth_type_embedding"] = True
34+
config["training"]["numb_steps"] = 1
35+
config["training"]["save_freq"] = 1
36+
config["learning_rate"]["start_lr"] = 1.0
37+
config["training"]["training_data"]["systems"] = [
38+
str(Path(__file__).parent / "water/data/single")
39+
]
40+
config["training"]["validation_data"]["systems"] = [
41+
str(Path(__file__).parent / "water/data/single")
42+
]
43+
44+
self.models = []
45+
for imodel in range(3):
46+
ckpt_model = f"model{imodel}.ckpt"
47+
if imodel == 0:
48+
temp_config = deepcopy(config)
49+
temp_config["training"]["save_ckpt"] = ckpt_model
50+
trainer = get_trainer(temp_config)
51+
elif imodel == 1:
52+
temp_config = deepcopy(config)
53+
temp_config["training"]["numb_steps"] = 0
54+
temp_config["training"]["save_ckpt"] = ckpt_model
55+
trainer = get_trainer(temp_config, init_model=self.models[-1])
56+
else:
57+
empty_config = deepcopy(config)
58+
empty_config["model"]["descriptor"] = {}
59+
empty_config["model"]["fitting_net"] = {}
60+
empty_config["training"]["numb_steps"] = 0
61+
empty_config["training"]["save_ckpt"] = ckpt_model
62+
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
63+
with open(tmp_input.name, "w") as f:
64+
json.dump(empty_config, f, indent=4)
65+
run_dp(
66+
f"dp --pt train {tmp_input.name} --init-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
67+
)
68+
trainer = None
69+
70+
if imodel in [0, 1]:
71+
trainer.run()
72+
self.models.append(ckpt_model + ".pt")
73+
74+
def test_dp_test(self):
75+
dp1 = DeepPot(str(self.models[0]))
76+
dp2 = DeepPot(str(self.models[1]))
77+
dp3 = DeepPot(str(self.models[2]))
78+
cell = np.array(
79+
[
80+
5.122106549439247480e00,
81+
4.016537340154059388e-01,
82+
6.951654033828678081e-01,
83+
4.016537340154059388e-01,
84+
6.112136112297989143e00,
85+
8.178091365465004481e-01,
86+
6.951654033828678081e-01,
87+
8.178091365465004481e-01,
88+
6.159552512682983760e00,
89+
]
90+
).reshape(1, 3, 3)
91+
coord = np.array(
92+
[
93+
2.978060152121375648e00,
94+
3.588469695887098077e00,
95+
2.792459820604495491e00,
96+
3.895592322591093115e00,
97+
2.712091020667753760e00,
98+
1.366836847133650501e00,
99+
9.955616170888935690e-01,
100+
4.121324820711413039e00,
101+
1.817239061889086571e00,
102+
3.553661462345699906e00,
103+
5.313046969500791583e00,
104+
6.635182659098815883e00,
105+
6.088601018589653080e00,
106+
6.575011420004332585e00,
107+
6.825240650611076099e00,
108+
]
109+
).reshape(1, -1, 3)
110+
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)
111+
112+
ret1 = dp1.eval(coord, cell, atype, atomic=True)
113+
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
114+
ret2 = dp2.eval(coord, cell, atype, atomic=True)
115+
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
116+
ret3 = dp3.eval(coord, cell, atype, atomic=True)
117+
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
118+
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
119+
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
120+
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
121+
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
122+
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
123+
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
124+
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
125+
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
126+
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
127+
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)
128+
129+
def tearDown(self):
130+
for f in os.listdir("."):
131+
if f.startswith("model") and f.endswith(".pt"):
132+
os.remove(f)
133+
if f in ["lcurve.out"]:
134+
os.remove(f)
135+
if f in ["stat_files"]:
136+
shutil.rmtree(f)

0 commit comments

Comments
 (0)