Skip to content

Commit 0e42e7e

Browse files
committed
WIP: Add get_from_list to credentials
1 parent 3cda828 commit 0e42e7e

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

logprep/util/credentials.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def from_target(cls, target_url: str) -> "Credentials | None":
163163
return credentials
164164

165165
@classmethod
166-
def from_endpoint(cls, target_endpoint: str) -> "Credentials | None":
166+
def from_endpoint(cls, target_endpoint: str) -> "list[Credentials] | Credentials | None":
167167
"""Factory method to create a credentials object based on the credentials stored in the
168168
environment variable :code:`LOGPREP_CREDENTIALS_FILE`.
169169
Based on these credentials the expected authentication method is chosen and represented
@@ -187,9 +187,14 @@ def from_endpoint(cls, target_endpoint: str) -> "Credentials | None":
187187
endpoint_credentials = credentials_file.input.get("endpoints")
188188
if endpoint_credentials is None:
189189
return None
190-
credential_mapping: dict | None = endpoint_credentials.get(target_endpoint)
190+
credential_mapping: list | dict | None = endpoint_credentials.get(target_endpoint)
191+
192+
credentials: list[Credentials] | Credentials | None = None
193+
if isinstance(credential_mapping, dict):
194+
credentials = cls.from_dict(credential_mapping)
195+
elif isinstance(credential_mapping, list):
196+
credentials = cls.from_list(credential_mapping)
191197

192-
credentials = cls.from_dict(credential_mapping)
193198
return credentials
194199

195200
@staticmethod
@@ -251,6 +256,16 @@ def _resolve_secret_content(credential_mapping: dict):
251256
credential_mapping.pop(f"{credential_type}_file")
252257
credential_mapping.update(secret_content)
253258

259+
@classmethod
260+
def from_list(cls, credential_mapping: list[dict | None]) -> "list[Credentials] | None":
261+
creds: list[Credentials] = []
262+
for credential in credential_mapping:
263+
cred = cls.from_dict(credential)
264+
if isinstance(cred, Credentials):
265+
creds.append(cred)
266+
267+
return creds
268+
254269
@classmethod
255270
def from_dict(cls, credential_mapping: dict | None) -> "Credentials | None":
256271
"""matches the given credentials of the credentials mapping
@@ -392,11 +407,11 @@ class AccessToken:
392407

393408
token: str = field(validator=validators.instance_of(str), repr=False)
394409
"""token used for authentication against the target"""
395-
expiry_time: datetime = field(
410+
expiry_time: datetime | None = field(
396411
validator=validators.instance_of((datetime, type(None))), init=False
397412
)
398413
"""time when token is expired"""
399-
refresh_token: str = field(
414+
refresh_token: str | None = field(
400415
validator=validators.instance_of((str, type(None))), default=None, repr=False
401416
)
402417
"""is used incase the token is expired"""
@@ -416,7 +431,7 @@ def __str__(self) -> str:
416431
@property
417432
def is_expired(self) -> bool:
418433
"""Checks if the token is already expired."""
419-
if self.expires_in == 0:
434+
if self.expires_in == 0 or not self.expiry_time:
420435
return False
421436
return datetime.now() > self.expiry_time
422437

@@ -427,7 +442,9 @@ class Credentials:
427442

428443
_logger = logging.getLogger("Credentials")
429444

430-
_session: Session = field(validator=validators.instance_of((Session, type(None))), default=None)
445+
_session: Session | None = field(
446+
validator=validators.instance_of((Session, type(None))), default=None
447+
)
431448

432449
def get_session(self):
433450
"""returns session with retry configuration"""
@@ -549,14 +566,14 @@ class OAuth2PasswordFlowCredentials(Credentials):
549566
"""the username for the token request"""
550567
timeout: int = field(validator=validators.instance_of(int), default=1)
551568
"""The timeout for the token request. Defaults to 1 second."""
552-
client_id: str = field(validator=validators.instance_of((str, type(None))), default=None)
569+
client_id: str | None = field(validator=validators.instance_of((str, type(None))), default=None)
553570
"""The client id for the token request. This is used to identify the client. (Optional)"""
554-
client_secret: str = field(
571+
client_secret: str | None = field(
555572
validator=validators.instance_of((str, type(None))), default=None, repr=False
556573
)
557574
"""The client secret for the token request.
558575
This is used to authenticate the client. (Optional)"""
559-
_token: AccessToken = field(
576+
_token: AccessToken | None = field(
560577
validator=validators.instance_of((AccessToken, type(None))),
561578
init=False,
562579
repr=False,
@@ -573,7 +590,7 @@ def get_session(self) -> Session:
573590
}
574591
session.headers["Authorization"] = f"Bearer {self._get_token(payload)}"
575592

576-
if self._token.is_expired and self._token.refresh_token is not None:
593+
if self._token and self._token.is_expired and self._token.refresh_token is not None:
577594
session = Session()
578595
payload = {
579596
"grant_type": "refresh_token",
@@ -639,7 +656,7 @@ class OAuth2ClientFlowCredentials(Credentials):
639656
"""The client secret for the token request. This is used to authenticate the client."""
640657
timeout: int = field(validator=validators.instance_of(int), default=1)
641658
"""The timeout for the token request. Defaults to 1 second."""
642-
_token: AccessToken = field(
659+
_token: AccessToken | None = field(
643660
validator=validators.instance_of((AccessToken, type(None))), init=False, repr=False
644661
)
645662

@@ -659,7 +676,7 @@ def get_session(self) -> Session:
659676
660677
"""
661678
session = super().get_session()
662-
if "Authorization" in session.headers and self._token.is_expired:
679+
if "Authorization" in session.headers and (not self._token or self._token.is_expired):
663680
session.close()
664681
session = Session()
665682
if self._no_authorization_header(session):
@@ -710,7 +727,7 @@ class MTLSCredentials(Credentials):
710727
"""path to the client key"""
711728
cert: str = field(validator=validators.instance_of(str))
712729
"""path to the client certificate"""
713-
ca_cert: str = field(validator=validators.instance_of((str, type(None))), default=None)
730+
ca_cert: str | None = field(validator=validators.instance_of((str, type(None))), default=None)
714731
"""path to a certification authority certificate"""
715732

716733
def get_session(self):

0 commit comments

Comments
 (0)