-
Notifications
You must be signed in to change notification settings - Fork 3.6k
fix: remove extra parameter in accelerator registry decorator #20975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
fix: remove extra parameter in accelerator registry decorator #20975
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20975 +/- ##
=======================================
Coverage 87% 87%
=======================================
Files 268 268
Lines 23460 23460
=======================================
Hits 20399 20399
Misses 3061 3061 |
@@ -73,14 +73,14 @@ def register( | |||
data["description"] = description | |||
data["init_params"] = init_params | |||
|
|||
def do_register(name: str, accelerator: Callable) -> Callable: | |||
def do_register(accelerator: Callable) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets pls add a test to reproduce your reported casem and demonstrate that this change fixes it 🦩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests have been updated. I’d be glad if you could take a look at them
@Borda @YgLK here are a new version of https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/tests_fabric/accelerators/test_registry.py that includes relevant testing. In particular the # Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import pytest
import torch
from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.exceptions import MisconfigurationException
class TestAccelerator(Accelerator):
"""Helper accelerator class for testing."""
def __init__(self, param1=None, param2=None):
self.param1 = param1
self.param2 = param2
super().__init__()
def setup_device(self, device: torch.device) -> None:
pass
def teardown(self) -> None:
pass
@staticmethod
def parse_devices(devices):
return devices
@staticmethod
def get_parallel_devices(devices):
return ["foo"] * devices
@staticmethod
def auto_device_count():
return 3
@staticmethod
def is_available():
return True
def test_accelerator_registry_with_new_accelerator():
accelerator_name = "custom_accelerator"
accelerator_description = "Custom Accelerator"
class CustomAccelerator(Accelerator):
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
super().__init__()
def setup_device(self, device: torch.device) -> None:
pass
def get_device_stats(self, device: torch.device) -> dict[str, Any]:
pass
def teardown(self) -> None:
pass
@staticmethod
def parse_devices(devices):
return devices
@staticmethod
def get_parallel_devices(devices):
return ["foo"] * devices
@staticmethod
def auto_device_count():
return 3
@staticmethod
def is_available():
return True
ACCELERATOR_REGISTRY.register(
accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123
)
assert accelerator_name in ACCELERATOR_REGISTRY
assert ACCELERATOR_REGISTRY[accelerator_name]["description"] == accelerator_description
assert ACCELERATOR_REGISTRY[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123}
assert ACCELERATOR_REGISTRY[accelerator_name]["accelerator_name"] == accelerator_name
assert isinstance(ACCELERATOR_REGISTRY.get(accelerator_name), CustomAccelerator)
ACCELERATOR_REGISTRY.remove(accelerator_name)
assert accelerator_name not in ACCELERATOR_REGISTRY
def test_available_accelerators_in_registry():
assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"}
def test_registry_as_decorator():
"""Test that the registry can be used as a decorator."""
test_registry = _AcceleratorRegistry()
# Test decorator usage
@test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42)
class DecoratorAccelerator(TestAccelerator):
pass
# Verify registration worked
assert "test_decorator" in test_registry
assert test_registry["test_decorator"]["description"] == "Test decorator accelerator"
assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42}
assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator
assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator"
# Test that we can instantiate the accelerator
instance = test_registry.get("test_decorator")
assert isinstance(instance, DecoratorAccelerator)
assert instance.param1 == "value1"
assert instance.param2 == 42
def test_registry_as_static_method():
"""Test that the registry can be used as a static method call."""
test_registry = _AcceleratorRegistry()
class StaticMethodAccelerator(TestAccelerator):
pass
# Test static method usage
result = test_registry.register(
"test_static",
StaticMethodAccelerator,
description="Test static method accelerator",
param1="static_value",
param2=100
)
# Verify registration worked
assert "test_static" in test_registry
assert test_registry["test_static"]["description"] == "Test static method accelerator"
assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100}
assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator
assert test_registry["test_static"]["accelerator_name"] == "test_static"
assert result == StaticMethodAccelerator # Should return the accelerator class
# Test that we can instantiate the accelerator
instance = test_registry.get("test_static")
assert isinstance(instance, StaticMethodAccelerator)
assert instance.param1 == "static_value"
assert instance.param2 == 100
def test_registry_without_parameters():
"""Test registration without init parameters."""
test_registry = _AcceleratorRegistry()
class SimpleAccelerator(TestAccelerator):
def __init__(self):
super().__init__()
test_registry.register("simple", SimpleAccelerator, description="Simple accelerator")
assert "simple" in test_registry
assert test_registry["simple"]["description"] == "Simple accelerator"
assert test_registry["simple"]["init_params"] == {}
instance = test_registry.get("simple")
assert isinstance(instance, SimpleAccelerator) |
@SkafteNicki thanks for help - it looks great. I updated the tests 79c5c42 |
What does this PR do?
Fixes #20973
📚 Documentation preview 📚: https://pytorch-lightning--20975.org.readthedocs.build/en/20975/