|
| 1 | +"""Test transformation jobs.""" |
| 2 | + |
| 3 | +try: |
| 4 | + import icet |
| 5 | +except ImportError: |
| 6 | + icet = None |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pytest |
| 10 | +from jobflow import run_locally |
| 11 | +from pymatgen.core import Structure |
| 12 | +from pymatgen.transformations.advanced_transformations import SQSTransformation |
| 13 | + |
| 14 | +from atomate2.common.jobs.transform import SQS |
| 15 | +from atomate2.common.schemas.transform import SQSTask |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture(scope="module") |
| 19 | +def simple_alloy() -> Structure: |
| 20 | + """Hexagonal close-packed 50-50 Mg-Al alloy.""" |
| 21 | + return Structure( |
| 22 | + 3.5 |
| 23 | + * np.array( |
| 24 | + [ |
| 25 | + [0.5, -(3.0 ** (0.5)) / 2.0, 0.0], |
| 26 | + [0.5, 3.0 ** (0.5) / 2.0, 0.0], |
| 27 | + [0.0, 0.0, 8 ** (0.5) / 3.0], |
| 28 | + ] |
| 29 | + ), |
| 30 | + [{"Mg": 0.5, "Al": 0.5}, {"Mg": 0.5, "Al": 0.5}], |
| 31 | + [[0.0, 0.0, 0.0], [1.0 / 3.0, 2.0 / 3.0, 0.5]], |
| 32 | + ) |
| 33 | + |
| 34 | + |
| 35 | +@pytest.mark.skipif( |
| 36 | + icet is None, reason="`icet` must be installed to perform this test." |
| 37 | +) |
| 38 | +def test_sqs(tmp_dir, simple_alloy): |
| 39 | + # Probably most common use case - just get one "best" SQS |
| 40 | + sqs_trans = SQSTransformation( |
| 41 | + scaling=4, |
| 42 | + best_only=False, |
| 43 | + sqs_method="icet-enumeration", |
| 44 | + ) |
| 45 | + job = SQS(transformation=sqs_trans).make(simple_alloy) |
| 46 | + |
| 47 | + output = run_locally(job)[job.uuid][1].output |
| 48 | + assert isinstance(output, SQSTask) |
| 49 | + assert output.final_structure.composition.as_dict() == {"Mg": 4, "Al": 4} |
| 50 | + assert isinstance(output.final_structure, Structure) |
| 51 | + assert output.final_structure.is_ordered |
| 52 | + assert all( |
| 53 | + getattr(output, attr) is None for attr in ("sqs_structures", "sqs_scores") |
| 54 | + ) |
| 55 | + assert isinstance(output.transformation, SQSTransformation) |
| 56 | + |
| 57 | + # Now simulate retrieving multiple SQSes |
| 58 | + sqs_trans = SQSTransformation( |
| 59 | + scaling=4, |
| 60 | + best_only=False, |
| 61 | + sqs_method="icet-monte_carlo", |
| 62 | + instances=3, |
| 63 | + icet_sqs_kwargs={"n_steps": 10}, # only 10-step search |
| 64 | + remove_duplicate_structures=False, # needed just to simulate output |
| 65 | + ) |
| 66 | + |
| 67 | + # return up to the two best structures |
| 68 | + job = SQS(transformation=sqs_trans).make(simple_alloy, return_ranked_list=2) |
| 69 | + output = run_locally(job)[job.uuid][1].output |
| 70 | + |
| 71 | + assert isinstance(output, SQSTask) |
| 72 | + |
| 73 | + # return_ranked_list - 1 structures and objective functions should be here |
| 74 | + assert all( |
| 75 | + len(getattr(output, attr)) == 1 for attr in ("sqs_structures", "sqs_scores") |
| 76 | + ) |
| 77 | + |
| 78 | + assert all( |
| 79 | + struct.composition.as_dict() == {"Mg": 4, "Al": 4} |
| 80 | + and isinstance(struct, Structure) |
| 81 | + for struct in output.sqs_structures |
| 82 | + ) |
0 commit comments