Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5456,6 +5456,86 @@ def unregister_worker_plugin(self, name, nanny=None):
"""
return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny)

def has_plugin(
self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence
) -> bool | dict[str, bool]:
"""Check if plugin(s) are registered

Parameters
----------
plugin : str | plugin object | Sequence
Plugin to check. You can use the plugin object directly or the plugin name.
For plugin objects, they must have a 'name' attribute. You can also pass
a sequence of plugin objects or names.

Returns
-------
bool or dict[str, bool]
If name is str: True if plugin is registered, False otherwise
If name is Sequence: dict mapping names to registration status

Examples
--------
>>> logging_plugin = LoggingConfigPlugin() # Has name = "logging-config"
>>> client.register_plugin(logging_plugin)
>>> client.has_plugin(logging_plugin)
True

>>> client.has_plugin('logging-config')
True

>>> client.has_plugin([logging_plugin, 'other-plugin'])
{'logging-config': True, 'other-plugin': False}
"""
return self.sync(self._has_plugin_async, plugin=plugin)

async def _has_plugin_async(
self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence
) -> bool | dict[str, bool]:
"""Async implementation for checking plugin registration"""

# Convert plugin to list of names
if isinstance(plugin, str):
names_to_check = [plugin]
return_single = True
elif isinstance(plugin, (WorkerPlugin, SchedulerPlugin, NannyPlugin)):
plugin_name = getattr(plugin, "name", None)
if plugin_name is None:
Comment on lines +5501 to +5503
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible simplification: if we detect a single plugin here, we can assign it to

plugin = [plugin]

and then fall through to the list case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see that the return type is different in that case (bool vs. dict[str, bool]). We could add a local unbox variable to handle this, but perhaps not worth it.

raise ValueError(
f"Plugin {funcname(type(plugin))} has no 'name' attribute. "
"Please add a 'name' attribute to your plugin class."
)
names_to_check = [plugin_name]
return_single = True
elif isinstance(plugin, Sequence):
names_to_check = []
for p in plugin:
if isinstance(p, str):
names_to_check.append(p)
else:
plugin_name = getattr(p, "name", None)
if plugin_name is None:
raise ValueError(
f"Plugin {funcname(type(p))} has no 'name' attribute"
)
names_to_check.append(plugin_name)
return_single = False
else:
raise TypeError(
f"plugin must be a plugin object, name string, or Sequence. Got {type(plugin)}"
)

# Get status from scheduler
result = await self.scheduler.get_plugin_registration_status(
names=names_to_check
)

# Return single bool or dict based on input
if return_single:
return result[names_to_check[0]]
else:
return result

@property
def amm(self):
"""Convenience accessors for the :doc:`active_memory_manager`"""
Expand Down
151 changes: 151 additions & 0 deletions distributed/diagnostics/tests/test_nanny_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,154 @@ async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s):
logs = caplog.getvalue()
assert "TestPlugin1 failed to teardown" in logs
assert "test error" in logs


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_has_nanny_plugin_by_name(c, s, a):
"""Test checking if nanny plugin is registered using string name"""

class DuckPlugin(NannyPlugin):
name = "duck-plugin"

def setup(self, nanny):
nanny.foo = 123

def teardown(self, nanny):
pass
Comment on lines +226 to +233
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets redefined quite a few times in here, can we move it outside the test and define it once to DRY this out a little?


# Check non-existent plugin
assert not await c.has_plugin("duck-plugin")

# Register plugin
await c.register_plugin(DuckPlugin())
assert a.foo == 123

# Check using string name
assert await c.has_plugin("duck-plugin")

# Unregister and check again
await c.unregister_worker_plugin("duck-plugin", nanny=True)
assert not await c.has_plugin("duck-plugin")


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_has_nanny_plugin_by_object(c, s, a):
"""Test checking if nanny plugin is registered using plugin object"""

class DuckPlugin(NannyPlugin):
name = "duck-plugin"

def setup(self, nanny):
nanny.bar = 456

def teardown(self, nanny):
pass

plugin = DuckPlugin()

# Check before registration
assert not await c.has_plugin(plugin)

# Register and check
await c.register_plugin(plugin)
assert a.bar == 456
assert await c.has_plugin(plugin)

# Unregister and check
await c.unregister_worker_plugin("duck-plugin", nanny=True)
assert not await c.has_plugin(plugin)


@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], Worker=Nanny)
async def test_has_nanny_plugin_multiple_nannies(c, s, a, b):
"""Test checking nanny plugin with multiple nannies"""

class DuckPlugin(NannyPlugin):
name = "duck-plugin"

def setup(self, nanny):
nanny.multi = "setup"

def teardown(self, nanny):
pass

# Check before registration
assert not await c.has_plugin("duck-plugin")

# Register plugin (should propagate to all nannies)
await c.register_plugin(DuckPlugin())

# Verify both nannies have the plugin
assert a.multi == "setup"
assert b.multi == "setup"

# Check plugin is registered
assert await c.has_plugin("duck-plugin")


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_has_nanny_plugin_custom_name_override(c, s, a):
"""Test nanny plugin registered with custom name different from class name"""

class DuckPlugin(NannyPlugin):
name = "duck-plugin"

def setup(self, nanny):
nanny.custom = "test"

def teardown(self, nanny):
pass

plugin = DuckPlugin()

# Register with custom name (overriding the class name attribute)
await c.register_plugin(plugin, name="custom-override")

# Check with custom name works
assert await c.has_plugin("custom-override")

# Original name won't work since we overrode it
assert not await c.has_plugin("duck-plugin")


@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_has_nanny_plugin_list_check(c, s, a):
"""Test checking multiple nanny plugins at once"""

class IdempotentPlugin(NannyPlugin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about the naming here. Why Idempotent and NonIdempotent?

name = "idempotentplugin"

def setup(self, nanny):
pass

def teardown(self, nanny):
pass

class NonIdempotentPlugin(NannyPlugin):
name = "nonidempotentplugin"

def setup(self, nanny):
pass

def teardown(self, nanny):
pass

# Check multiple before registration
result = await c.has_plugin(
["idempotentplugin", "nonidempotentplugin", "nonexistent"]
)
assert result == {
"idempotentplugin": False,
"nonidempotentplugin": False,
"nonexistent": False,
}

# Register first plugin
await c.register_plugin(IdempotentPlugin())
result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"])
assert result == {"idempotentplugin": True, "nonidempotentplugin": False}

# Register second plugin
await c.register_plugin(NonIdempotentPlugin())
result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"])
assert result == {"idempotentplugin": True, "nonidempotentplugin": True}
135 changes: 135 additions & 0 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,138 @@ def __init__(self, instance=None):
await s.register_scheduler_plugin(plugin=dumps(third))
assert "nonidempotentplugin" in s.plugins
assert s.plugins["nonidempotentplugin"].instance == "third"


@gen_cluster(client=True)
async def test_has_scheduler_plugin_by_name(c, s, a, b):
"""Test checking if scheduler plugin is registered using string name"""

class Dummy1(SchedulerPlugin):
name = "Dummy1"

def start(self, scheduler):
scheduler.foo = "bar"

# Check non-existent plugin
assert not await c.has_plugin("Dummy1")

# Register plugin
await c.register_plugin(Dummy1())
assert s.foo == "bar"

# Check using string name
assert await c.has_plugin("Dummy1")

# Unregister and check again
await c.unregister_scheduler_plugin("Dummy1")
assert not await c.has_plugin("Dummy1")


@gen_cluster(client=True)
async def test_has_scheduler_plugin_by_object(c, s, a, b):
"""Test checking if scheduler plugin is registered using plugin object"""

class Dummy2(SchedulerPlugin):
name = "Dummy2"

def start(self, scheduler):
scheduler.check_value = 42

plugin = Dummy2()

# Check before registration
assert not await c.has_plugin(plugin)

# Register and check
await c.register_plugin(plugin)
assert s.check_value == 42
assert await c.has_plugin(plugin)

# Unregister and check
await c.unregister_scheduler_plugin("Dummy2")
assert not await c.has_plugin(plugin)


@gen_cluster(client=True)
async def test_has_plugin_mixed_scheduler_and_worker_types(c, s, a, b):
"""Test checking scheduler and worker plugins together"""
from distributed import WorkerPlugin

class MyPlugin(SchedulerPlugin):
name = "MyPlugin"

def start(self, scheduler):
scheduler.my_value = "scheduler"

class MyWorkerPlugin(WorkerPlugin):
name = "MyWorkerPlugin"

def setup(self, worker):
worker.my_value = "worker"

sched_plugin = MyPlugin()
work_plugin = MyWorkerPlugin()

# Register both types
await c.register_plugin(sched_plugin)
await c.register_plugin(work_plugin)

# Verify both registered
assert s.my_value == "scheduler"
assert a.my_value == "worker"
assert b.my_value == "worker"

# Check both with list of names
result = await c.has_plugin(["MyPlugin", "MyWorkerPlugin"])
assert result == {"MyPlugin": True, "MyWorkerPlugin": True}

# Check both with objects
assert await c.has_plugin(sched_plugin)
assert await c.has_plugin(work_plugin)

# Check non-existent alongside real ones
result = await c.has_plugin(["MyPlugin", "nonexistent", "MyWorkerPlugin"])
assert result == {"MyPlugin": True, "nonexistent": False, "MyWorkerPlugin": True}


@gen_cluster(client=True, nthreads=[])
async def test_has_scheduler_plugin_no_workers(c, s):
"""Test checking scheduler plugin when no workers exist"""

class Plugin(SchedulerPlugin):
name = "plugin"

def start(self, scheduler):
scheduler.no_worker_test = True

# Check before registration
assert not await c.has_plugin("plugin")

# Register plugin when no workers exist
await c.register_plugin(Plugin())
assert s.no_worker_test is True

# Check after registration
assert await c.has_plugin("plugin")


@gen_cluster(client=True)
async def test_has_scheduler_plugin_custom_name_override(c, s, a, b):
"""Test scheduler plugin registered with custom name different from class name"""

class Dummy3(SchedulerPlugin):
name = "Dummy3"

def start(self, scheduler):
scheduler.name_test = "custom"

plugin = Dummy3()

# Register with custom name (overriding the class name attribute)
await c.register_plugin(plugin, name="custom-override")

# Check with custom name works
assert await c.has_plugin("custom-override")

# Original name won't work since we overrode it
assert not await c.has_plugin("Dummy3")
Loading
Loading