Skip to content

Commit 29db791

Browse files
authored
feat(pt): add datafile option for change-bias (#3945)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added optional `--datafile` argument to specify a file for system data processing. - **Bug Fixes** - Improved `help` messages for `--datafile` argument to clarify its usage. - **Tests** - Enhanced test coverage for changing bias with a new method that handles data from a system file. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 1c3e099 commit 29db791

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

deepmd/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def main_parser() -> argparse.ArgumentParser:
370370
"--datafile",
371371
default=None,
372372
type=str,
373-
help="The path to file of test list.",
373+
help="The path to the datafile, each line of which is a path to one data system.",
374374
)
375375
parser_tst.add_argument(
376376
"-S",
@@ -685,6 +685,13 @@ def main_parser() -> argparse.ArgumentParser:
685685
type=str,
686686
help="The system dir. Recursively detect systems in this directory",
687687
)
688+
parser_change_bias_source.add_argument(
689+
"-f",
690+
"--datafile",
691+
default=None,
692+
type=str,
693+
help="The path to the datafile, each line of which is a path to one data system.",
694+
)
688695
parser_change_bias_source.add_argument(
689696
"-b",
690697
"--bias-value",

deepmd/pt/entrypoints/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,12 @@ def change_bias(FLAGS):
469469
updated_model = model_to_change
470470
else:
471471
# calculate bias on given systems
472-
data_systems = process_systems(expand_sys_str(FLAGS.system))
472+
if FLAGS.datafile is not None:
473+
with open(FLAGS.datafile) as datalist:
474+
all_sys = datalist.read().splitlines()
475+
else:
476+
all_sys = expand_sys_str(FLAGS.system)
477+
data_systems = process_systems(all_sys)
473478
data_single = DpLoaderSet(
474479
data_systems,
475480
1,

source/tests/pt/test_change_bias.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import shutil
5+
import tempfile
56
import unittest
67
from copy import (
78
deepcopy,
@@ -36,6 +37,9 @@
3637
to_torch_tensor,
3738
)
3839

40+
from .common import (
41+
run_dp,
42+
)
3943
from .model.test_permutation import (
4044
model_se_e2_a,
4145
)
@@ -77,12 +81,15 @@ def setUp(self):
7781
self.model_path_data_bias = Path(current_path) / (
7882
model_name + "data_bias" + ".pt"
7983
)
84+
self.model_path_data_file_bias = Path(current_path) / (
85+
model_name + "data_file_bias" + ".pt"
86+
)
8087
self.model_path_user_bias = Path(current_path) / (
8188
model_name + "user_bias" + ".pt"
8289
)
8390

8491
def test_change_bias_with_data(self):
85-
os.system(
92+
run_dp(
8693
f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}"
8794
)
8895
state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE)
@@ -99,9 +106,32 @@ def test_change_bias_with_data(self):
99106
expected_bias = expected_model.get_out_bias()
100107
torch.testing.assert_close(updated_bias, expected_bias)
101108

109+
def test_change_bias_with_data_sys_file(self):
110+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
111+
with open(tmp_file.name, "w") as f:
112+
f.writelines([sys + "\n" for sys in self.data_file])
113+
run_dp(
114+
f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}"
115+
)
116+
state_dict = torch.load(
117+
str(self.model_path_data_file_bias), map_location=DEVICE
118+
)
119+
model_params = state_dict["model"]["_extra_state"]["model_params"]
120+
model_for_wrapper = get_model_for_wrapper(model_params)
121+
wrapper = ModelWrapper(model_for_wrapper)
122+
wrapper.load_state_dict(state_dict["model"])
123+
updated_bias = wrapper.model["Default"].get_out_bias()
124+
expected_model = model_change_out_bias(
125+
self.trainer.wrapper.model["Default"],
126+
self.sampled,
127+
_bias_adjust_mode="change-by-statistic",
128+
)
129+
expected_bias = expected_model.get_out_bias()
130+
torch.testing.assert_close(updated_bias, expected_bias)
131+
102132
def test_change_bias_with_user_defined(self):
103133
user_bias = [0.1, 3.2, -0.5]
104-
os.system(
134+
run_dp(
105135
f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}"
106136
)
107137
state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE)

0 commit comments

Comments
 (0)