From b6242d90586a51fee86d0f9597243f3f318120ee Mon Sep 17 00:00:00 2001 From: YgLK Date: Fri, 11 Jul 2025 22:49:49 +0200 Subject: [PATCH 1/4] fix: remove extra parameter in accelerator registry decorator --- src/lightning/fabric/accelerators/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 4959a0fb9426a..539b7aa8a01dc 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -73,14 +73,14 @@ def register( data["description"] = description data["init_params"] = init_params - def do_register(name: str, accelerator: Callable) -> Callable: + def do_register(accelerator: Callable) -> Callable: data["accelerator"] = accelerator data["accelerator_name"] = name self[name] = data return accelerator if accelerator is not None: - return do_register(name, accelerator) + return do_register(accelerator) return do_register From 79c5c422fb1714323023b966c87abb5822f98b45 Mon Sep 17 00:00:00 2001 From: YgLK Date: Sun, 10 Aug 2025 00:38:07 +0200 Subject: [PATCH 2/4] tests: add registry decorator tests --- .../accelerators/test_registry.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 8036a6f45b8a0..5be1f4683729a 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -16,6 +16,38 @@ import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator +from lightning.fabric.accelerators.registry import _AcceleratorRegistry + + +class TestAccelerator(Accelerator): + """Helper accelerator class for testing.""" + + def __init__(self, param1=None, param2=None): + self.param1 = param1 + self.param2 = param2 + super().__init__() + + def setup_device(self, device: torch.device) -> None: + pass + + def teardown(self) -> None: + pass + + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return ["foo"] * devices + + @staticmethod + def auto_device_count(): + return 3 + + @staticmethod + def is_available(): + return True def test_accelerator_registry_with_new_accelerator(): @@ -71,3 +103,75 @@ def is_available(): def test_available_accelerators_in_registry(): assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"} + + +def test_registry_as_decorator(): + """Test that the registry can be used as a decorator.""" + test_registry = _AcceleratorRegistry() + + # Test decorator usage + @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) + class DecoratorAccelerator(TestAccelerator): + pass + + # Verify registration worked + assert "test_decorator" in test_registry + assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" + assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} + assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator + assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" + + # Test that we can instantiate the accelerator + instance = test_registry.get("test_decorator") + assert isinstance(instance, DecoratorAccelerator) + assert instance.param1 == "value1" + assert instance.param2 == 42 + + +def test_registry_as_static_method(): + """Test that the registry can be used as a static method call.""" + test_registry = _AcceleratorRegistry() + + class StaticMethodAccelerator(TestAccelerator): + pass + + # Test static method usage + result = test_registry.register( + "test_static", + StaticMethodAccelerator, + description="Test static method accelerator", + param1="static_value", + param2=100 + ) + + # Verify registration worked + assert "test_static" in test_registry + assert test_registry["test_static"]["description"] == "Test static method accelerator" + assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100} + assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator + assert test_registry["test_static"]["accelerator_name"] == "test_static" + assert result == StaticMethodAccelerator # Should return the accelerator class + + # Test that we can instantiate the accelerator + instance = test_registry.get("test_static") + assert isinstance(instance, StaticMethodAccelerator) + assert instance.param1 == "static_value" + assert instance.param2 == 100 + + +def test_registry_without_parameters(): + """Test registration without init parameters.""" + test_registry = _AcceleratorRegistry() + + class SimpleAccelerator(TestAccelerator): + def __init__(self): + super().__init__() + + test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") + + assert "simple" in test_registry + assert test_registry["simple"]["description"] == "Simple accelerator" + assert test_registry["simple"]["init_params"] == {} + + instance = test_registry.get("simple") + assert isinstance(instance, SimpleAccelerator) From 9a4d7924b03b2284368b83b17876b580c363ed1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Aug 2025 22:52:58 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../accelerators/test_registry.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 5be1f4683729a..b88ecf1db1e57 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -21,7 +21,7 @@ class TestAccelerator(Accelerator): """Helper accelerator class for testing.""" - + def __init__(self, param1=None, param2=None): self.param1 = param1 self.param2 = param2 @@ -108,19 +108,19 @@ def test_available_accelerators_in_registry(): def test_registry_as_decorator(): """Test that the registry can be used as a decorator.""" test_registry = _AcceleratorRegistry() - + # Test decorator usage @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) class DecoratorAccelerator(TestAccelerator): pass - + # Verify registration worked assert "test_decorator" in test_registry assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" - + # Test that we can instantiate the accelerator instance = test_registry.get("test_decorator") assert isinstance(instance, DecoratorAccelerator) @@ -131,19 +131,19 @@ class DecoratorAccelerator(TestAccelerator): def test_registry_as_static_method(): """Test that the registry can be used as a static method call.""" test_registry = _AcceleratorRegistry() - + class StaticMethodAccelerator(TestAccelerator): pass - + # Test static method usage result = test_registry.register( - "test_static", - StaticMethodAccelerator, - description="Test static method accelerator", - param1="static_value", - param2=100 + "test_static", + StaticMethodAccelerator, + description="Test static method accelerator", + param1="static_value", + param2=100, ) - + # Verify registration worked assert "test_static" in test_registry assert test_registry["test_static"]["description"] == "Test static method accelerator" @@ -151,7 +151,7 @@ class StaticMethodAccelerator(TestAccelerator): assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator assert test_registry["test_static"]["accelerator_name"] == "test_static" assert result == StaticMethodAccelerator # Should return the accelerator class - + # Test that we can instantiate the accelerator instance = test_registry.get("test_static") assert isinstance(instance, StaticMethodAccelerator) @@ -162,16 +162,16 @@ class StaticMethodAccelerator(TestAccelerator): def test_registry_without_parameters(): """Test registration without init parameters.""" test_registry = _AcceleratorRegistry() - + class SimpleAccelerator(TestAccelerator): def __init__(self): super().__init__() - + test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") - + assert "simple" in test_registry assert test_registry["simple"]["description"] == "Simple accelerator" assert test_registry["simple"]["init_params"] == {} - + instance = test_registry.get("simple") assert isinstance(instance, SimpleAccelerator) From c33bb1def3c179f8198d8483266993f8c3aa2bec Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 11 Aug 2025 13:47:12 +0200 Subject: [PATCH 4/4] chlog --- src/lightning/fabric/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 18537ca15e2fc..2b1d651d1c027 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029)) +- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975)) + + --- ## [2.5.2] - 2025-3-20