diff --git a/distributed/client.py b/distributed/client.py index 01fe47fa9c..83bfaa336a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5456,6 +5456,86 @@ def unregister_worker_plugin(self, name, nanny=None): """ return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny) + def has_plugin( + self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence + ) -> bool | dict[str, bool]: + """Check if plugin(s) are registered + + Parameters + ---------- + plugin : str | plugin object | Sequence + Plugin to check. You can use the plugin object directly or the plugin name. + For plugin objects, they must have a 'name' attribute. You can also pass + a sequence of plugin objects or names. + + Returns + ------- + bool or dict[str, bool] + If name is str: True if plugin is registered, False otherwise + If name is Sequence: dict mapping names to registration status + + Examples + -------- + >>> logging_plugin = LoggingConfigPlugin() # Has name = "logging-config" + >>> client.register_plugin(logging_plugin) + >>> client.has_plugin(logging_plugin) + True + + >>> client.has_plugin('logging-config') + True + + >>> client.has_plugin([logging_plugin, 'other-plugin']) + {'logging-config': True, 'other-plugin': False} + """ + return self.sync(self._has_plugin_async, plugin=plugin) + + async def _has_plugin_async( + self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence + ) -> bool | dict[str, bool]: + """Async implementation for checking plugin registration""" + + # Convert plugin to list of names + if isinstance(plugin, str): + names_to_check = [plugin] + return_single = True + elif isinstance(plugin, (WorkerPlugin, SchedulerPlugin, NannyPlugin)): + plugin_name = getattr(plugin, "name", None) + if plugin_name is None: + raise ValueError( + f"Plugin {funcname(type(plugin))} has no 'name' attribute. " + "Please add a 'name' attribute to your plugin class." + ) + names_to_check = [plugin_name] + return_single = True + elif isinstance(plugin, Sequence): + names_to_check = [] + for p in plugin: + if isinstance(p, str): + names_to_check.append(p) + else: + plugin_name = getattr(p, "name", None) + if plugin_name is None: + raise ValueError( + f"Plugin {funcname(type(p))} has no 'name' attribute" + ) + names_to_check.append(plugin_name) + return_single = False + else: + raise TypeError( + f"plugin must be a plugin object, name string, or Sequence. Got {type(plugin)}" + ) + + # Get status from scheduler + result = await self.scheduler.get_plugin_registration_status( + names=names_to_check + ) + + # Return single bool or dict based on input + if return_single: + return result[names_to_check[0]] + else: + return result + @property def amm(self): """Convenience accessors for the :doc:`active_memory_manager`""" diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index 3c481dce26..256772542e 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -217,3 +217,154 @@ async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s): logs = caplog.getvalue() assert "TestPlugin1 failed to teardown" in logs assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_by_name(c, s, a): + """Test checking if nanny plugin is registered using string name""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.foo = 123 + + def teardown(self, nanny): + pass + + # Check non-existent plugin + assert not await c.has_plugin("duck-plugin") + + # Register plugin + await c.register_plugin(DuckPlugin()) + assert a.foo == 123 + + # Check using string name + assert await c.has_plugin("duck-plugin") + + # Unregister and check again + await c.unregister_worker_plugin("duck-plugin", nanny=True) + assert not await c.has_plugin("duck-plugin") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_by_object(c, s, a): + """Test checking if nanny plugin is registered using plugin object""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.bar = 456 + + def teardown(self, nanny): + pass + + plugin = DuckPlugin() + + # Check before registration + assert not await c.has_plugin(plugin) + + # Register and check + await c.register_plugin(plugin) + assert a.bar == 456 + assert await c.has_plugin(plugin) + + # Unregister and check + await c.unregister_worker_plugin("duck-plugin", nanny=True) + assert not await c.has_plugin(plugin) + + +@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_multiple_nannies(c, s, a, b): + """Test checking nanny plugin with multiple nannies""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.multi = "setup" + + def teardown(self, nanny): + pass + + # Check before registration + assert not await c.has_plugin("duck-plugin") + + # Register plugin (should propagate to all nannies) + await c.register_plugin(DuckPlugin()) + + # Verify both nannies have the plugin + assert a.multi == "setup" + assert b.multi == "setup" + + # Check plugin is registered + assert await c.has_plugin("duck-plugin") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_custom_name_override(c, s, a): + """Test nanny plugin registered with custom name different from class name""" + + class DuckPlugin(NannyPlugin): + name = "duck-plugin" + + def setup(self, nanny): + nanny.custom = "test" + + def teardown(self, nanny): + pass + + plugin = DuckPlugin() + + # Register with custom name (overriding the class name attribute) + await c.register_plugin(plugin, name="custom-override") + + # Check with custom name works + assert await c.has_plugin("custom-override") + + # Original name won't work since we overrode it + assert not await c.has_plugin("duck-plugin") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_has_nanny_plugin_list_check(c, s, a): + """Test checking multiple nanny plugins at once""" + + class IdempotentPlugin(NannyPlugin): + name = "idempotentplugin" + + def setup(self, nanny): + pass + + def teardown(self, nanny): + pass + + class NonIdempotentPlugin(NannyPlugin): + name = "nonidempotentplugin" + + def setup(self, nanny): + pass + + def teardown(self, nanny): + pass + + # Check multiple before registration + result = await c.has_plugin( + ["idempotentplugin", "nonidempotentplugin", "nonexistent"] + ) + assert result == { + "idempotentplugin": False, + "nonidempotentplugin": False, + "nonexistent": False, + } + + # Register first plugin + await c.register_plugin(IdempotentPlugin()) + result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"]) + assert result == {"idempotentplugin": True, "nonidempotentplugin": False} + + # Register second plugin + await c.register_plugin(NonIdempotentPlugin()) + result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"]) + assert result == {"idempotentplugin": True, "nonidempotentplugin": True} diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index d520524291..dd580ed09f 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -753,3 +753,138 @@ def __init__(self, instance=None): await s.register_scheduler_plugin(plugin=dumps(third)) assert "nonidempotentplugin" in s.plugins assert s.plugins["nonidempotentplugin"].instance == "third" + + +@gen_cluster(client=True) +async def test_has_scheduler_plugin_by_name(c, s, a, b): + """Test checking if scheduler plugin is registered using string name""" + + class Dummy1(SchedulerPlugin): + name = "Dummy1" + + def start(self, scheduler): + scheduler.foo = "bar" + + # Check non-existent plugin + assert not await c.has_plugin("Dummy1") + + # Register plugin + await c.register_plugin(Dummy1()) + assert s.foo == "bar" + + # Check using string name + assert await c.has_plugin("Dummy1") + + # Unregister and check again + await c.unregister_scheduler_plugin("Dummy1") + assert not await c.has_plugin("Dummy1") + + +@gen_cluster(client=True) +async def test_has_scheduler_plugin_by_object(c, s, a, b): + """Test checking if scheduler plugin is registered using plugin object""" + + class Dummy2(SchedulerPlugin): + name = "Dummy2" + + def start(self, scheduler): + scheduler.check_value = 42 + + plugin = Dummy2() + + # Check before registration + assert not await c.has_plugin(plugin) + + # Register and check + await c.register_plugin(plugin) + assert s.check_value == 42 + assert await c.has_plugin(plugin) + + # Unregister and check + await c.unregister_scheduler_plugin("Dummy2") + assert not await c.has_plugin(plugin) + + +@gen_cluster(client=True) +async def test_has_plugin_mixed_scheduler_and_worker_types(c, s, a, b): + """Test checking scheduler and worker plugins together""" + from distributed import WorkerPlugin + + class MyPlugin(SchedulerPlugin): + name = "MyPlugin" + + def start(self, scheduler): + scheduler.my_value = "scheduler" + + class MyWorkerPlugin(WorkerPlugin): + name = "MyWorkerPlugin" + + def setup(self, worker): + worker.my_value = "worker" + + sched_plugin = MyPlugin() + work_plugin = MyWorkerPlugin() + + # Register both types + await c.register_plugin(sched_plugin) + await c.register_plugin(work_plugin) + + # Verify both registered + assert s.my_value == "scheduler" + assert a.my_value == "worker" + assert b.my_value == "worker" + + # Check both with list of names + result = await c.has_plugin(["MyPlugin", "MyWorkerPlugin"]) + assert result == {"MyPlugin": True, "MyWorkerPlugin": True} + + # Check both with objects + assert await c.has_plugin(sched_plugin) + assert await c.has_plugin(work_plugin) + + # Check non-existent alongside real ones + result = await c.has_plugin(["MyPlugin", "nonexistent", "MyWorkerPlugin"]) + assert result == {"MyPlugin": True, "nonexistent": False, "MyWorkerPlugin": True} + + +@gen_cluster(client=True, nthreads=[]) +async def test_has_scheduler_plugin_no_workers(c, s): + """Test checking scheduler plugin when no workers exist""" + + class Plugin(SchedulerPlugin): + name = "plugin" + + def start(self, scheduler): + scheduler.no_worker_test = True + + # Check before registration + assert not await c.has_plugin("plugin") + + # Register plugin when no workers exist + await c.register_plugin(Plugin()) + assert s.no_worker_test is True + + # Check after registration + assert await c.has_plugin("plugin") + + +@gen_cluster(client=True) +async def test_has_scheduler_plugin_custom_name_override(c, s, a, b): + """Test scheduler plugin registered with custom name different from class name""" + + class Dummy3(SchedulerPlugin): + name = "Dummy3" + + def start(self, scheduler): + scheduler.name_test = "custom" + + plugin = Dummy3() + + # Register with custom name (overriding the class name attribute) + await c.register_plugin(plugin, name="custom-override") + + # Check with custom name works + assert await c.has_plugin("custom-override") + + # Original name won't work since we overrode it + assert not await c.has_plugin("Dummy3") diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 001576afe3..83eac3bbb9 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -479,3 +479,107 @@ async def test_plugin_with_broken_teardown_logs_on_close(c, s): logs = caplog.getvalue() assert "TestPlugin1 failed to teardown" in logs assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_worker_plugin_by_name(c, s, a): + """Test checking if worker plugin is registered using string name""" + + class MyPlugin(WorkerPlugin): + name = "MyPlugin" + + def __init__(self, data, expected_notifications=None): + self.data = data + self.expected_notifications = expected_notifications + + # Check non-existent plugin + assert not await c.has_plugin("MyPlugin") # ← await + + # Register plugin + await c.register_plugin(MyPlugin(123, None)) + + # Check using string name + assert await c.has_plugin("MyPlugin") # ← await + + # Unregister and check again + await c.unregister_worker_plugin("MyPlugin") + assert not await c.has_plugin("MyPlugin") # ← await + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_worker_plugin_by_object(c, s, a): + """Test checking if worker plugin is registered using plugin object""" + plugin = MyPlugin(456) + + # Check before registration + assert not await c.has_plugin(plugin) # ← await + + # Register and check + await c.register_plugin(plugin) + assert await c.has_plugin(plugin) # ← await + + # Unregister and check + await c.unregister_worker_plugin("MyPlugin") + assert not await c.has_plugin(plugin) # ← await + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_plugin_list(c, s, a): + """Test checking multiple plugins at once""" + plugin1 = MyPlugin(1) + + class AnotherPlugin(WorkerPlugin): + name = "AnotherPlugin" + + plugin2 = AnotherPlugin() + + # Check multiple plugins before registration + result = await c.has_plugin(["MyPlugin", "AnotherPlugin", "NonExistent"]) # ← await + assert result == { + "MyPlugin": False, + "AnotherPlugin": False, + "NonExistent": False, + } + + # Register first plugin + await c.register_plugin(plugin1) + result = await c.has_plugin(["MyPlugin", "AnotherPlugin"]) # ← await + assert result == {"MyPlugin": True, "AnotherPlugin": False} + + # Register second plugin + await c.register_plugin(plugin2) + result = await c.has_plugin(["MyPlugin", "AnotherPlugin"]) # ← await + assert result == {"MyPlugin": True, "AnotherPlugin": True} + + # Can also pass list of objects + result = await c.has_plugin([plugin1, plugin2]) # ← await + assert result == {"MyPlugin": True, "AnotherPlugin": True} + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_plugin_without_name_attribute(c, s, a): + """Test error when plugin has no name attribute""" + + class PluginWithoutName(WorkerPlugin): + pass # No name attribute + + plugin = PluginWithoutName() + + # Should raise error when checking + with pytest.raises(ValueError, match="has no 'name' attribute"): + await c.has_plugin(plugin) # ← await + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_has_plugin_custom_name(c, s, a): + """Test plugin registered with custom name""" + plugin = MyPlugin(789) + + # Register with custom name + await c.register_plugin(plugin, name="custom-name") + + # Check with custom name + assert await c.has_plugin("custom-name") # ← await + + # Original name won't work + assert not await c.has_plugin("MyPlugin") # ← await diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2d5ee2c8cf..d717e8e4aa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4039,6 +4039,7 @@ async def post(self) -> None: "unregister_worker_plugin": self.unregister_worker_plugin, "register_nanny_plugin": self.register_nanny_plugin, "unregister_nanny_plugin": self.unregister_nanny_plugin, + "get_plugin_registration_status": self.get_plugin_registration_status, "adaptive_target": self.adaptive_target, "workers_to_close": self.workers_to_close, "subscribe_worker_status": self.subscribe_worker_status, @@ -8696,6 +8697,50 @@ async def get_worker_monitor_info( ) return dict(zip(self.workers, results)) + async def get_plugin_registration_status(self, names: list[str]) -> dict[str, bool]: + """Check if plugins are registered in any plugin registry + + Checks all plugin registries (worker, scheduler, nanny) and returns True + if the plugin is found in any of them. + + Parameters + ---------- + names : list[str] + List of plugin names to check + + Returns + ------- + dict[str, bool] + Dict mapping plugin names to their registration status across all registries + """ + result = {} + for name in names: + # Check if plugin exists in any registry + result[name] = ( + name in self.worker_plugins + or name in self.plugins + or name in self.nanny_plugins + ) + return result + + async def get_worker_plugin_registration_status( + self, names: list[str] + ) -> dict[str, bool]: + """Check if worker plugins are registered""" + return {name: name in self.worker_plugins for name in names} + + async def get_scheduler_plugin_registration_status( + self, names: list[str] + ) -> dict[str, bool]: + """Check if scheduler plugins are registered""" + return {name: name in self.plugins for name in names} + + async def get_nanny_plugin_registration_status( + self, names: list[str] + ) -> dict[str, bool]: + """Check if nanny plugins are registered""" + return {name: name in self.nanny_plugins for name in names} + ########### # Cleanup # ###########