1- from typing import Dict , List , Union
1+ from typing import List , Union
22
33import torch
44from metatensor .torch import Labels , TensorBlock , TensorMap
@@ -49,12 +49,9 @@ def gradients(self) -> List[str]:
4949 @property
5050 def per_atom (self ) -> bool :
5151 """Whether the target is per atom."""
52- return (
53- "atom" in self .layout .block (0 ).samples .names
54- or (
55- "first_atom" in self .layout .block (0 ).samples .names
56- and "second_atom" in self .layout .block (0 ).samples .names
57- )
52+ return "atom" in self .layout .block (0 ).samples .names or (
53+ "first_atom" in self .layout .block (0 ).samples .names
54+ and "second_atom" in self .layout .block (0 ).samples .names
5855 )
5956
6057 def __repr__ (self ):
@@ -180,7 +177,7 @@ def _check_layout(self, layout: TensorMap) -> None:
180177 o3_lambda , o3_sigma , s2_pi = (
181178 int (key .values [0 ].item ()),
182179 int (key .values [1 ].item ()),
183- None
180+ None ,
184181 )
185182 else :
186183 assert len (key .names ) == 3
@@ -191,33 +188,33 @@ def _check_layout(self, layout: TensorMap) -> None:
191188 )
192189 if o3_sigma not in [- 1 , 1 ]:
193190 raise ValueError (
194- "The layout ``TensorMap`` of a spherical tensor target should "
195- "have key dimension 'o3_sigma' that is either -1 or 1. "
196- f"Found '{ o3_sigma } ' instead."
191+ "The layout ``TensorMap`` of a spherical tensor "
192+ "target should have key dimension 'o3_sigma' that "
193+ f"is either -1 or 1. Found '{ o3_sigma } ' instead."
197194 )
198195 if o3_lambda < 0 :
199196 raise ValueError (
200- "The layout ``TensorMap`` of a spherical tensor target should "
201- "have key dimension 'o3_lambda' that is non-negative. "
202- f"Found '{ o3_lambda } ' instead."
197+ "The layout ``TensorMap`` of a spherical tensor "
198+ "target should have key dimension 'o3_lambda' that "
199+ f"is non-negative. Found '{ o3_lambda } ' instead."
203200 )
204201 if s2_pi is not None :
205202 if s2_pi not in [- 1 , 0 , + 1 ]:
206203 raise ValueError (
207- "The layout ``TensorMap`` of a spherical tensor target should "
208- "have key dimension 's2_pi' that is either -1, 0, or +1. "
209- f"Found '{ s2_pi } ' instead."
204+ "The layout ``TensorMap`` of a spherical tensor "
205+ "target should have key dimension 's2_pi' that "
206+ f"is either -1, 0, or +1. Found '{ s2_pi } ' instead."
210207 )
211208 components = block .components
212209 if len (components ) != 1 :
213210 raise ValueError (
214- "The layout ``TensorMap`` of a spherical tensor target should "
215- "have a single component."
211+ "The layout ``TensorMap`` of a spherical tensor "
212+ "target should have a single component."
216213 )
217214 if len (components [0 ]) != 2 * o3_lambda + 1 :
218215 raise ValueError (
219- "Each ``TensorBlock`` of a spherical tensor target should have "
220- "a component with 2*o3_lambda + 1 elements."
216+ "Each ``TensorBlock`` of a spherical tensor target"
217+ "should have a component with 2*o3_lambda + 1 elements."
221218 f"Found '{ len (components [0 ])} ' elements instead."
222219 )
223220 if len (block .gradients_list ()) > 0 :
@@ -414,7 +411,6 @@ def _get_cartesian_target_info(target: DictConfig) -> TargetInfo:
414411
415412
416413def _get_spherical_target_info (target : DictConfig ) -> TargetInfo :
417-
418414 irreps = target ["type" ]["spherical" ]["irreps" ]
419415 atomic_basis = target ["type" ]["spherical" ].get ("atomic_basis" , None )
420416
@@ -427,14 +423,11 @@ def _get_spherical_target_info(target: DictConfig) -> TargetInfo:
427423 atomic_basis = Labels (
428424 names = atomic_basis_names ,
429425 values = torch .tensor (
430- [
431- [i [name ] for name in atomic_basis_names ]
432- for i in atomic_basis
433- ],
426+ [[i [name ] for name in atomic_basis_names ] for i in atomic_basis ],
434427 dtype = torch .int32 ,
435428 ),
436429 )
437-
430+
438431 # Infer the sample names
439432 if target ["per_atom" ]:
440433 if atomic_basis is None :
0 commit comments