@@ -70,6 +70,56 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
7070 tuple[str, bool, Any]: The tag, whether the value is of the correct type, and the possibly fixed value.
7171 """
7272
73+ def is_equal_to (self , val1 : Any | list [Any ], obj2 : AbstractTag , val2 : Any | list [Any ]) -> bool :
74+ """Check if the two values are equal.
75+
76+ Args:
77+ val1 (Any): The value of this tag object.
78+ obj2 (AbstractTag): The other tag object.
79+ val2 (Any): The value of the other tag object.
80+
81+ Returns:
82+ bool: True if the two tag object/value pairs are equal, False otherwise.
83+ """
84+ if self .can_repeat :
85+ if not obj2 .can_repeat :
86+ return False
87+ val1 = val1 if isinstance (val1 , list ) else [val1 ]
88+ val2 = val2 if isinstance (val2 , list ) else [val2 ]
89+ if len (val1 ) != len (val2 ):
90+ return False
91+ return all (True in [self ._is_equal_to (v1 , obj2 , v2 ) for v2 in val2 ] for v1 in val1 )
92+ return self ._is_equal_to (val1 , obj2 , val2 )
93+
94+ @abstractmethod
95+ def _is_equal_to (self , val1 : Any , obj2 : AbstractTag , val2 : Any ) -> bool :
96+ """Check if the two values are equal.
97+
98+ Used to check if the two values are equal. Assumes val1 and val2 are single elements.
99+
100+ Args:
101+ val1 (Any): The value of this tag object.
102+ obj2 (AbstractTag): The other tag object.
103+ val2 (Any): The value of the other tag object.
104+
105+ Returns:
106+ bool: True if the two tag object/value pairs are equal, False otherwise.
107+ """
108+
109+ def _is_same_tagtype (
110+ self ,
111+ obj2 : AbstractTag ,
112+ ) -> bool :
113+ """Check if the two values are equal.
114+
115+ Args:
116+ obj2 (AbstractTag): The other tag object.
117+
118+ Returns:
119+ bool: True if the two tag object/value pairs are equal, False otherwise.
120+ """
121+ return isinstance (self , type (obj2 ))
122+
73123 def _validate_value_type (
74124 self , type_check : type , tag : str , value : Any , try_auto_type_fix : bool = False
75125 ) -> tuple [str , bool , Any ]:
@@ -258,6 +308,19 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
258308 """
259309 return self ._validate_value_type (bool , tag , value , try_auto_type_fix = try_auto_type_fix )
260310
311+ def _is_equal_to (self , val1 : Any , obj2 : AbstractTag , val2 : Any ) -> bool :
312+ """Check if the two values are equal.
313+
314+ Args:
315+ val1 (Any): The value of this tag object.
316+ obj2 (AbstractTag): The other tag object.
317+ val2 (Any): The value of the other tag object.
318+
319+ Returns:
320+ bool: True if the two tag object/value pairs are equal, False otherwise.
321+ """
322+ return self ._is_same_tagtype (obj2 ) and val1 == val2
323+
261324 def raise_value_error (self , tag : str , value : str ) -> None :
262325 """Raise a ValueError for the value string.
263326
@@ -335,6 +398,23 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
335398 """
336399 return self ._validate_value_type (str , tag , value , try_auto_type_fix = try_auto_type_fix )
337400
401+ def _is_equal_to (self , val1 : Any , obj2 : AbstractTag , val2 : Any ) -> bool :
402+ """Check if the two values are equal.
403+
404+ Args:
405+ val1 (Any): The value of this tag object.
406+ obj2 (AbstractTag): The other tag object.
407+ val2 (Any): The value of the other tag object.
408+
409+ Returns:
410+ bool: True if the two tag object/value pairs are equal, False otherwise.
411+ """
412+ if self ._is_same_tagtype (obj2 ):
413+ if not all (isinstance (x , str ) for x in (val1 , val2 )):
414+ raise ValueError ("Both values must be strings for StrTag comparison" )
415+ return val1 .strip () == val2 .strip ()
416+ return False
417+
338418 def read (self , tag : str , value : str ) -> str :
339419 """Read the value string for this tag.
340420
@@ -379,6 +459,8 @@ class AbstractNumericTag(AbstractTag):
379459 ub : float | None = None # upper bound
380460 lb_incl : bool = True # lower bound inclusive
381461 ub_incl : bool = True # upper bound inclusive
462+ eq_atol : float = 1.0e-8 # absolute tolerance for equality check
463+ eq_rtol : float = 1.0e-5 # relative tolerance for equality check
382464
383465 def val_is_within_bounds (self , value : float ) -> bool :
384466 """Check if the value is within the bounds.
@@ -425,6 +507,22 @@ def validate_value_bounds(
425507 return False , self .get_invalid_value_error_str (tag , value )
426508 return True , ""
427509
510+ def _is_equal_to (self , val1 , obj2 , val2 ):
511+ """Check if the two values are equal.
512+
513+ Used to check if the two values are equal. Doesn't need to be redefined for IntTag and FloatTag.
514+
515+ Args:
516+ val1 (Any): The value of this tag object.
517+ obj2 (AbstractTag): The other tag object.
518+ val2 (Any): The value of the other tag object.
519+ rtol (float, optional): Relative tolerance. Defaults to 1.e-5.
520+ atol (float, optional): Absolute tolerance. Defaults to 1.e-8.
521+ Returns:
522+ bool: True if the two tag object/value pairs are equal, False otherwise.
523+ """
524+ return self ._is_same_tagtype (obj2 ) and np .isclose (val1 , val2 , rtol = self .eq_rtol , atol = self .eq_atol )
525+
428526
429527@dataclass
430528class IntTag (AbstractNumericTag ):
@@ -620,6 +718,10 @@ def get_token_len(self) -> int:
620718 """
621719 return self ._get_token_len ()
622720
721+ def _is_equal_to (self , val1 , obj2 , val2 ):
722+ return True # TODO: We still need to actually implement initmagmom as a multi-format tag
723+ # raise NotImplementedError("equality not yet implemented for InitMagMomTag")
724+
623725
624726@dataclass
625727class TagContainer (AbstractTag ):
@@ -1013,6 +1115,28 @@ def get_dict_representation(self, tag: str, value: list) -> dict | list[dict]:
10131115 list_value = self ._make_str_for_dict (tag , value )
10141116 return self .read (tag , list_value )
10151117
1118+ def _is_equal_to (self , val1 , obj2 , val2 ):
1119+ """Check if the two values are equal.
1120+
1121+ Return False if (checked in following order)
1122+ - obj2 is not a TagContainer
1123+ - all of val1's subtags are not in val2
1124+ - val1 and val2 are not the same length (different number of subtags)
1125+ - at least one subtag in val1 is not equal to the corresponding subtag in val2
1126+ """
1127+ if self ._is_same_tagtype (obj2 ):
1128+ if isinstance (val1 , dict ) and isinstance (val2 , dict ):
1129+ if all (subtag in val2 for subtag in val1 ) and (len (list (val1 .keys ())) == len (list (val2 .keys ()))):
1130+ for subtag , subtag_type in self .subtags .items ():
1131+ if (subtag in val1 ) and (
1132+ not subtag_type .is_equal_to (val1 [subtag ], obj2 .subtags [subtag ], val2 [subtag ])
1133+ ):
1134+ return False
1135+ return True
1136+ return False
1137+ raise ValueError ("Values must be in dictionary format for TagContainer comparison" )
1138+ return False
1139+
10161140
10171141# TODO: Write StructureDefferedTagContainer back in (commented out code block removed
10181142# on 11/4/24) and make usable for tags like initial-magnetic-moments
@@ -1162,6 +1286,9 @@ def get_token_len(self) -> int:
11621286 """
11631287 raise NotImplementedError ("This method is not supposed to be called directly on MultiformatTag objects!" )
11641288
1289+ def _is_equal_to (self , val1 , obj2 , val2 ):
1290+ raise NotImplementedError ("This method is not supposed to be called directly on MultiformatTag objects!" )
1291+
11651292
11661293@dataclass
11671294class BoolTagContainer (TagContainer ):
0 commit comments