Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pip install netboxlabs-diode-sdk
* `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting
* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication
* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication
* `DIODE_CERT_FILE` - Path to custom certificate file for TLS connections
* `DIODE_SKIP_TLS_VERIFY` - Skip TLS verification (default: `false`)
* `DIODE_DRY_RUN_OUTPUT_DIR` - Directory where `DiodeDryRunClient` will write JSON files

### Example
Expand Down Expand Up @@ -77,6 +79,36 @@ if __name__ == "__main__":

```

### TLS verification and certificates

TLS verification is controlled by the target URL scheme:
- **Secure schemes** (`grpcs://`, `https://`): TLS verification enabled
- **Insecure schemes** (`grpc://`, `http://`): TLS verification disabled

```python
# TLS verification enabled (uses system certificates)
client = DiodeClient(target="grpcs://example.com", ...)

# TLS verification disabled
client = DiodeClient(target="grpc://example.com", ...)
```

#### Using custom certificates

```python
# Via constructor parameter
client = DiodeClient(target="grpcs://example.com", cert_file="/path/to/cert.pem", ...)

# Or via environment variable
export DIODE_CERT_FILE=/path/to/cert.pem
```

#### Disabling TLS verification

```bash
export DIODE_SKIP_TLS_VERIFY=true
```

### Dry run mode

`DiodeDryRunClient` generates ingestion requests without contacting a Diode server. Requests are printed to stdout by default, or written to JSON files when `output_dir` (or the `DIODE_DRY_RUN_OUTPUT_DIR` environment variable) is specified. The `app_name` parameter serves as the filename prefix; if not provided, `dryrun` is used as the default prefix. The file name is suffixed with a nanosecond-precision timestamp, resulting in the format `<app_name>_<timestamp_ns>.json`.
Expand Down
58 changes: 43 additions & 15 deletions netboxlabs/diode/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@
from netboxlabs.diode.sdk.ingester import Entity
from netboxlabs.diode.sdk.version import version_semver

_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
_DEFAULT_STREAM = "latest"
_DIODE_CERT_FILE_ENVVAR_NAME = "DIODE_CERT_FILE"
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME = "DIODE_SKIP_TLS_VERIFY"
_DRY_RUN_OUTPUT_DIR_ENVVAR_NAME = "DIODE_DRY_RUN_OUTPUT_DIR"
_INGEST_SCOPE = "diode:ingest"
_DEFAULT_STREAM = "latest"
_LOGGER = logging.getLogger(__name__)

_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"

def load_dryrun_entities(file_path: str | Path) -> Iterable[Entity]:
"""Yield entities from a file with concatenated JSON messages."""
Expand All @@ -53,20 +54,35 @@ class DiodeClientInterface:
pass


def _load_certs() -> bytes:
"""Loads cacert.pem."""
with open(certifi.where(), "rb") as f:
def _load_certs(cert_file: str | None = None) -> bytes:
"""Loads cacert.pem or custom certificate file."""
cert_path = cert_file or certifi.where()
with open(cert_path, "rb") as f:
return f.read()


def _should_verify_tls(scheme: str) -> bool:
"""Determine if TLS verification should be enabled based on scheme and environment variable."""
# Check if scheme is insecure
insecure_scheme = scheme in ["grpc", "http"]

# Check environment variable
skip_tls_env = os.getenv(_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME, "").lower()
skip_tls_from_env = skip_tls_env in ["true", "1", "yes", "on"]

# TLS verification is enabled by default, disabled only for insecure schemes or env var
return not (insecure_scheme or skip_tls_from_env)


def parse_target(target: str) -> tuple[str, str, bool]:
"""Parse the target into authority, path and tls_verify."""
parsed_target = urlparse(target)

if parsed_target.scheme not in ["grpc", "grpcs", "http", "https"]:
raise ValueError("target should start with grpc://, grpcs://, http:// or https://")

tls_verify = parsed_target.scheme in ["grpcs", "https"]
# Determine if TLS verification should be enabled
tls_verify = _should_verify_tls(parsed_target.scheme)

authority = parsed_target.netloc

Expand Down Expand Up @@ -127,15 +143,22 @@ def __init__(
sentry_traces_sample_rate: float = 1.0,
sentry_profiles_sample_rate: float = 1.0,
max_auth_retries: int = 3,
cert_file: str | None = None,
):
"""Initiate a new client."""
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
logging.basicConfig(level=log_level)

self._max_auth_retries = _get_optional_config_value(
_MAX_RETRIES_ENVVAR_NAME, max_auth_retries
self._max_auth_retries = int(_get_optional_config_value(
_MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries)
) or max_auth_retries)
self._cert_file = _get_optional_config_value(
_DIODE_CERT_FILE_ENVVAR_NAME, cert_file
)
self._target, self._path, self._tls_verify = parse_target(target)

# Load certificates once if needed
self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None
self._app_name = app_name
self._app_version = app_version
self._platform = platform.platform()
Expand All @@ -161,12 +184,12 @@ def __init__(
),
)

if self._tls_verify:
if self._tls_verify and self._certificates:
_LOGGER.debug("Setting up gRPC secure channel")
self._channel = grpc.secure_channel(
self._target,
grpc.ssl_channel_credentials(
root_certificates=_load_certs(),
root_certificates=self._certificates,
),
options=channel_opts,
)
Expand Down Expand Up @@ -304,6 +327,7 @@ def _authenticate(self, scope: str):
self._client_id,
self._client_secret,
scope,
self._certificates,
)
access_token = authentication_client.authenticate()
self._metadata = list(
Expand Down Expand Up @@ -391,20 +415,24 @@ def __init__(
client_id: str,
client_secret: str,
scope: str,
certificates: bytes | None = None,
):
self._target = target
self._tls_verify = tls_verify
self._client_id = client_id
self._client_secret = client_secret
self._path = path
self._scope = scope
self._certificates = certificates

def authenticate(self) -> str:
"""Request an OAuth2 token using client credentials and return it."""
if self._tls_verify:
if self._tls_verify and self._certificates:
context = ssl.create_default_context()
context.load_verify_locations(cadata=self._certificates.decode('utf-8'))
conn = http.client.HTTPSConnection(
self._target,
context=None if self._tls_verify else ssl._create_unverified_context(),
context=context,
)
else:
conn = http.client.HTTPConnection(
Expand Down
Loading
Loading