Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/metatrain/pet/tests/test_continue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import shutil

import metatensor
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_continue(monkeypatch, tmp_path):

dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["training"]["num_epochs"] = 0
loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()})
OmegaConf.resolve(loss_conf)
Expand Down
7 changes: 4 additions & 3 deletions src/metatrain/pet/tests/test_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import shutil

import pytest
Expand Down Expand Up @@ -140,7 +141,7 @@ def test_finetuning_restart(monkeypatch, tmp_path):

dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)

hypers["training"]["num_epochs"] = 1

Expand All @@ -166,7 +167,7 @@ def test_finetuning_restart(monkeypatch, tmp_path):
assert isinstance(model_finetune, PET)
model_finetune.restart(dataset_info)

hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)

hypers["training"]["num_epochs"] = 0

Expand Down Expand Up @@ -203,7 +204,7 @@ def test_finetuning_restart(monkeypatch, tmp_path):
["lora_" in name for name, _ in model_finetune_restart.named_parameters()]
)

hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)

hypers["training"]["num_epochs"] = 0

Expand Down
6 changes: 4 additions & 2 deletions src/metatrain/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import metatensor.torch as mts
import pytest
import torch
Expand Down Expand Up @@ -354,7 +356,7 @@ def test_output_per_atom():
def test_fixed_composition_weights():
"""Tests the correctness of the json schema for fixed_composition_weights"""

hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["training"]["fixed_composition_weights"] = {
"energy": {
1: 1.0,
Expand All @@ -370,7 +372,7 @@ def test_fixed_composition_weights():

def test_fixed_composition_weights_error():
"""Test that only input of type Dict[str, Dict[int, float]] are allowed."""
hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["training"]["fixed_composition_weights"] = {"energy": {"H": 300.0}}
hypers = OmegaConf.create(hypers)
with pytest.raises(ValueError, match=r"'H' does not match '\^\[0-9\]\+\$'"):
Expand Down
5 changes: 3 additions & 2 deletions src/metatrain/pet/tests/test_long_range.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy

import pytest


pytest.importorskip("torchpme")

import copy

import torch
from metatomic.torch import ModelOutput, System
Expand Down Expand Up @@ -79,7 +80,7 @@ def test_long_range_training(use_ewald):
targets, target_info_dict = read_targets(OmegaConf.create(conf))
targets = {"energy": targets["energy"]}
dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]})
hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["training"]["num_epochs"] = 2
hypers["training"]["scheduler_patience"] = 1
hypers["training"]["fixed_composition_weights"] = {}
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/pet/tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import random

import numpy as np
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_regression_energies_forces_train(device):
targets, target_info_dict = read_targets(OmegaConf.create(conf))
targets = {"energy": targets["energy"]}
dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]})
hypers = DEFAULT_HYPERS.copy()
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["training"]["num_epochs"] = 2
hypers["training"]["scheduler_patience"] = 1
hypers["training"]["fixed_composition_weights"] = {}
Expand Down
Loading