Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions device-discovery/device_discovery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import logging
import threading
from typing import Any

from netboxlabs.diode.sdk import DiodeClient, DiodeDryRunClient
from netboxlabs.diode.sdk import DiodeClient, DiodeDryRunClient, DiodeOTLPClient

from device_discovery.translate import translate_data
from device_discovery.version import version_semver
Expand Down Expand Up @@ -82,22 +83,28 @@ def init_client(
app_name=f"{prefix}/{APP_NAME}" if prefix else APP_NAME,
output_dir=dry_run_output_dir,
)
else:
elif client_id is not None and client_secret is not None:
self.diode_client = DiodeClient(
target=target,
app_name=f"{prefix}/{APP_NAME}" if prefix else APP_NAME,
app_version=APP_VERSION,
client_id=client_id,
client_secret=client_secret,
)
else:
self.diode_client = DiodeOTLPClient(
target=target,
app_name=f"{prefix}/{APP_NAME}" if prefix else APP_NAME,
app_version=APP_VERSION,
)

def ingest(self, hostname: str, data: dict):
def ingest(self, metadata: dict[str, Any] | None, data: dict):
"""
Ingest data using the Diode client after translating it.

Args:
----
hostname (str): The device hostname.
metadata (dict[str, Any] | None): Metadata to attach to the ingestion request.
data (dict): The data to be ingested.

Raises:
Expand All @@ -109,9 +116,17 @@ def ingest(self, hostname: str, data: dict):
raise ValueError("Diode client not initialized")

with self._lock:
response = self.diode_client.ingest(translate_data(data))
translated_entities = translate_data(data)
request_metadata = metadata or {}
response = self.diode_client.ingest(
entities=translated_entities, metadata=request_metadata
)

hostname = request_metadata.get("hostname") or "unknown-host"

if response.errors:
logger.error(f"ERROR ingestion failed for {hostname} : {response.errors}")
logger.error(
f"ERROR ingestion failed for {hostname} : {response.errors}"
)
else:
logger.info(f"Hostname {hostname}: Successful ingestion")
2 changes: 0 additions & 2 deletions device-discovery/device_discovery/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def main():
name
for name, val in [
("--diode-target", args.diode_target),
("--diode-client-id", args.diode_client_id),
("--diode-client-secret", args.diode_client_secret),
]
if not val
]
Expand Down
3 changes: 2 additions & 1 deletion device-discovery/device_discovery/policy/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def _collect_device_data(
logger.error(
f"Policy {self.name}, Hostname {sanitized_hostname}: Error getting VLANs: {e}"
)
Client().ingest(scope.hostname, data)
metadata = {"policy_name": self.name, "hostname": sanitized_hostname}
Client().ingest(metadata, data)
discovery_success = get_metric("discovery_success")
if discovery_success:
discovery_success.add(1, {"policy": self.name})
Expand Down
6 changes: 5 additions & 1 deletion device-discovery/tests/policy/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def test_run_device_with_discovered_driver(policy_runner, sample_scopes, sample_
# Verify driver discovery and ingestion
mock_discover.assert_called_once_with(sample_scopes[0])
mock_ingest.assert_called_once()
data = mock_ingest.call_args[0][1]
metadata_arg, data = mock_ingest.call_args[0]
assert metadata_arg == {
"policy_name": policy_runner.name,
"hostname": sample_scopes[0].hostname,
}
assert data["driver"] == "ios"
assert data["device"] == {"model": "SampleModel"}
assert data["interface"] == {"eth0": "up"}
Expand Down
62 changes: 48 additions & 14 deletions device-discovery/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def sample_data():
}


@pytest.fixture
def sample_metadata():
"""Sample metadata for testing ingestion."""
return {"policy_name": "test-policy", "hostname": "router1"}


@pytest.fixture
def mock_version_semver():
"""Mock the version_semver function."""
Expand All @@ -73,6 +79,13 @@ def mock_diode_client_class():
yield mock


@pytest.fixture
def mock_diode_otlp_client_class():
"""Mock the DiodeOTLPClient class."""
with patch("device_discovery.client.DiodeOTLPClient") as mock:
yield mock


def test_init_client(mock_diode_client_class, mock_version_semver):
"""Test the initialization of the Diode client."""
client = Client()
Expand All @@ -92,7 +105,7 @@ def test_init_client(mock_diode_client_class, mock_version_semver):
)


def test_ingest_success(mock_diode_client_class, sample_data):
def test_ingest_success(mock_diode_client_class, sample_data, sample_metadata):
"""Test successful data ingestion."""
client = Client()
client.init_client(
Expand All @@ -101,18 +114,20 @@ def test_ingest_success(mock_diode_client_class, sample_data):

mock_diode_instance = mock_diode_client_class.return_value
mock_diode_instance.ingest.return_value.errors = []
hostname = sample_data["device"]["hostname"]

metadata = sample_metadata
with patch(
"device_discovery.client.translate_data",
return_value=translate_data(sample_data),
) as mock_translate_data:
client.ingest(hostname, sample_data)
client.ingest(metadata, sample_data)
mock_translate_data.assert_called_once_with(sample_data)
mock_diode_instance.ingest.assert_called_once()
mock_diode_instance.ingest.assert_called_once_with(
entities=mock_translate_data.return_value,
metadata=metadata,
)


def test_ingest_failure(mock_diode_client_class, sample_data):
def test_ingest_failure(mock_diode_client_class, sample_data, sample_metadata):
"""Test data ingestion with errors."""
client = Client()
client.init_client(
Expand All @@ -124,25 +139,27 @@ def test_ingest_failure(mock_diode_client_class, sample_data):

mock_diode_instance = mock_diode_client_class.return_value
mock_diode_instance.ingest.return_value.errors = ["Error1", "Error2"]
hostname = sample_data["device"]["hostname"]

metadata = sample_metadata
with patch(
"device_discovery.client.translate_data",
return_value=translate_data(sample_data),
) as mock_translate_data:
client.ingest(hostname, sample_data)
client.ingest(metadata, sample_data)
mock_translate_data.assert_called_once_with(sample_data)
mock_diode_instance.ingest.assert_called_once()
mock_diode_instance.ingest.assert_called_once_with(
entities=mock_translate_data.return_value,
metadata=metadata,
)

assert len(mock_diode_instance.ingest.return_value.errors) > 0


def test_ingest_without_initialization():
def test_ingest_without_initialization(sample_metadata):
"""Test ingestion without client initialization raises ValueError."""
Client._instance = None # Reset the Client singleton instance
client = Client()
with pytest.raises(ValueError, match="Diode client not initialized"):
client.ingest("", {})
client.ingest(sample_metadata, {})


def test_client_dry_run(tmp_path, sample_data):
Expand All @@ -154,7 +171,8 @@ def test_client_dry_run(tmp_path, sample_data):
dry_run_output_dir=tmp_path,
)
hostname = sample_data["device"]["hostname"]
client.ingest(hostname, sample_data)
metadata = {"policy_name": "dry-run-policy", "hostname": hostname}
client.ingest(metadata, sample_data)
files = list(tmp_path.glob("prefix_device-discovery*.json"))

assert len(files) == 1
Expand All @@ -174,8 +192,24 @@ def test_client_dry_run_stdout(capsys, sample_data):
)

hostname = sample_data["device"]["hostname"]
client.ingest(hostname, sample_data)
metadata = {"policy_name": "dry-run-policy", "hostname": hostname}
client.ingest(metadata, sample_data)

captured = capsys.readouterr()
assert sample_data["device"]["hostname"] in captured.out
assert sample_data["interface"]["GigabitEthernet0/0"]["mac_address"] in captured.out


def test_init_client_uses_otlp_when_credentials_missing(
mock_diode_client_class, mock_diode_otlp_client_class, mock_version_semver
):
"""Ensure init_client falls back to DiodeOTLPClient when credentials are not provided."""
client = Client()
client.init_client(prefix="prefix", target="https://example.com")

assert not mock_diode_client_class.called
mock_diode_otlp_client_class.assert_called_once_with(
target="https://example.com",
app_name="prefix/device-discovery",
app_version=mock_version_semver(),
)
60 changes: 59 additions & 1 deletion worker/tests/policy/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
@pytest.fixture
def policy_runner():
"""Fixture to create a PolicyRunner instance."""
return PolicyRunner()
runner = PolicyRunner()
runner.metadata = Metadata(
name="test_backend", app_name="test_app", app_version="1.0"
)
return runner


@pytest.fixture
Expand Down Expand Up @@ -71,6 +75,15 @@ def mock_diode_client():
mock_diode_client.return_value = mock_instance
yield mock_diode_client


@pytest.fixture
def mock_diode_otlp_client():
"""Fixture to mock the DiodeOTLPClient constructor."""
with patch("worker.policy.runner.DiodeOTLPClient") as mock_diode_otlp_client:
mock_instance = MagicMock()
mock_diode_otlp_client.return_value = mock_instance
yield mock_diode_otlp_client

@pytest.fixture
def mock_diode_dry_run_client():
"""Fixture to mock the DiodeDryRunClient constructor."""
Expand Down Expand Up @@ -138,6 +151,28 @@ def test_setup_policy_runner_with_one_time_run(
assert mock_start.called
assert policy_runner.status == Status.RUNNING


def test_setup_policy_runner_uses_otlp_client(
policy_runner,
sample_policy,
mock_load_class,
mock_diode_client,
mock_diode_otlp_client,
):
"""Ensure setup falls back to DiodeOTLPClient when credentials are missing."""
otlp_config = DiodeConfig(target="http://localhost:8080", prefix="test-prefix")
with patch.object(policy_runner.scheduler, "start") as mock_start, patch.object(
policy_runner.scheduler, "add_job"
) as mock_add_job:
policy_runner.setup("policy1", otlp_config, sample_policy)

mock_start.assert_called_once()
mock_add_job.assert_called_once()

mock_load_class.assert_called_once()
assert not mock_diode_client.called
mock_diode_otlp_client.assert_called_once()

def test_setup_policy_runner_dry_run(
policy_runner,
sample_diode_dry_run_config,
Expand Down Expand Up @@ -185,6 +220,29 @@ def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backe
assert len(call_args) == 3


def test_run_passes_metadata_to_ingest(
policy_runner, sample_policy, mock_diode_client, mock_backend
):
"""Ensure run forwards policy/backend metadata to the Diode client."""
policy_runner.name = "policy-meta"
policy_runner.metadata = Metadata(
name="custom_backend", app_name="custom", app_version="0.1"
)

entity = ingester_pb2.Entity()
entity.device.name = "device-1"
mock_backend.run.return_value = [entity]
mock_diode_client.ingest.return_value.errors = []

policy_runner.run(mock_diode_client, mock_backend, sample_policy)

_, kwargs = mock_diode_client.ingest.call_args
assert kwargs["metadata"] == {
"policy_name": "policy-meta",
"worker_backend": "custom_backend",
}


def test_run_ingestion_errors(
policy_runner,
sample_policy,
Expand Down
15 changes: 14 additions & 1 deletion worker/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from worker.main import main
from worker.main import main, resolve_env_var


@pytest.fixture
Expand Down Expand Up @@ -169,3 +169,16 @@ def test_main_missing_policy(mock_parse_args):
main()
except Exception as e:
assert str(e) == "Test Exit"


def test_resolve_env_var_expands_environment(monkeypatch):
"""Ensure resolve_env_var expands placeholders using environment variables."""
monkeypatch.setenv("MY_ENDPOINT", "grpc://localhost:4317")
assert resolve_env_var("${MY_ENDPOINT}") == "grpc://localhost:4317"


def test_resolve_env_var_returns_original(monkeypatch):
"""Ensure resolve_env_var returns original string when expansion is not possible."""
monkeypatch.delenv("NOT_DEFINED", raising=False)
assert resolve_env_var("plain-value") == "plain-value"
assert resolve_env_var("${NOT_DEFINED}") == "${NOT_DEFINED}"
4 changes: 1 addition & 3 deletions worker/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def main():
name
for name, val in [
("--diode-target", args.diode_target),
("--diode-client-id", args.diode_client_id),
("--diode-client-secret", args.diode_client_secret),
]
if not val
]
Expand Down Expand Up @@ -174,7 +172,7 @@ def main():
)

try:
if not config.dry_run:
if not config.dry_run and client_id is not None and client_secret is not None:
DiodeClient(
target=config.target,
app_name="validate",
Expand Down
Loading
Loading