Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions QEfficient/finetune/experimental/core/component_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,197 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------


import logging
from typing import Callable, Dict, Optional, Type

# from QEfficient.finetune.experimental.core.logger import get_logger

# logger = get_logger()
logger = logging.getLogger(__name__)


def get_object(obj_dict: Dict, name: str, object_type: str, list_fn: Callable) -> Optional[Type]:
"""Utility to get object from a dictionary with error handling."""
obj = obj_dict.get(name)
if obj is None:
raise ValueError(f"Unknown {object_type}: {name}. Available: {list_fn()}")
return obj


class ComponentRegistry:
"""Registry for managing different training components."""

def __init__(self):
self._optimizers: Dict[str, Type] = {}
self._schedulers: Dict[str, Type] = {}
self._datasets: Dict[str, Type] = {}
self._models: Dict[str, Type] = {}
self._data_collators: Dict[str, Type] = {}
self._metrics: Dict[str, Type] = {}
self._loss_functions: Dict[str, Type] = {}
self._callbacks: Dict[str, Type] = {}
self._hooks: Dict[str, Type] = {}
self._trainer_modules: Dict[str, Type] = {}

def trainer_module(self, name: str, args_cls=None, required_kwargs=None):
"""
Decorator to register a trainer module with its configuration.
Each trainer module has to be binded to its args class and required kwargs.

Args:
name: Name of the trainer type
args_cls: The arguments class for this trainer
required_kwargs: Dictionary of required keyword arguments and their default values
"""
required_kwargs = required_kwargs or {}

def decorator(trainer_cls):
self._trainer_modules[name] = {
"trainer_cls": trainer_cls,
"args_cls": args_cls,
"required_kwargs": required_kwargs,
}
logger.info(f"Registered trainer module: {name}")
return self._trainer_modules[name]

return decorator

def optimizer(self, name: str):
"""Decorator to register an optimizer class."""

def decorator(cls: Type):
self._optimizers[name] = cls
logger.info(f"Registered optimizer: {name}")
return cls

return decorator

def scheduler(self, name: str):
"""Decorator to register a scheduler class."""

def decorator(cls: Type):
self._schedulers[name] = cls
logger.info(f"Registered scheduler: {name}")
return cls

return decorator

def dataset(self, name: str):
"""Decorator to register a dataset class."""

def decorator(cls: Type):
self._datasets[name] = cls
logger.info(f"Registered dataset: {name}")
return cls

return decorator

def model(self, name: str):
"""Decorator to register a model class."""

def decorator(cls: Type):
self._models[name] = cls
logger.info(f"Registered model: {name}")
return cls

return decorator

def data_collator(self, name: str):
"""Decorator to register a data collator class."""

def decorator(fn_pointer: Type):
self._data_collators[name] = fn_pointer
logger.info(f"Registered data collator: {name}")
return fn_pointer

return decorator

def loss_function(self, name: str):
"""Decorator to register a loss function class."""

def decorator(cls: Type):
self._loss_functions[name] = cls
logger.info(f"Registered loss function: {name}")
return cls

return decorator

def callback(self, name: str):
"""Decorator to register a callback class."""

def decorator(cls: Type):
self._callbacks[name] = cls
logger.info(f"Registered callback: {name}")
return cls

return decorator

def get_trainer_module(self, name: str) -> Optional[Type]:
"""Get trainer module class by name."""
return get_object(self._trainer_modules, name, "trainer module", self.list_trainer_modules)

def get_optimizer(self, name: str) -> Optional[Type]:
"""Get optimizer class by name."""
return get_object(self._optimizers, name, "optimizer", self.list_optimizers)

def get_scheduler(self, name: str) -> Optional[Type]:
"""Get scheduler class by name."""
return get_object(self._schedulers, name, "scheduler", self.list_schedulers)

def get_dataset(self, name: str) -> Optional[Type]:
"""Get dataset class by name."""
return get_object(self._datasets, name, "dataset", self.list_datasets)

def get_model(self, name: str) -> Optional[Type]:
"""Get model class by name."""
return get_object(self._models, name, "model", self.list_models)

def get_data_collator(self, name: str) -> Optional[Type]:
"""Get data collator class by name."""
return get_object(self._data_collators, name, "data collator", self.list_data_collators)

def get_loss_function(self, name: str) -> Optional[Type]:
"""Get loss function class by name."""
return get_object(self._loss_functions, name, "loss function", self.list_loss_functions)

def get_callback(self, name: str) -> Optional[Type]:
"""Get callback class by name."""
return get_object(self._callbacks, name, "callback", self.list_callbacks)

def list_trainer_modules(self) -> list[str]:
"""List all registered trainer modules."""
return list(self._trainer_modules.keys())

def list_optimizers(self) -> list[str]:
"""List all registered optimizers."""
return list(self._optimizers.keys())

def list_schedulers(self) -> list[str]:
"""List all registered schedulers."""
return list(self._schedulers.keys())

def list_datasets(self) -> list[str]:
"""List all registered datasets."""
return list(self._datasets.keys())

def list_models(self) -> list[str]:
"""List all registered models."""
return list(self._models.keys())

def list_data_collators(self) -> list[str]:
"""List all registered data collators."""
return list(self._data_collators.keys())

def list_loss_functions(self) -> list[str]:
"""List all registered loss functions."""
return list(self._loss_functions.keys())

def list_callbacks(self) -> list[str]:
"""List all registered callbacks."""
return list(self._callbacks.keys())


# Global registry instance
registry = ComponentRegistry()
167 changes: 167 additions & 0 deletions QEfficient/finetune/experimental/tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import pytest

from QEfficient.finetune.experimental.core.component_registry import ComponentRegistry, get_object, registry


class TestComponentRegistry:
@pytest.fixture(autouse=True)
def setUp(self):
"""Set up test fixtures before each test method."""
self.registry = ComponentRegistry()

@pytest.mark.parametrize(
"register_method, get_method, object_name",
[
("trainer_module", "get_trainer_module", "test_trainer"),
("optimizer", "get_optimizer", "test_optimizer"),
("scheduler", "get_scheduler", "test_scheduler"),
("dataset", "get_dataset", "test_dataset"),
("model", "get_model", "test_model"),
("data_collator", "get_data_collator", "test_collator"),
("loss_function", "get_loss_function", "test_loss"),
("callback", "get_callback", "test_callback"),
],
)
def test_object_success(self, register_method: str, get_method: str, object_name: str):
"""Test object registration decorator."""

class MockObject:
pass

# Register with decorator
getattr(self.registry, register_method)(object_name)(MockObject)

# Verify registration
retrieved = getattr(self.registry, get_method)(object_name)
if register_method == "trainer_module":
retrieved = retrieved["trainer_cls"]
assert retrieved == MockObject

@pytest.mark.parametrize(
"object_type, get_method",
[
("trainer module", "get_trainer_module"),
("optimizer", "get_optimizer"),
("scheduler", "get_scheduler"),
("dataset", "get_dataset"),
("model", "get_model"),
("data collator", "get_data_collator"),
("loss function", "get_loss_function"),
("callback", "get_callback"),
],
)
def test_object_failure(self, object_type: str, get_method: str, object_name: str = "non_existent"):
"""Test failure when retrieving non-existent object."""
with pytest.raises(ValueError) as exc_info:
getattr(self.registry, get_method)(object_name)

assert f"Unknown {object_type}" in str(exc_info.value)

def test_init_empty_registries(self):
"""Test that all registries are initialized as empty dictionaries."""
assert len(self.registry._optimizers) == 0
assert len(self.registry._schedulers) == 0
assert len(self.registry._datasets) == 0
assert len(self.registry._models) == 0
assert len(self.registry._data_collators) == 0
assert len(self.registry._metrics) == 0
assert len(self.registry._loss_functions) == 0
assert len(self.registry._callbacks) == 0
assert len(self.registry._hooks) == 0
assert len(self.registry._trainer_modules) == 0

def test_trainer_module_with_args_and_kwargs(self):
"""Test trainer module registration with args class and required kwargs."""

class MockArgs:
pass

class MockTrainer:
pass

# Register with decorator including args class and required kwargs
self.registry.trainer_module(
"test_trainer_with_args", args_cls=MockArgs, required_kwargs={"param1": "default1", "param2": "default2"}
)(MockTrainer)

# Verify registration details
module_info = self.registry.get_trainer_module("test_trainer_with_args")
assert module_info["trainer_cls"] == MockTrainer
assert module_info["args_cls"] == MockArgs
assert module_info["required_kwargs"] == {"param1": "default1", "param2": "default2"}

def test_list_methods(self):
"""Test all list methods return correct keys."""

# Register some dummy items
class DummyClass:
pass

self.registry.optimizer("opt1")(DummyClass)
self.registry.scheduler("sched1")(DummyClass)
self.registry.dataset("ds1")(DummyClass)
self.registry.model("model1")(DummyClass)
self.registry.data_collator("coll1")(lambda x: x)
self.registry.loss_function("loss1")(DummyClass)
self.registry.callback("cb1")(DummyClass)
self.registry.trainer_module("tm1")(DummyClass)

# Test lists
assert self.registry.list_optimizers() == ["opt1"]
assert self.registry.list_schedulers() == ["sched1"]
assert self.registry.list_datasets() == ["ds1"]
assert self.registry.list_models() == ["model1"]
assert self.registry.list_data_collators() == ["coll1"]
assert self.registry.list_loss_functions() == ["loss1"]
assert self.registry.list_callbacks() == ["cb1"]
assert self.registry.list_trainer_modules() == ["tm1"]

def test_logging_on_registration(self, mocker):
"""Test that registration logs messages."""
mock_logger = mocker.patch("QEfficient.finetune.experimental.core.component_registry.logger")

class MockClass:
pass

# Test optimizer registration logging
self.registry.optimizer("test_opt")(MockClass)
mock_logger.info.assert_called_with("Registered optimizer: test_opt")

# Reset mock
mock_logger.reset_mock()

# Test trainer module registration logging
self.registry.trainer_module("test_tm")(MockClass)
mock_logger.info.assert_called_with("Registered trainer module: test_tm")


class TestGetObjectFunction:
def test_get_object_success(self):
"""Test get_object function success case."""
test_dict = {"key1": "value1", "key2": "value2"}

result = get_object(test_dict, "key1", "test_type", lambda: ["key1", "key2"])
assert result == "value1"

def test_get_object_failure(self):
"""Test get_object function failure case."""
test_dict = {"key1": "value1"}

with pytest.raises(ValueError) as exc_info:
get_object(test_dict, "nonexistent", "test_type", lambda: ["key1", "key2"])

assert "Unknown test_type: nonexistent" in str(exc_info.value)
assert "Available: ['key1', 'key2']" in str(exc_info.value)


class TestGlobalRegistry:
def test_global_registry_instance(self):
"""Test that global registry instance exists and is of correct type."""
assert isinstance(registry, ComponentRegistry)
Loading