diff --git a/.github/workflows/opt.yaml b/.github/workflows/opt.yaml index 4238803..0e963b9 100644 --- a/.github/workflows/opt.yaml +++ b/.github/workflows/opt.yaml @@ -143,8 +143,8 @@ jobs: input_file=${{ inputs.path }} # path to the input YAML file input_dir=$(dirname $input_file) # parent directory of input YAML file - git add $input_dir/output/{dde,icrmsd,rmsd,tfd}.csv - git add $input_dir/output/{dde,rmsd,rmsd_cdf,tfd,tfd_cdf,bonds,angles,dihedrals,impropers}.png + git add $input_dir/output/.*_{dde,icrmsd,rmsd,tfd}.csv + git add $input_dir/output/.*_{dde,rmsd,rmsd_cdf,tfd,tfd_cdf,bonds,angles,dihedrals,impropers}.png git commit -m "Add benchmark results" git push diff --git a/config.py b/config.py index 776dcbe..842bc2d 100644 --- a/config.py +++ b/config.py @@ -5,7 +5,7 @@ @dataclass class Config: - forcefield: str + forcefields: list[str] datasets: list[str] @classmethod diff --git a/main.py b/main.py index a1a6b7e..0d5e26c 100644 --- a/main.py +++ b/main.py @@ -19,19 +19,27 @@ def make_csvs(store, forcefield, out_dir): + ff_name = forcefield.split(".offxml")[0] print("getting DDEs") - store.get_dde(forcefield, skip_check=True).to_csv(f"{out_dir}/dde.csv") + store.get_dde(forcefield, skip_check=True).to_csv(f"{out_dir}/{ff_name}_dde.csv") print("getting RMSDs") - store.get_rmsd(forcefield, skip_check=True).to_csv(f"{out_dir}/rmsd.csv") + store.get_rmsd(forcefield, skip_check=True).to_csv(f"{out_dir}/{ff_name}_rmsd.csv") print("getting TFDs") - store.get_tfd(forcefield, skip_check=True).to_csv(f"{out_dir}/tfd.csv") + store.get_tfd(forcefield, skip_check=True).to_csv(f"{out_dir}/{ff_name}_tfd.csv") print("getting internal coordinate RMSDs") store.get_internal_coordinate_rmsd(forcefield, skip_check=True).to_csv( - f"{out_dir}/icrmsd.csv" + f"{out_dir}/{ff_name}_icrmsd.csv" ) -def _main(forcefield, dataset, sqlite_file, out_dir, procs, invalidate_cache): +def _main( + forcefields: list[str], + dataset: str, + sqlite_file: str, + out_dir: str, + procs: int, + invalidate_cache: bool, +): if invalidate_cache and os.path.exists(sqlite_file): os.remove(sqlite_file) if os.path.exists(sqlite_file): @@ -43,15 +51,17 @@ def _main(forcefield, dataset, sqlite_file, out_dir, procs, invalidate_cache): crc = QCArchiveDataset.model_validate_json(inp.read()) store = MoleculeStore.from_qcarchive_dataset(crc, sqlite_file) - print("started optimizing store", flush=True) - start = time.time() - store.optimize_mm(force_field=forcefield, n_processes=procs) - print(f"finished optimizing after {time.time() - start} sec") + for forcefield in forcefields: + print(f"started optimizing store with {forcefield}", flush=True) + start = time.time() + print(f"optimizing with {forcefield}", flush=True) + store.optimize_mm(force_field=forcefield, n_processes=procs) + print(f"finished optimizing after {time.time() - start} sec") - if not os.path.exists(out_dir): - os.makedirs(out_dir) + if not os.path.exists(out_dir): + os.makedirs(out_dir) - make_csvs(store, forcefield, out_dir) + make_csvs(store, forcefield, out_dir) if __name__ == "__main__":