Skip to content

Commit d0e5f5d

Browse files
add test for SQS jobs
1 parent 2a3a5fa commit d0e5f5d

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

src/atomate2/common/jobs/transform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def check_structure(structure: Structure, scaling: Sequence[int]) -> Structure:
7070

7171
if isinstance(scaling, int):
7272
nsites = scaling * len(struct)
73-
elif hasattr(scaling, "__len__") and len(scaling) == 3:
73+
elif (
74+
hasattr(scaling, "__len__")
75+
and all(isinstance(sf, int) for sf in scaling)
76+
and len(scaling) == 3
77+
):
7478
nsites = len(struct * scaling)
7579
else:
7680
raise ValueError(

src/atomate2/common/schemas/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class SQSTask(TransformTask):
4949
sqs_structures: list[Structure] | None = Field(
5050
None, description="A list of other good SQS candidates."
5151
)
52-
sqs_scores: list[Structure] | None = Field(
52+
sqs_scores: list[float] | None = Field(
5353
None,
5454
description=(
5555
"The objective function values for the structures in `sqs_structures`"
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Test transformation jobs."""
2+
3+
try:
4+
import icet
5+
except ImportError:
6+
icet = None
7+
8+
import numpy as np
9+
import pytest
10+
from jobflow import run_locally
11+
from pymatgen.core import Structure
12+
from pymatgen.transformations.advanced_transformations import SQSTransformation
13+
14+
from atomate2.common.jobs.transform import SQS
15+
from atomate2.common.schemas.transform import SQSTask
16+
17+
18+
@pytest.fixture(scope="module")
19+
def simple_alloy() -> Structure:
20+
"""Hexagonal close-packed 50-50 Mg-Al alloy."""
21+
return Structure(
22+
3.5
23+
* np.array(
24+
[
25+
[0.5, -(3.0 ** (0.5)) / 2.0, 0.0],
26+
[0.5, 3.0 ** (0.5) / 2.0, 0.0],
27+
[0.0, 0.0, 8 ** (0.5) / 3.0],
28+
]
29+
),
30+
[{"Mg": 0.5, "Al": 0.5}, {"Mg": 0.5, "Al": 0.5}],
31+
[[0.0, 0.0, 0.0], [1.0 / 3.0, 2.0 / 3.0, 0.5]],
32+
)
33+
34+
35+
@pytest.mark.skipif(
36+
icet is None, reason="`icet` must be installed to perform this test."
37+
)
38+
def test_sqs(tmp_dir, simple_alloy):
39+
# Probably most common use case - just get one "best" SQS
40+
sqs_trans = SQSTransformation(
41+
scaling=4,
42+
best_only=False,
43+
sqs_method="icet-enumeration",
44+
)
45+
job = SQS(transformation=sqs_trans).make(simple_alloy)
46+
47+
output = run_locally(job)[job.uuid][1].output
48+
assert isinstance(output, SQSTask)
49+
assert output.final_structure.composition.as_dict() == {"Mg": 4, "Al": 4}
50+
assert isinstance(output.final_structure, Structure)
51+
assert output.final_structure.is_ordered
52+
assert all(
53+
getattr(output, attr) is None for attr in ("sqs_structures", "sqs_scores")
54+
)
55+
assert isinstance(output.transformation, SQSTransformation)
56+
57+
# Now simulate retrieving multiple SQSes
58+
sqs_trans = SQSTransformation(
59+
scaling=4,
60+
best_only=False,
61+
sqs_method="icet-monte_carlo",
62+
instances=3,
63+
icet_sqs_kwargs={"n_steps": 10}, # only 10-step search
64+
remove_duplicate_structures=False, # needed just to simulate output
65+
)
66+
67+
# return up to the two best structures
68+
job = SQS(transformation=sqs_trans).make(simple_alloy, return_ranked_list=2)
69+
output = run_locally(job)[job.uuid][1].output
70+
71+
assert isinstance(output, SQSTask)
72+
73+
# return_ranked_list - 1 structures and objective functions should be here
74+
assert all(
75+
len(getattr(output, attr)) == 1 for attr in ("sqs_structures", "sqs_scores")
76+
)
77+
78+
assert all(
79+
struct.composition.as_dict() == {"Mg": 4, "Al": 4}
80+
and isinstance(struct, Structure)
81+
for struct in output.sqs_structures
82+
)

0 commit comments

Comments
 (0)