Skip to content

Commit df6672e

Browse files
authored
🚚 release (#63)
2 parents 9b38d6f + 62bb757 commit df6672e

File tree

3 files changed

+513
-31
lines changed

3 files changed

+513
-31
lines changed

‎README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pip install netboxlabs-diode-sdk
2424
* `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting
2525
* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication
2626
* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication
27+
* `DIODE_CERT_FILE` - Path to custom certificate file for TLS connections
28+
* `DIODE_SKIP_TLS_VERIFY` - Skip TLS verification (default: `false`)
2729
* `DIODE_DRY_RUN_OUTPUT_DIR` - Directory where `DiodeDryRunClient` will write JSON files
2830

2931
### Example
@@ -77,6 +79,36 @@ if __name__ == "__main__":
7779

7880
```
7981

82+
### TLS verification and certificates
83+
84+
TLS verification is controlled by the target URL scheme:
85+
- **Secure schemes** (`grpcs://`, `https://`): TLS verification enabled
86+
- **Insecure schemes** (`grpc://`, `http://`): TLS verification disabled
87+
88+
```python
89+
# TLS verification enabled (uses system certificates)
90+
client = DiodeClient(target="grpcs://example.com", ...)
91+
92+
# TLS verification disabled
93+
client = DiodeClient(target="grpc://example.com", ...)
94+
```
95+
96+
#### Using custom certificates
97+
98+
```python
99+
# Via constructor parameter
100+
client = DiodeClient(target="grpcs://example.com", cert_file="/path/to/cert.pem", ...)
101+
102+
# Or via environment variable
103+
export DIODE_CERT_FILE=/path/to/cert.pem
104+
```
105+
106+
#### Disabling TLS verification
107+
108+
```bash
109+
export DIODE_SKIP_TLS_VERIFY=true
110+
```
111+
80112
### Dry run mode
81113

82114
`DiodeDryRunClient` generates ingestion requests without contacting a Diode server. Requests are printed to stdout by default, or written to JSON files when `output_dir` (or the `DIODE_DRY_RUN_OUTPUT_DIR` environment variable) is specified. The `app_name` parameter serves as the filename prefix; if not provided, `dryrun` is used as the default prefix. The file name is suffixed with a nanosecond-precision timestamp, resulting in the format `<app_name>_<timestamp_ns>.json`.

‎netboxlabs/diode/sdk/client.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,17 @@
2626
from netboxlabs.diode.sdk.ingester import Entity
2727
from netboxlabs.diode.sdk.version import version_semver
2828

29-
_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
30-
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
31-
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
3229
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
3330
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
31+
_DEFAULT_STREAM = "latest"
32+
_DIODE_CERT_FILE_ENVVAR_NAME = "DIODE_CERT_FILE"
33+
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
34+
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
35+
_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME = "DIODE_SKIP_TLS_VERIFY"
3436
_DRY_RUN_OUTPUT_DIR_ENVVAR_NAME = "DIODE_DRY_RUN_OUTPUT_DIR"
3537
_INGEST_SCOPE = "diode:ingest"
36-
_DEFAULT_STREAM = "latest"
3738
_LOGGER = logging.getLogger(__name__)
38-
39+
_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
3940

4041
def load_dryrun_entities(file_path: str | Path) -> Iterable[Entity]:
4142
"""Yield entities from a file with concatenated JSON messages."""
@@ -53,20 +54,35 @@ class DiodeClientInterface:
5354
pass
5455

5556

56-
def _load_certs() -> bytes:
57-
"""Loads cacert.pem."""
58-
with open(certifi.where(), "rb") as f:
57+
def _load_certs(cert_file: str | None = None) -> bytes:
58+
"""Loads cacert.pem or custom certificate file."""
59+
cert_path = cert_file or certifi.where()
60+
with open(cert_path, "rb") as f:
5961
return f.read()
6062

6163

64+
def _should_verify_tls(scheme: str) -> bool:
65+
"""Determine if TLS verification should be enabled based on scheme and environment variable."""
66+
# Check if scheme is insecure
67+
insecure_scheme = scheme in ["grpc", "http"]
68+
69+
# Check environment variable
70+
skip_tls_env = os.getenv(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME, "").lower()
71+
skip_tls_from_env = skip_tls_env in ["true", "1", "yes", "on"]
72+
73+
# TLS verification is enabled by default, disabled only for insecure schemes or env var
74+
return not (insecure_scheme or skip_tls_from_env)
75+
76+
6277
def parse_target(target: str) -> tuple[str, str, bool]:
6378
"""Parse the target into authority, path and tls_verify."""
6479
parsed_target = urlparse(target)
6580

6681
if parsed_target.scheme not in ["grpc", "grpcs", "http", "https"]:
6782
raise ValueError("target should start with grpc://, grpcs://, http:// or https://")
6883

69-
tls_verify = parsed_target.scheme in ["grpcs", "https"]
84+
# Determine if TLS verification should be enabled
85+
tls_verify = _should_verify_tls(parsed_target.scheme)
7086

7187
authority = parsed_target.netloc
7288

@@ -127,15 +143,22 @@ def __init__(
127143
sentry_traces_sample_rate: float = 1.0,
128144
sentry_profiles_sample_rate: float = 1.0,
129145
max_auth_retries: int = 3,
146+
cert_file: str | None = None,
130147
):
131148
"""Initiate a new client."""
132149
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
133150
logging.basicConfig(level=log_level)
134151

135-
self._max_auth_retries = _get_optional_config_value(
136-
_MAX_RETRIES_ENVVAR_NAME, max_auth_retries
152+
self._max_auth_retries = int(_get_optional_config_value(
153+
_MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries)
154+
) or max_auth_retries)
155+
self._cert_file = _get_optional_config_value(
156+
_DIODE_CERT_FILE_ENVVAR_NAME, cert_file
137157
)
138158
self._target, self._path, self._tls_verify = parse_target(target)
159+
160+
# Load certificates once if needed
161+
self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None
139162
self._app_name = app_name
140163
self._app_version = app_version
141164
self._platform = platform.platform()
@@ -161,12 +184,12 @@ def __init__(
161184
),
162185
)
163186

164-
if self._tls_verify:
187+
if self._tls_verify and self._certificates:
165188
_LOGGER.debug("Setting up gRPC secure channel")
166189
self._channel = grpc.secure_channel(
167190
self._target,
168191
grpc.ssl_channel_credentials(
169-
root_certificates=_load_certs(),
192+
root_certificates=self._certificates,
170193
),
171194
options=channel_opts,
172195
)
@@ -304,6 +327,7 @@ def _authenticate(self, scope: str):
304327
self._client_id,
305328
self._client_secret,
306329
scope,
330+
self._certificates,
307331
)
308332
access_token = authentication_client.authenticate()
309333
self._metadata = list(
@@ -391,20 +415,24 @@ def __init__(
391415
client_id: str,
392416
client_secret: str,
393417
scope: str,
418+
certificates: bytes | None = None,
394419
):
395420
self._target = target
396421
self._tls_verify = tls_verify
397422
self._client_id = client_id
398423
self._client_secret = client_secret
399424
self._path = path
400425
self._scope = scope
426+
self._certificates = certificates
401427

402428
def authenticate(self) -> str:
403429
"""Request an OAuth2 token using client credentials and return it."""
404-
if self._tls_verify:
430+
if self._tls_verify and self._certificates:
431+
context = ssl.create_default_context()
432+
context.load_verify_locations(cadata=self._certificates.decode('utf-8'))
405433
conn = http.client.HTTPSConnection(
406434
self._target,
407-
context=None if self._tls_verify else ssl._create_unverified_context(),
435+
context=context,
408436
)
409437
else:
410438
conn = http.client.HTTPConnection(

0 commit comments

Comments
 (0)