Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lightning/fabric/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def register(
data["description"] = description
data["init_params"] = init_params

def do_register(name: str, accelerator: Callable) -> Callable:
def do_register(accelerator: Callable) -> Callable:
data["accelerator"] = accelerator
data["accelerator_name"] = name
self[name] = data
return accelerator

if accelerator is not None:
return do_register(name, accelerator)
return do_register(accelerator)

return do_register

Expand Down
104 changes: 104 additions & 0 deletions tests/tests_fabric/accelerators/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,38 @@
import torch

from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry


class TestAccelerator(Accelerator):
"""Helper accelerator class for testing."""

def __init__(self, param1=None, param2=None):
self.param1 = param1
self.param2 = param2
super().__init__()

def setup_device(self, device: torch.device) -> None:
pass

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices):
return devices

@staticmethod
def get_parallel_devices(devices):
return ["foo"] * devices

@staticmethod
def auto_device_count():
return 3

@staticmethod
def is_available():
return True


def test_accelerator_registry_with_new_accelerator():
Expand Down Expand Up @@ -71,3 +103,75 @@ def is_available():

def test_available_accelerators_in_registry():
assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"}


def test_registry_as_decorator():
"""Test that the registry can be used as a decorator."""
test_registry = _AcceleratorRegistry()

# Test decorator usage
@test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42)
class DecoratorAccelerator(TestAccelerator):
pass

# Verify registration worked
assert "test_decorator" in test_registry
assert test_registry["test_decorator"]["description"] == "Test decorator accelerator"
assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42}
assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator
assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator"

# Test that we can instantiate the accelerator
instance = test_registry.get("test_decorator")
assert isinstance(instance, DecoratorAccelerator)
assert instance.param1 == "value1"
assert instance.param2 == 42


def test_registry_as_static_method():
"""Test that the registry can be used as a static method call."""
test_registry = _AcceleratorRegistry()

class StaticMethodAccelerator(TestAccelerator):
pass

# Test static method usage
result = test_registry.register(
"test_static",
StaticMethodAccelerator,
description="Test static method accelerator",
param1="static_value",
param2=100,
)

# Verify registration worked
assert "test_static" in test_registry
assert test_registry["test_static"]["description"] == "Test static method accelerator"
assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100}
assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator
assert test_registry["test_static"]["accelerator_name"] == "test_static"
assert result == StaticMethodAccelerator # Should return the accelerator class

# Test that we can instantiate the accelerator
instance = test_registry.get("test_static")
assert isinstance(instance, StaticMethodAccelerator)
assert instance.param1 == "static_value"
assert instance.param2 == 100


def test_registry_without_parameters():
"""Test registration without init parameters."""
test_registry = _AcceleratorRegistry()

class SimpleAccelerator(TestAccelerator):
def __init__(self):
super().__init__()

test_registry.register("simple", SimpleAccelerator, description="Simple accelerator")

assert "simple" in test_registry
assert test_registry["simple"]["description"] == "Simple accelerator"
assert test_registry["simple"]["init_params"] == {}

instance = test_registry.get("simple")
assert isinstance(instance, SimpleAccelerator)
Loading