26
26
from netboxlabs .diode .sdk .ingester import Entity
27
27
from netboxlabs .diode .sdk .version import version_semver
28
28
29
- _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
30
- _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
31
- _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
32
29
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
33
30
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
31
+ _DEFAULT_STREAM = "latest"
32
+ _DIODE_CERT_FILE_ENVVAR_NAME = "DIODE_CERT_FILE"
33
+ _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
34
+ _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
35
+ _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME = "DIODE_SKIP_TLS_VERIFY"
34
36
_DRY_RUN_OUTPUT_DIR_ENVVAR_NAME = "DIODE_DRY_RUN_OUTPUT_DIR"
35
37
_INGEST_SCOPE = "diode:ingest"
36
- _DEFAULT_STREAM = "latest"
37
38
_LOGGER = logging .getLogger (__name__ )
38
-
39
+ _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
39
40
40
41
def load_dryrun_entities (file_path : str | Path ) -> Iterable [Entity ]:
41
42
"""Yield entities from a file with concatenated JSON messages."""
@@ -53,20 +54,35 @@ class DiodeClientInterface:
53
54
pass
54
55
55
56
56
- def _load_certs () -> bytes :
57
- """Loads cacert.pem."""
58
- with open (certifi .where (), "rb" ) as f :
57
+ def _load_certs (cert_file : str | None = None ) -> bytes :
58
+ """Loads cacert.pem or custom certificate file."""
59
+ cert_path = cert_file or certifi .where ()
60
+ with open (cert_path , "rb" ) as f :
59
61
return f .read ()
60
62
61
63
64
+ def _should_verify_tls (scheme : str ) -> bool :
65
+ """Determine if TLS verification should be enabled based on scheme and environment variable."""
66
+ # Check if scheme is insecure
67
+ insecure_scheme = scheme in ["grpc" , "http" ]
68
+
69
+ # Check environment variable
70
+ skip_tls_env = os .getenv (_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME , "" ).lower ()
71
+ skip_tls_from_env = skip_tls_env in ["true" , "1" , "yes" , "on" ]
72
+
73
+ # TLS verification is enabled by default, disabled only for insecure schemes or env var
74
+ return not (insecure_scheme or skip_tls_from_env )
75
+
76
+
62
77
def parse_target (target : str ) -> tuple [str , str , bool ]:
63
78
"""Parse the target into authority, path and tls_verify."""
64
79
parsed_target = urlparse (target )
65
80
66
81
if parsed_target .scheme not in ["grpc" , "grpcs" , "http" , "https" ]:
67
82
raise ValueError ("target should start with grpc://, grpcs://, http:// or https://" )
68
83
69
- tls_verify = parsed_target .scheme in ["grpcs" , "https" ]
84
+ # Determine if TLS verification should be enabled
85
+ tls_verify = _should_verify_tls (parsed_target .scheme )
70
86
71
87
authority = parsed_target .netloc
72
88
@@ -127,15 +143,22 @@ def __init__(
127
143
sentry_traces_sample_rate : float = 1.0 ,
128
144
sentry_profiles_sample_rate : float = 1.0 ,
129
145
max_auth_retries : int = 3 ,
146
+ cert_file : str | None = None ,
130
147
):
131
148
"""Initiate a new client."""
132
149
log_level = os .getenv (_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME , "INFO" ).upper ()
133
150
logging .basicConfig (level = log_level )
134
151
135
- self ._max_auth_retries = _get_optional_config_value (
136
- _MAX_RETRIES_ENVVAR_NAME , max_auth_retries
152
+ self ._max_auth_retries = int (_get_optional_config_value (
153
+ _MAX_RETRIES_ENVVAR_NAME , str (max_auth_retries )
154
+ ) or max_auth_retries )
155
+ self ._cert_file = _get_optional_config_value (
156
+ _DIODE_CERT_FILE_ENVVAR_NAME , cert_file
137
157
)
138
158
self ._target , self ._path , self ._tls_verify = parse_target (target )
159
+
160
+ # Load certificates once if needed
161
+ self ._certificates = _load_certs (self ._cert_file ) if (self ._tls_verify or self ._cert_file ) else None
139
162
self ._app_name = app_name
140
163
self ._app_version = app_version
141
164
self ._platform = platform .platform ()
@@ -161,12 +184,12 @@ def __init__(
161
184
),
162
185
)
163
186
164
- if self ._tls_verify :
187
+ if self ._tls_verify and self . _certificates :
165
188
_LOGGER .debug ("Setting up gRPC secure channel" )
166
189
self ._channel = grpc .secure_channel (
167
190
self ._target ,
168
191
grpc .ssl_channel_credentials (
169
- root_certificates = _load_certs () ,
192
+ root_certificates = self . _certificates ,
170
193
),
171
194
options = channel_opts ,
172
195
)
@@ -304,6 +327,7 @@ def _authenticate(self, scope: str):
304
327
self ._client_id ,
305
328
self ._client_secret ,
306
329
scope ,
330
+ self ._certificates ,
307
331
)
308
332
access_token = authentication_client .authenticate ()
309
333
self ._metadata = list (
@@ -391,20 +415,24 @@ def __init__(
391
415
client_id : str ,
392
416
client_secret : str ,
393
417
scope : str ,
418
+ certificates : bytes | None = None ,
394
419
):
395
420
self ._target = target
396
421
self ._tls_verify = tls_verify
397
422
self ._client_id = client_id
398
423
self ._client_secret = client_secret
399
424
self ._path = path
400
425
self ._scope = scope
426
+ self ._certificates = certificates
401
427
402
428
def authenticate (self ) -> str :
403
429
"""Request an OAuth2 token using client credentials and return it."""
404
- if self ._tls_verify :
430
+ if self ._tls_verify and self ._certificates :
431
+ context = ssl .create_default_context ()
432
+ context .load_verify_locations (cadata = self ._certificates .decode ('utf-8' ))
405
433
conn = http .client .HTTPSConnection (
406
434
self ._target ,
407
- context = None if self . _tls_verify else ssl . _create_unverified_context () ,
435
+ context = context ,
408
436
)
409
437
else :
410
438
conn = http .client .HTTPConnection (
0 commit comments