|
13 | 13 | _pop_states, |
14 | 14 | _slice_state, |
15 | 15 | concatenate_states, |
16 | | - infer_property_scope, |
| 16 | + get_attrs_for_scope, |
17 | 17 | initialize_state, |
18 | 18 | ) |
19 | 19 |
|
|
24 | 24 | from pymatgen.core import Structure |
25 | 25 |
|
26 | 26 |
|
27 | | -def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: |
28 | | - """Test inference of property scope.""" |
29 | | - scope = infer_property_scope(si_sim_state) |
30 | | - assert set(scope["global"]) == {"pbc"} |
31 | | - assert set(scope["per_atom"]) == { |
| 27 | +def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None: |
| 28 | + """Test getting attributes for a scope.""" |
| 29 | + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) |
| 30 | + assert set(per_atom_attrs.keys()) == { |
32 | 31 | "positions", |
33 | 32 | "masses", |
34 | 33 | "atomic_numbers", |
35 | 34 | "system_idx", |
36 | 35 | } |
37 | | - assert set(scope["per_system"]) == {"cell"} |
| 36 | + per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) |
| 37 | + assert set(per_system_attrs.keys()) == {"cell"} |
| 38 | + global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) |
| 39 | + assert set(global_attrs.keys()) == {"pbc"} |
38 | 40 |
|
39 | 41 |
|
40 | | -def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: |
41 | | - """Test inference of property scope.""" |
42 | | - state = MDState( |
43 | | - **asdict(si_sim_state), |
44 | | - momenta=torch.randn_like(si_sim_state.positions), |
45 | | - forces=torch.randn_like(si_sim_state.positions), |
46 | | - energy=torch.zeros((1,)), |
47 | | - ) |
48 | | - scope = infer_property_scope(state) |
49 | | - assert set(scope["global"]) == {"pbc"} |
50 | | - assert set(scope["per_atom"]) == { |
51 | | - "positions", |
52 | | - "masses", |
53 | | - "atomic_numbers", |
54 | | - "system_idx", |
55 | | - "forces", |
56 | | - "momenta", |
57 | | - } |
58 | | - assert set(scope["per_system"]) == {"cell", "energy"} |
| 42 | +def test_all_attributes_must_be_specified_in_scopes() -> None: |
| 43 | + """Test that an error is raised when we forget to specify the scope |
| 44 | + for an attribute in a child SimState class.""" |
| 45 | + with pytest.raises(TypeError) as excinfo: |
| 46 | + |
| 47 | + class ChildState(SimState): |
| 48 | + attribute_specified_in_scopes: bool |
| 49 | + attribute_not_specified_in_scopes: bool |
| 50 | + |
| 51 | + _atom_attributes = ( |
| 52 | + SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 |
| 53 | + ) |
| 54 | + |
| 55 | + assert "attribute_not_specified_in_scopes" in str(excinfo.value) |
| 56 | + assert "attribute_specified_in_scopes" not in str(excinfo.value) |
| 57 | + |
| 58 | + |
| 59 | +def test_no_duplicate_attributes_in_scopes() -> None: |
| 60 | + """Test that no attributes are specified in multiple scopes.""" |
| 61 | + |
| 62 | + # Capture the exception information using "as excinfo" |
| 63 | + with pytest.raises(TypeError) as excinfo: |
| 64 | + |
| 65 | + class ChildState(SimState): |
| 66 | + duplicated_attribute: bool |
| 67 | + |
| 68 | + _system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 |
| 69 | + _global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 |
| 70 | + |
| 71 | + assert "are declared multiple times" in str(excinfo.value) |
| 72 | + assert "duplicated_attribute" in str(excinfo.value) |
59 | 73 |
|
60 | 74 |
|
61 | 75 | def test_slice_substate( |
@@ -497,6 +511,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None: |
497 | 511 | class DeformState(SimState, DeformGradMixin): |
498 | 512 | """Test class that combines SimState with DeformGradMixin.""" |
499 | 513 |
|
| 514 | + _system_attributes = ( |
| 515 | + SimState._system_attributes # noqa: SLF001 |
| 516 | + | DeformGradMixin._system_attributes # noqa: SLF001 |
| 517 | + ) |
| 518 | + |
500 | 519 | def __init__( |
501 | 520 | self, |
502 | 521 | *args, |
|
0 commit comments