Skip to content

Commit 82e022a

Browse files
committed
fix: improve type hints for import_key
1 parent 212ea38 commit 82e022a

File tree

7 files changed

+64
-24
lines changed

7 files changed

+64
-24
lines changed

src/joserfc/_rfc7517/models.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def import_from_dict(cls, value: DictKey) -> t.Any:
4242

4343
@classmethod
4444
@abstractmethod
45-
def import_from_bytes(cls, value: bytes, password: t.Optional[t.Any] = None) -> t.Any:
45+
def import_from_bytes(cls, value: bytes, password: t.Any = None) -> t.Any:
4646
pass
4747

4848
@staticmethod
4949
def as_bytes(
5050
key: GenericKey,
51-
encoding: t.Optional[t.Literal["PEM", "DER"]] = None,
52-
private: t.Optional[bool] = None,
53-
password: t.Optional[str] = None,
51+
encoding: t.Literal["PEM", "DER"] | None = None,
52+
private: bool | None = None,
53+
password: str | None = None,
5454
) -> bytes:
5555
raise NotImplementedError()
5656

@@ -88,7 +88,7 @@ def __init__(
8888
self,
8989
raw_value: NativePrivateKey | NativePublicKey,
9090
original_value: t.Any,
91-
parameters: t.Optional[KeyParameters] = None,
91+
parameters: KeyParameters | None = None,
9292
):
9393
self._raw_value = raw_value
9494
self.original_value = original_value
@@ -160,7 +160,7 @@ def public_key(self) -> NativePublicKey:
160160
raise NotImplementedError()
161161

162162
@property
163-
def private_key(self) -> t.Optional[NativePrivateKey]:
163+
def private_key(self) -> NativePrivateKey | None:
164164
raise NotImplementedError()
165165

166166
def thumbprint(self) -> str:
@@ -177,7 +177,7 @@ def thumbprint_uri(self) -> str:
177177
value = self.thumbprint()
178178
return concat_thumbprint_uri(value, self.thumbprint_digest_method)
179179

180-
def as_dict(self, private: t.Optional[bool] = None, **params: t.Any) -> DictKey:
180+
def as_dict(self, private: bool | None = None, **params: t.Any) -> DictKey:
181181
"""Output this key to a JWK format (in dict). By default, it will return
182182
the ``dict_value`` of this key.
183183
@@ -270,8 +270,8 @@ def validate_dict_key(cls, data: DictKey) -> None:
270270
def import_key(
271271
cls: t.Type[GenericKey],
272272
value: AnyKey,
273-
parameters: t.Optional[KeyParameters] = None,
274-
password: t.Optional[t.Any] = None,
273+
parameters: KeyParameters | None = None,
274+
password: t.Any = None,
275275
) -> GenericKey:
276276
if isinstance(value, dict):
277277
cls.validate_dict_key(value)
@@ -285,7 +285,7 @@ def import_key(
285285
def generate_key(
286286
cls: t.Type[GenericKey],
287287
size_or_crv: t.Any,
288-
parameters: t.Optional[KeyParameters] = None,
288+
parameters: KeyParameters | None = None,
289289
private: bool = True,
290290
auto_kid: bool = False,
291291
) -> GenericKey:
@@ -321,16 +321,16 @@ def raw_value(self) -> t.Union[NativePublicKey, NativePrivateKey]:
321321

322322
def as_bytes(
323323
self,
324-
encoding: t.Optional[t.Literal["PEM", "DER"]] = None,
325-
private: t.Optional[bool] = None,
326-
password: t.Optional[str] = None,
324+
encoding: t.Literal["PEM", "DER"] | None = None,
325+
private: bool | None = None,
326+
password: str | None = None,
327327
) -> bytes:
328328
return self.binding.as_bytes(self, encoding, private, password)
329329

330-
def as_pem(self, private: t.Optional[bool] = None, password: t.Optional[str] = None) -> bytes:
330+
def as_pem(self, private: bool | None = None, password: str | None = None) -> bytes:
331331
return self.as_bytes(private=private, password=password)
332332

333-
def as_der(self, private: t.Optional[bool] = None, password: t.Optional[str] = None) -> bytes:
333+
def as_der(self, private: bool | None = None, password: str | None = None) -> bytes:
334334
return self.as_bytes(encoding="DER", private=private, password=password)
335335

336336

src/joserfc/_rfc7518/ec_key.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..errors import InvalidExchangeKeyError
1818
from .._rfc7517.models import CurveKey
1919
from .._rfc7517.pem import CryptographyBinding
20-
from .._rfc7517.types import KeyParameters
20+
from .._rfc7517.types import KeyParameters, AnyKey
2121
from ..util import base64_to_int, int_to_base64
2222
from ..registry import KeyParameter
2323

@@ -142,6 +142,15 @@ def curve_name(self) -> str:
142142
def curve_key_size(self) -> int:
143143
return self.raw_value.curve.key_size
144144

145+
@classmethod
146+
def import_key(
147+
cls: t.Any,
148+
value: AnyKey,
149+
parameters: KeyParameters | None = None,
150+
password: t.Any = None,
151+
) -> "ECKey":
152+
return super(ECKey, cls).import_key(value, parameters, password)
153+
145154
@classmethod
146155
def generate_key(
147156
cls,

src/joserfc/_rfc7518/oct_key.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from ..registry import KeyParameter
1212
from .._rfc7517.models import SymmetricKey, NativeKeyBinding
13-
from .._rfc7517.types import KeyParameters, DictKey
13+
from .._rfc7517.types import KeyParameters, DictKey, AnyKey
1414

1515

1616
POSSIBLE_UNSAFE_KEYS = (
@@ -50,6 +50,15 @@ class OctKey(SymmetricKey):
5050
#: https://www.rfc-editor.org/rfc/rfc7518#section-6.4
5151
value_registry = {"k": KeyParameter("Key Value", "str", True, True)}
5252

53+
@classmethod
54+
def import_key(
55+
cls: Any,
56+
value: AnyKey,
57+
parameters: KeyParameters | None = None,
58+
password: Any = None,
59+
) -> "OctKey":
60+
return super(OctKey, cls).import_key(value, parameters, password)
61+
5362
@classmethod
5463
def generate_key(
5564
cls,

src/joserfc/_rfc7518/rsa_key.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import warnings
3-
from typing import TypedDict
3+
from typing import TypedDict, Any
44
from functools import cached_property
55
from cryptography.hazmat.primitives.asymmetric.rsa import (
66
generate_private_key,
@@ -18,7 +18,7 @@
1818
from ..errors import SecurityWarning
1919
from .._rfc7517.models import AsymmetricKey
2020
from .._rfc7517.pem import CryptographyBinding
21-
from .._rfc7517.types import KeyParameters
21+
from .._rfc7517.types import KeyParameters, AnyKey
2222
from ..util import int_to_base64, base64_to_int
2323

2424

@@ -132,6 +132,15 @@ def private_key(self) -> RSAPrivateKey | None:
132132
return self.raw_value
133133
return None
134134

135+
@classmethod
136+
def import_key(
137+
cls: Any,
138+
value: AnyKey,
139+
parameters: KeyParameters | None = None,
140+
password: Any = None,
141+
) -> "RSAKey":
142+
return super(RSAKey, cls).import_key(value, parameters, password)
143+
135144
@classmethod
136145
def generate_key(
137146
cls,

src/joserfc/_rfc8037/okp_key.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
NoEncryption,
1414
)
1515
from .._rfc7517.models import CurveKey
16-
from .._rfc7517.types import KeyParameters
16+
from .._rfc7517.types import KeyParameters, AnyKey
1717
from .._rfc7517.pem import CryptographyBinding
1818
from ..errors import InvalidExchangeKeyError
1919
from ..util import to_bytes, urlsafe_b64decode, urlsafe_b64encode
@@ -120,6 +120,15 @@ def private_key(self) -> PrivateOKPKey | None:
120120
def curve_name(self) -> str:
121121
return get_key_curve(self.raw_value)
122122

123+
@classmethod
124+
def import_key(
125+
cls: t.Any,
126+
value: AnyKey,
127+
parameters: KeyParameters | None = None,
128+
password: t.Any = None,
129+
) -> "OKPKey":
130+
return super(OKPKey, cls).import_key(value, parameters, password)
131+
123132
@classmethod
124133
def generate_key(
125134
cls,

src/joserfc/jwk.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def import_key(data: AnyKey, key_type: t.Literal["EC"], parameters: KeyParameter
115115
def import_key(data: AnyKey, key_type: t.Literal["OKP"], parameters: KeyParameters | None = None) -> OKPKey: ...
116116

117117

118+
@t.overload
119+
def import_key(data: DictKey, key_type: None = None, parameters: KeyParameters | None = None) -> Key: ...
120+
121+
118122
def import_key(
119123
data: AnyKey,
120124
key_type: t.Literal["oct", "RSA", "EC", "OKP"] | None = None,

tests/jwk/test_key_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def test_find_correct_key_with_use(self):
142142
key = OctKey.generate_key()
143143
dict_key = key.as_dict()
144144

145-
key1: OctKey = OctKey.import_key(dict_key, {"use": "enc"})
146-
key2: OctKey = OctKey.import_key(dict_key, {"use": "sig"})
145+
key1 = OctKey.import_key(dict_key, {"use": "enc"})
146+
key2 = OctKey.import_key(dict_key, {"use": "sig"})
147147
self.assertEqual(key1.kid, key2.kid)
148148

149149
key_set = KeySet([key1, key2])
@@ -165,8 +165,8 @@ def test_find_correct_key_with_alg(self):
165165
key = OctKey.generate_key()
166166
dict_key = key.as_dict()
167167

168-
key1: OctKey = OctKey.import_key(dict_key, {"alg": "HS256"})
169-
key2: OctKey = OctKey.import_key(dict_key, {"alg": "dir"})
168+
key1 = OctKey.import_key(dict_key, {"alg": "HS256"})
169+
key2 = OctKey.import_key(dict_key, {"alg": "dir"})
170170

171171
self.assertEqual(key1.kid, key2.kid)
172172

0 commit comments

Comments
 (0)