Skip to content

Commit b9ff6f6

Browse files
committed
Lint format
1 parent 06905ea commit b9ff6f6

File tree

8 files changed

+180
-150
lines changed

8 files changed

+180
-150
lines changed

src/metatrain/pet/model.py

Lines changed: 113 additions & 82 deletions
Large diffs are not rendered by default.

src/metatrain/utils/additive/_base_composition.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,10 @@ def forward(
410410
)
411411
else: # per_pair
412412
sample_values.append(
413-
torch.tensor([int(A), int(i), int(i), 0, 0, 0], dtype=torch.int32)
413+
torch.tensor(
414+
[int(A), int(i), int(i), 0, 0, 0],
415+
dtype=torch.int32,
416+
)
414417
)
415418
if self.sample_kinds[output_name] == "per_atom":
416419
sample_labels = Labels(
@@ -419,7 +422,14 @@ def forward(
419422
)
420423
else:
421424
sample_labels = Labels(
422-
["system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
425+
[
426+
"system",
427+
"first_atom",
428+
"second_atom",
429+
"cell_shift_a",
430+
"cell_shift_b",
431+
"cell_shift_c",
432+
],
423433
torch.vstack(sample_values),
424434
)
425435
X = self._compute_X_per_atom(
@@ -550,7 +560,7 @@ def _include_key(key: LabelsEntry) -> bool:
550560
if values are 0 and 1 respectively (indicating an invariant block of a
551561
spherical target).
552562
- If the key has names ["o3_lambda", "o3_sigma", "s2_pi], it is included
553-
if values are 0, 1, 0 respectively, indicating an unsymmetrized
563+
if values are 0, 1, 0 respectively, indicating an unsymmetrized
554564
invariant block of a per-pair target.
555565
"""
556566
include_key = False

src/metatrain/utils/additive/remove.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from metatensor.torch.operations._multiply import _multiply_block_constant
99
from metatomic.torch import System
1010

11+
from ..basis import get_onsite_samples_mask
1112
from ..data import TargetInfo
1213
from ..evaluate_model import evaluate_model
13-
from ..basis import get_onsite_samples_mask
1414

1515

1616
def remove_additive(
@@ -78,7 +78,7 @@ def remove_additive(
7878
values[onsite_sample_idxs] = old_block.values.detach()
7979
else:
8080
values = old_block.values.detach()
81-
81+
8282
device = targets[target_key].block(block_key).values.device
8383
block = metatensor.torch.TensorBlock(
8484
values=values.to(device=device),

src/metatrain/utils/augmentation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ def _apply_wigner_D_matrices(
157157
block.samples.values[:, 0], return_inverse=True
158158
)
159159
split_values = [
160-
values[inverse_indices == i]
161-
for i in range(len(unique_system_ids))
160+
values[inverse_indices == i] for i in range(len(unique_system_ids))
162161
]
163162
else:
164163
if "atom" in block.samples.names:

src/metatrain/utils/basis.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import torch
2-
from metatensor.torch import LabelsEntry, Labels, TensorMap
1+
from typing import List
32

3+
import torch
4+
from metatensor.torch import Labels, TensorMap
45

56

67
def is_spherical_atomic_basis(target: TensorMap) -> bool:
@@ -10,12 +11,9 @@ def is_spherical_atomic_basis(target: TensorMap) -> bool:
1011
:param target: The target tensor map to check.
1112
:returns: True if the target is a spherical atomic basis, False otherwise.
1213
"""
13-
if not (
14-
"o3_lambda" in target.keys.names
15-
and "o3_sigma" in target.keys.names
16-
):
14+
if not ("o3_lambda" in target.keys.names and "o3_sigma" in target.keys.names):
1715
return False
18-
16+
1917
# i.e. electron density on basis
2018
if (
2119
target.sample_names == ["system", "atom"]
@@ -25,19 +23,27 @@ def is_spherical_atomic_basis(target: TensorMap) -> bool:
2523
):
2624
return True
2725

28-
2926
# i.e. hamiltonian
3027
if (
31-
target.sample_names == ["system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]
28+
target.sample_names
29+
== [
30+
"system",
31+
"first_atom",
32+
"second_atom",
33+
"cell_shift_a",
34+
"cell_shift_b",
35+
"cell_shift_c",
36+
]
3237
and len(target[0].components) == 1
3338
and target[0].components[0].names == ["o3_mu"]
34-
and target.property_names == ['first_atom_type', 'second_atom_type', 'l_1', 'l_2', 'n_1', 'n_2']
39+
and target.property_names
40+
== ["first_atom_type", "second_atom_type", "l_1", "l_2", "n_1", "n_2"]
3541
):
3642
return True
3743

38-
3944
return False
4045

46+
4147
def get_block_sample_idxs_per_atom(
4248
sample_labels_with_types: Labels,
4349
atomic_basis: Labels,
@@ -49,14 +55,10 @@ def get_block_sample_idxs_per_atom(
4955
respectively.
5056
"""
5157
assert sample_labels_with_types.names == ["system", "atom", "center_type"]
52-
58+
5359
# find the center types that have basis functions with the given o3_lambda
5460
valid_center_types = torch.sort(
55-
torch.unique(
56-
atomic_basis.values[
57-
atomic_basis.values[:, 0] == o3_lambda
58-
][:, 1]
59-
)
61+
torch.unique(atomic_basis.values[atomic_basis.values[:, 0] == o3_lambda][:, 1])
6062
)[0]
6163

6264
# find the sample indices that have these center types
@@ -83,11 +85,7 @@ def get_block_sample_idxs_per_pair(
8385
"""
8486

8587
if node_or_edge == "node":
86-
assert sample_labels_with_types.names == [
87-
"system",
88-
"atom",
89-
"center_type"
90-
]
88+
assert sample_labels_with_types.names == ["system", "atom", "center_type"]
9189
# Find the pairs of atom types that can be in this block
9290
atomic_type_pairs = atomic_basis.values[
9391
atomic_basis.select(
@@ -104,7 +102,8 @@ def get_block_sample_idxs_per_pair(
104102
][:, 3]
105103
sample_idxs = sample_labels_with_types.select(
106104
Labels(
107-
["center_type"], center_types.reshape(-1, 1),
105+
["center_type"],
106+
center_types.reshape(-1, 1),
108107
)
109108
)
110109

@@ -136,12 +135,9 @@ def get_block_sample_idxs_per_pair(
136135
torch.hstack(
137136
[
138137
atomic_type_pairs[:, 3:],
139-
torch.full(
140-
(atomic_type_pairs.shape[0], 1),
141-
s2_pi
142-
)
138+
torch.full((atomic_type_pairs.shape[0], 1), s2_pi),
143139
]
144-
)
140+
),
145141
)
146142
)
147143

@@ -297,4 +293,3 @@ def symmetrize_edge_features(
297293
edge_features_sym_m1,
298294
]
299295
)
300-

src/metatrain/utils/data/readers/metatensor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]:
7272
# actual properties of the tensor maps
7373
target_info.layout = _empty_tensor_map_like(tensor_map)
7474

75-
system_ids = metatensor.torch.unique_metadata(
76-
tensor_map, "samples", "system"
77-
)
75+
system_ids = metatensor.torch.unique_metadata(tensor_map, "samples", "system")
7876
selections = [
7977
Labels(
8078
names=["system"],

src/metatrain/utils/data/target_info.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Union
1+
from typing import List, Union
22

33
import torch
44
from metatensor.torch import Labels, TensorBlock, TensorMap
@@ -49,12 +49,9 @@ def gradients(self) -> List[str]:
4949
@property
5050
def per_atom(self) -> bool:
5151
"""Whether the target is per atom."""
52-
return (
53-
"atom" in self.layout.block(0).samples.names
54-
or (
55-
"first_atom" in self.layout.block(0).samples.names
56-
and "second_atom" in self.layout.block(0).samples.names
57-
)
52+
return "atom" in self.layout.block(0).samples.names or (
53+
"first_atom" in self.layout.block(0).samples.names
54+
and "second_atom" in self.layout.block(0).samples.names
5855
)
5956

6057
def __repr__(self):
@@ -180,7 +177,7 @@ def _check_layout(self, layout: TensorMap) -> None:
180177
o3_lambda, o3_sigma, s2_pi = (
181178
int(key.values[0].item()),
182179
int(key.values[1].item()),
183-
None
180+
None,
184181
)
185182
else:
186183
assert len(key.names) == 3
@@ -191,33 +188,33 @@ def _check_layout(self, layout: TensorMap) -> None:
191188
)
192189
if o3_sigma not in [-1, 1]:
193190
raise ValueError(
194-
"The layout ``TensorMap`` of a spherical tensor target should "
195-
"have key dimension 'o3_sigma' that is either -1 or 1."
196-
f"Found '{o3_sigma}' instead."
191+
"The layout ``TensorMap`` of a spherical tensor "
192+
"target should have key dimension 'o3_sigma' that "
193+
f"is either -1 or 1. Found '{o3_sigma}' instead."
197194
)
198195
if o3_lambda < 0:
199196
raise ValueError(
200-
"The layout ``TensorMap`` of a spherical tensor target should "
201-
"have key dimension 'o3_lambda' that is non-negative."
202-
f"Found '{o3_lambda}' instead."
197+
"The layout ``TensorMap`` of a spherical tensor "
198+
"target should have key dimension 'o3_lambda' that "
199+
f"is non-negative. Found '{o3_lambda}' instead."
203200
)
204201
if s2_pi is not None:
205202
if s2_pi not in [-1, 0, +1]:
206203
raise ValueError(
207-
"The layout ``TensorMap`` of a spherical tensor target should "
208-
"have key dimension 's2_pi' that is either -1, 0, or +1."
209-
f"Found '{s2_pi}' instead."
204+
"The layout ``TensorMap`` of a spherical tensor "
205+
"target should have key dimension 's2_pi' that "
206+
f"is either -1, 0, or +1. Found '{s2_pi}' instead."
210207
)
211208
components = block.components
212209
if len(components) != 1:
213210
raise ValueError(
214-
"The layout ``TensorMap`` of a spherical tensor target should "
215-
"have a single component."
211+
"The layout ``TensorMap`` of a spherical tensor "
212+
"target should have a single component."
216213
)
217214
if len(components[0]) != 2 * o3_lambda + 1:
218215
raise ValueError(
219-
"Each ``TensorBlock`` of a spherical tensor target should have "
220-
"a component with 2*o3_lambda + 1 elements."
216+
"Each ``TensorBlock`` of a spherical tensor target"
217+
"should have a component with 2*o3_lambda + 1 elements."
221218
f"Found '{len(components[0])}' elements instead."
222219
)
223220
if len(block.gradients_list()) > 0:
@@ -414,7 +411,6 @@ def _get_cartesian_target_info(target: DictConfig) -> TargetInfo:
414411

415412

416413
def _get_spherical_target_info(target: DictConfig) -> TargetInfo:
417-
418414
irreps = target["type"]["spherical"]["irreps"]
419415
atomic_basis = target["type"]["spherical"].get("atomic_basis", None)
420416

@@ -427,14 +423,11 @@ def _get_spherical_target_info(target: DictConfig) -> TargetInfo:
427423
atomic_basis = Labels(
428424
names=atomic_basis_names,
429425
values=torch.tensor(
430-
[
431-
[i[name] for name in atomic_basis_names]
432-
for i in atomic_basis
433-
],
426+
[[i[name] for name in atomic_basis_names] for i in atomic_basis],
434427
dtype=torch.int32,
435428
),
436429
)
437-
430+
438431
# Infer the sample names
439432
if target["per_atom"]:
440433
if atomic_basis is None:

src/metatrain/utils/loss.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def __call__(
9898
"TensorMapLoss requires the two TensorMaps to have the same "
9999
"components."
100100
)
101-
if not torch.all(block_1.samples.values[:, 1:] == block_2.samples.values[:, 1:]):
101+
if not torch.all(
102+
block_1.samples.values[:, 1:] == block_2.samples.values[:, 1:]
103+
):
102104
raise ValueError(
103105
"TensorMapLoss requires the two TensorMaps "
104106
"to have the same samples."
@@ -158,7 +160,9 @@ def __call__(
158160
assert values_1.shape[0] == 0
159161
else:
160162
loss += (
161-
self.weight * self.losses["values"](values_1, values_2) / sliding_weight
163+
self.weight
164+
* self.losses["values"](values_1, values_2)
165+
/ sliding_weight
162166
)
163167
for gradient_name in block_2.gradients_list():
164168
gradient_weight = self.gradient_weights[gradient_name]

0 commit comments

Comments
 (0)