Skip to content

Commit a68beff

Browse files
committed
add __eq__ for HashAlgorithm and padding instances
1 parent 5044290 commit a68beff

File tree

10 files changed

+341
-7
lines changed

10 files changed

+341
-7
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ Changelog
1818
* Removed the deprecated ``CAST5``, ``SEED``, ``IDEA``, and ``Blowfish``
1919
classes from the cipher module. These are still available in
2020
:doc:`/hazmat/decrepit/index`.
21+
* Make instances of
22+
:class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm` as well as
23+
instances of classes in
24+
:mod:`~cryptography.hazmat.primitives.asymmetric.padding`
25+
comparable.
2126

2227
.. _v45-0-6:
2328

docs/hazmat/primitives/asymmetric/cloudhsm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ if you only need a subset of functionality.
8888
... Maps the cryptography padding and algorithm to the corresponding KMS signing algorithm.
8989
... This is specific to your implementation.
9090
... """
91-
... if isinstance(padding, PKCS1v15) and isinstance(algorithm, hashes.SHA256):
91+
... if padding == PKCS1v15() and algorithm == hashes.SHA256():
9292
... return b"RSA_PKCS1_V1_5_SHA_256"
9393
... else:
9494
... raise NotImplementedError()

docs/x509/reference.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ Loading Certificate Revocation Lists
248248
>>> from cryptography import x509
249249
>>> from cryptography.hazmat.primitives import hashes
250250
>>> crl = x509.load_pem_x509_crl(pem_crl_data)
251-
>>> isinstance(crl.signature_hash_algorithm, hashes.SHA256)
251+
>>> crl.signature_hash_algorithm == hashes.SHA256()
252252
True
253253

254254
.. function:: load_der_x509_crl(data)
@@ -287,7 +287,7 @@ Loading Certificate Signing Requests
287287
>>> from cryptography import x509
288288
>>> from cryptography.hazmat.primitives import hashes
289289
>>> csr = x509.load_pem_x509_csr(pem_req_data)
290-
>>> isinstance(csr.signature_hash_algorithm, hashes.SHA256)
290+
>>> csr.signature_hash_algorithm == hashes.SHA256()
291291
True
292292

293293
.. function:: load_der_x509_csr(data)
@@ -477,7 +477,7 @@ X.509 Certificate Object
477477
.. doctest::
478478

479479
>>> from cryptography.hazmat.primitives import hashes
480-
>>> isinstance(cert.signature_hash_algorithm, hashes.SHA256)
480+
>>> cert.signature_hash_algorithm == hashes.SHA256()
481481
True
482482

483483
.. attribute:: signature_algorithm_oid
@@ -716,7 +716,7 @@ X.509 CRL (Certificate Revocation List) Object
716716
.. doctest::
717717

718718
>>> from cryptography.hazmat.primitives import hashes
719-
>>> isinstance(crl.signature_hash_algorithm, hashes.SHA256)
719+
>>> crl.signature_hash_algorithm == hashes.SHA256()
720720
True
721721

722722
.. attribute:: signature_algorithm_oid
@@ -1119,7 +1119,7 @@ X.509 CSR (Certificate Signing Request) Object
11191119
.. doctest::
11201120

11211121
>>> from cryptography.hazmat.primitives import hashes
1122-
>>> isinstance(csr.signature_hash_algorithm, hashes.SHA256)
1122+
>>> csr.signature_hash_algorithm == hashes.SHA256()
11231123
True
11241124

11251125
.. attribute:: signature_algorithm_oid

src/cryptography/hazmat/primitives/asymmetric/padding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import typing
89

910
from cryptography.hazmat.primitives import hashes
1011
from cryptography.hazmat.primitives._asymmetric import (
@@ -16,6 +17,9 @@
1617
class PKCS1v15(AsymmetricPadding):
1718
name = "EMSA-PKCS1-v1_5"
1819

20+
def __eq__(self, other: typing.Any) -> bool:
21+
return isinstance(other, PKCS1v15)
22+
1923

2024
class _MaxLength:
2125
"Sentinel value for `MAX_LENGTH`."
@@ -56,6 +60,18 @@ def __init__(
5660

5761
self._salt_length = salt_length
5862

63+
def __eq__(self, other: typing.Any) -> bool:
64+
if isinstance(self._salt_length, int):
65+
eq_salt_length = self._salt_length == other._salt_length
66+
else:
67+
eq_salt_length = self._salt_length is other._salt_length
68+
69+
return (
70+
isinstance(other, PSS)
71+
and eq_salt_length
72+
and self._mgf == other._mgf
73+
)
74+
5975
@property
6076
def mgf(self) -> MGF:
6177
return self._mgf
@@ -77,6 +93,14 @@ def __init__(
7793
self._algorithm = algorithm
7894
self._label = label
7995

96+
def __eq__(self, other: typing.Any) -> bool:
97+
return (
98+
isinstance(other, OAEP)
99+
and self._mgf == other._mgf
100+
and self._algorithm == other._algorithm
101+
and self._label == other._label
102+
)
103+
80104
@property
81105
def algorithm(self) -> hashes.HashAlgorithm:
82106
return self._algorithm
@@ -89,6 +113,13 @@ def mgf(self) -> MGF:
89113
class MGF(metaclass=abc.ABCMeta):
90114
_algorithm: hashes.HashAlgorithm
91115

116+
@abc.abstractmethod
117+
def __eq__(self, other: typing.Any) -> bool:
118+
"""
119+
Implement equality checking.
120+
"""
121+
...
122+
92123

93124
class MGF1(MGF):
94125
def __init__(self, algorithm: hashes.HashAlgorithm):
@@ -97,6 +128,9 @@ def __init__(self, algorithm: hashes.HashAlgorithm):
97128

98129
self._algorithm = algorithm
99130

131+
def __eq__(self, other: typing.Any) -> bool:
132+
return isinstance(other, MGF1) and self._algorithm == other._algorithm
133+
100134

101135
def calculate_max_pss_salt_length(
102136
key: rsa.RSAPrivateKey | rsa.RSAPublicKey,

src/cryptography/hazmat/primitives/hashes.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import typing
89

910
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
1011
from cryptography.utils import Buffer
@@ -36,6 +37,13 @@
3637

3738

3839
class HashAlgorithm(metaclass=abc.ABCMeta):
40+
@abc.abstractmethod
41+
def __eq__(self, other: typing.Any) -> bool:
42+
"""
43+
Implement equality checking.
44+
"""
45+
...
46+
3947
@property
4048
@abc.abstractmethod
4149
def name(self) -> str:
@@ -103,66 +111,99 @@ class SHA1(HashAlgorithm):
103111
digest_size = 20
104112
block_size = 64
105113

114+
def __eq__(self, other: typing.Any) -> bool:
115+
return isinstance(other, SHA1)
116+
106117

107118
class SHA512_224(HashAlgorithm): # noqa: N801
108119
name = "sha512-224"
109120
digest_size = 28
110121
block_size = 128
111122

123+
def __eq__(self, other: typing.Any) -> bool:
124+
return isinstance(other, SHA512_224)
125+
112126

113127
class SHA512_256(HashAlgorithm): # noqa: N801
114128
name = "sha512-256"
115129
digest_size = 32
116130
block_size = 128
117131

132+
def __eq__(self, other: typing.Any) -> bool:
133+
return isinstance(other, SHA512_256)
134+
118135

119136
class SHA224(HashAlgorithm):
120137
name = "sha224"
121138
digest_size = 28
122139
block_size = 64
123140

141+
def __eq__(self, other: typing.Any) -> bool:
142+
return isinstance(other, SHA224)
143+
124144

125145
class SHA256(HashAlgorithm):
126146
name = "sha256"
127147
digest_size = 32
128148
block_size = 64
129149

150+
def __eq__(self, other: typing.Any) -> bool:
151+
return isinstance(other, SHA256)
152+
130153

131154
class SHA384(HashAlgorithm):
132155
name = "sha384"
133156
digest_size = 48
134157
block_size = 128
135158

159+
def __eq__(self, other: typing.Any) -> bool:
160+
return isinstance(other, SHA384)
161+
136162

137163
class SHA512(HashAlgorithm):
138164
name = "sha512"
139165
digest_size = 64
140166
block_size = 128
141167

168+
def __eq__(self, other: typing.Any) -> bool:
169+
return isinstance(other, SHA512)
170+
142171

143172
class SHA3_224(HashAlgorithm): # noqa: N801
144173
name = "sha3-224"
145174
digest_size = 28
146175
block_size = None
147176

177+
def __eq__(self, other: typing.Any) -> bool:
178+
return isinstance(other, SHA3_224)
179+
148180

149181
class SHA3_256(HashAlgorithm): # noqa: N801
150182
name = "sha3-256"
151183
digest_size = 32
152184
block_size = None
153185

186+
def __eq__(self, other: typing.Any) -> bool:
187+
return isinstance(other, SHA3_256)
188+
154189

155190
class SHA3_384(HashAlgorithm): # noqa: N801
156191
name = "sha3-384"
157192
digest_size = 48
158193
block_size = None
159194

195+
def __eq__(self, other: typing.Any) -> bool:
196+
return isinstance(other, SHA3_384)
197+
160198

161199
class SHA3_512(HashAlgorithm): # noqa: N801
162200
name = "sha3-512"
163201
digest_size = 64
164202
block_size = None
165203

204+
def __eq__(self, other: typing.Any) -> bool:
205+
return isinstance(other, SHA3_512)
206+
166207

167208
class SHAKE128(HashAlgorithm, ExtendableOutputFunction):
168209
name = "shake128"
@@ -177,6 +218,12 @@ def __init__(self, digest_size: int):
177218

178219
self._digest_size = digest_size
179220

221+
def __eq__(self, other: typing.Any) -> bool:
222+
return (
223+
isinstance(other, SHAKE128)
224+
and self._digest_size == other._digest_size
225+
)
226+
180227
@property
181228
def digest_size(self) -> int:
182229
return self._digest_size
@@ -195,6 +242,12 @@ def __init__(self, digest_size: int):
195242

196243
self._digest_size = digest_size
197244

245+
def __eq__(self, other: typing.Any) -> bool:
246+
return (
247+
isinstance(other, SHAKE256)
248+
and self._digest_size == other._digest_size
249+
)
250+
198251
@property
199252
def digest_size(self) -> int:
200253
return self._digest_size
@@ -205,6 +258,9 @@ class MD5(HashAlgorithm):
205258
digest_size = 16
206259
block_size = 64
207260

261+
def __eq__(self, other: typing.Any) -> bool:
262+
return isinstance(other, MD5)
263+
208264

209265
class BLAKE2b(HashAlgorithm):
210266
name = "blake2b"
@@ -218,6 +274,12 @@ def __init__(self, digest_size: int):
218274

219275
self._digest_size = digest_size
220276

277+
def __eq__(self, other: typing.Any) -> bool:
278+
return (
279+
isinstance(other, BLAKE2b)
280+
and self._digest_size == other._digest_size
281+
)
282+
221283
@property
222284
def digest_size(self) -> int:
223285
return self._digest_size
@@ -235,6 +297,12 @@ def __init__(self, digest_size: int):
235297

236298
self._digest_size = digest_size
237299

300+
def __eq__(self, other: typing.Any) -> bool:
301+
return (
302+
isinstance(other, BLAKE2s)
303+
and self._digest_size == other._digest_size
304+
)
305+
238306
@property
239307
def digest_size(self) -> int:
240308
return self._digest_size
@@ -244,3 +312,6 @@ class SM3(HashAlgorithm):
244312
name = "sm3"
245313
digest_size = 32
246314
block_size = 64
315+
316+
def __eq__(self, other: typing.Any) -> bool:
317+
return isinstance(other, SM3)

tests/doubles.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
33
# for complete details.
44

5+
import typing
56

67
from cryptography.hazmat.primitives import hashes, serialization
78
from cryptography.hazmat.primitives.asymmetric import padding
@@ -40,6 +41,12 @@ class DummyHashAlgorithm(hashes.HashAlgorithm):
4041
def __init__(self, digest_size: int = 32) -> None:
4142
self._digest_size = digest_size
4243

44+
def __eq__(self, other: typing.Any) -> bool:
45+
return (
46+
isinstance(self, DummyHashAlgorithm)
47+
and self._digest_size == other._digest_size
48+
)
49+
4350
@property
4451
def digest_size(self) -> int:
4552
return self._digest_size

tests/hazmat/backends/test_openssl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import itertools
7+
import typing
78

89
import pytest
910

@@ -32,6 +33,9 @@ class DummyMGF(padding.MGF):
3233
_salt_length = 0
3334
_algorithm = hashes.SHA1()
3435

36+
def __eq__(self, other: typing.Any) -> bool:
37+
return isinstance(other, DummyMGF)
38+
3539

3640
class TestOpenSSL:
3741
def test_backend_exists(self):

0 commit comments

Comments
 (0)