Skip to content

Commit 96e9c06

Browse files
support one_to_many transformations, and add test for both behaviors
1 parent 89c63b5 commit 96e9c06

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

src/atomate2/common/jobs/transform.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,42 @@ class Transformer(Maker):
2525
"""Apply a pymatgen transformation, as a job.
2626
2727
For many of the standard and advanced transformations,
28-
this will work just by supplying the transformation.
28+
this should "just work" by supplying the transformation.
2929
"""
3030

3131
transformation: AbstractTransformation
3232
name: str = "pymatgen transformation maker"
3333

3434
@job
35-
def make(self, structure: Structure, **kwargs) -> TransformTask:
36-
"""Run the transformation."""
35+
def make(
36+
self, structure: Structure, **kwargs
37+
) -> TransformTask | list[TransformTask]:
38+
"""Evaluate the transformation.
39+
40+
Parameters
41+
----------
42+
structure : Structure to transform
43+
**kwargs : to pass to the `apply_transformation` method
44+
45+
Returns
46+
-------
47+
list of TransformTask, if `self.transformation.is_one_to_many`
48+
(many structures are produced from a single transformation)
49+
50+
TransformTask, otherwise
51+
"""
3752
transformed_structure = self.transformation.apply_transformation(
3853
structure, **kwargs
3954
)
55+
if self.transformation.is_one_to_many:
56+
return [
57+
TransformTask(
58+
input_structure=structure,
59+
final_structure=dct["structure"],
60+
transformation=dct.get("transformation") or self.transformation,
61+
)
62+
for dct in transformed_structure
63+
]
4064
return TransformTask(
4165
input_structure=structure,
4266
final_structure=transformed_structure,

tests/common/jobs/test_transform.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77

88
import numpy as np
99
import pytest
10-
from jobflow import run_locally
10+
from jobflow import Flow, run_locally
1111
from pymatgen.core import Structure
1212
from pymatgen.transformations.advanced_transformations import SQSTransformation
13+
from pymatgen.transformations.standard_transformations import (
14+
OrderDisorderedStructureTransformation,
15+
OxidationStateDecorationTransformation,
16+
)
1317

14-
from atomate2.common.jobs.transform import SQS
15-
from atomate2.common.schemas.transform import SQSTask
18+
from atomate2.common.jobs.transform import SQS, Transformer
19+
from atomate2.common.schemas.transform import SQSTask, TransformTask
1620

1721

1822
@pytest.fixture(scope="module")
@@ -32,6 +36,43 @@ def simple_alloy() -> Structure:
3236
)
3337

3438

39+
def test_simple_and_advanced():
40+
# simple disordered zincblende structure
41+
structure = Structure(
42+
3.8 * np.array([[0.0, 0.5, 0.5], [0.5, 0.0, 0.5], [0.5, 0.5, 0.0]]),
43+
["Zn", {"S": 0.75, "Se": 0.25}],
44+
[[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]],
45+
).to_conventional()
46+
47+
oxi_dict = {"Zn": 2, "S": -2, "Se": -2}
48+
oxi_job = Transformer(
49+
name="oxistate", transformation=OxidationStateDecorationTransformation(oxi_dict)
50+
).make(structure)
51+
52+
odst_job = Transformer(
53+
name="odst", transformation=OrderDisorderedStructureTransformation()
54+
).make(oxi_job.output.final_structure, return_ranked_list=2)
55+
56+
flow = Flow([oxi_job, odst_job])
57+
resp = run_locally(flow)
58+
59+
oxi_state_output = resp[oxi_job.uuid][1].output
60+
assert isinstance(oxi_state_output, TransformTask)
61+
62+
# check correct assignment of oxidation states
63+
assert all(
64+
specie.oxi_state == oxi_dict.get(specie.element.value)
65+
for site in oxi_state_output.final_structure
66+
for specie in site.species
67+
)
68+
69+
odst_output = resp[odst_job.uuid][1].output
70+
# return_ranked_list = 2, so should get 2 output docs
71+
assert len(odst_output) == 2
72+
assert all(isinstance(doc, TransformTask) for doc in odst_output)
73+
assert all(doc.final_structure.is_ordered for doc in odst_output)
74+
75+
3576
@pytest.mark.skipif(
3677
icet is None, reason="`icet` must be installed to perform this test."
3778
)

0 commit comments

Comments
 (0)