diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py index d647b73a6..7744d71e6 100644 --- a/QEfficient/finetune/experimental/core/component_registry.py +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -4,3 +4,197 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + + +import logging +from typing import Callable, Dict, Optional, Type + +# from QEfficient.finetune.experimental.core.logger import get_logger + +# logger = get_logger() +logger = logging.getLogger(__name__) + + +def get_object(obj_dict: Dict, name: str, object_type: str, list_fn: Callable) -> Optional[Type]: + """Utility to get object from a dictionary with error handling.""" + obj = obj_dict.get(name) + if obj is None: + raise ValueError(f"Unknown {object_type}: {name}. Available: {list_fn()}") + return obj + + +class ComponentRegistry: + """Registry for managing different training components.""" + + def __init__(self): + self._optimizers: Dict[str, Type] = {} + self._schedulers: Dict[str, Type] = {} + self._datasets: Dict[str, Type] = {} + self._models: Dict[str, Type] = {} + self._data_collators: Dict[str, Type] = {} + self._metrics: Dict[str, Type] = {} + self._loss_functions: Dict[str, Type] = {} + self._callbacks: Dict[str, Type] = {} + self._hooks: Dict[str, Type] = {} + self._trainer_modules: Dict[str, Type] = {} + + def trainer_module(self, name: str, args_cls=None, required_kwargs=None): + """ + Decorator to register a trainer module with its configuration. + Each trainer module has to be binded to its args class and required kwargs. + + Args: + name: Name of the trainer type + args_cls: The arguments class for this trainer + required_kwargs: Dictionary of required keyword arguments and their default values + """ + required_kwargs = required_kwargs or {} + + def decorator(trainer_cls): + self._trainer_modules[name] = { + "trainer_cls": trainer_cls, + "args_cls": args_cls, + "required_kwargs": required_kwargs, + } + logger.info(f"Registered trainer module: {name}") + return self._trainer_modules[name] + + return decorator + + def optimizer(self, name: str): + """Decorator to register an optimizer class.""" + + def decorator(cls: Type): + self._optimizers[name] = cls + logger.info(f"Registered optimizer: {name}") + return cls + + return decorator + + def scheduler(self, name: str): + """Decorator to register a scheduler class.""" + + def decorator(cls: Type): + self._schedulers[name] = cls + logger.info(f"Registered scheduler: {name}") + return cls + + return decorator + + def dataset(self, name: str): + """Decorator to register a dataset class.""" + + def decorator(cls: Type): + self._datasets[name] = cls + logger.info(f"Registered dataset: {name}") + return cls + + return decorator + + def model(self, name: str): + """Decorator to register a model class.""" + + def decorator(cls: Type): + self._models[name] = cls + logger.info(f"Registered model: {name}") + return cls + + return decorator + + def data_collator(self, name: str): + """Decorator to register a data collator class.""" + + def decorator(fn_pointer: Type): + self._data_collators[name] = fn_pointer + logger.info(f"Registered data collator: {name}") + return fn_pointer + + return decorator + + def loss_function(self, name: str): + """Decorator to register a loss function class.""" + + def decorator(cls: Type): + self._loss_functions[name] = cls + logger.info(f"Registered loss function: {name}") + return cls + + return decorator + + def callback(self, name: str): + """Decorator to register a callback class.""" + + def decorator(cls: Type): + self._callbacks[name] = cls + logger.info(f"Registered callback: {name}") + return cls + + return decorator + + def get_trainer_module(self, name: str) -> Optional[Type]: + """Get trainer module class by name.""" + return get_object(self._trainer_modules, name, "trainer module", self.list_trainer_modules) + + def get_optimizer(self, name: str) -> Optional[Type]: + """Get optimizer class by name.""" + return get_object(self._optimizers, name, "optimizer", self.list_optimizers) + + def get_scheduler(self, name: str) -> Optional[Type]: + """Get scheduler class by name.""" + return get_object(self._schedulers, name, "scheduler", self.list_schedulers) + + def get_dataset(self, name: str) -> Optional[Type]: + """Get dataset class by name.""" + return get_object(self._datasets, name, "dataset", self.list_datasets) + + def get_model(self, name: str) -> Optional[Type]: + """Get model class by name.""" + return get_object(self._models, name, "model", self.list_models) + + def get_data_collator(self, name: str) -> Optional[Type]: + """Get data collator class by name.""" + return get_object(self._data_collators, name, "data collator", self.list_data_collators) + + def get_loss_function(self, name: str) -> Optional[Type]: + """Get loss function class by name.""" + return get_object(self._loss_functions, name, "loss function", self.list_loss_functions) + + def get_callback(self, name: str) -> Optional[Type]: + """Get callback class by name.""" + return get_object(self._callbacks, name, "callback", self.list_callbacks) + + def list_trainer_modules(self) -> list[str]: + """List all registered trainer modules.""" + return list(self._trainer_modules.keys()) + + def list_optimizers(self) -> list[str]: + """List all registered optimizers.""" + return list(self._optimizers.keys()) + + def list_schedulers(self) -> list[str]: + """List all registered schedulers.""" + return list(self._schedulers.keys()) + + def list_datasets(self) -> list[str]: + """List all registered datasets.""" + return list(self._datasets.keys()) + + def list_models(self) -> list[str]: + """List all registered models.""" + return list(self._models.keys()) + + def list_data_collators(self) -> list[str]: + """List all registered data collators.""" + return list(self._data_collators.keys()) + + def list_loss_functions(self) -> list[str]: + """List all registered loss functions.""" + return list(self._loss_functions.keys()) + + def list_callbacks(self) -> list[str]: + """List all registered callbacks.""" + return list(self._callbacks.keys()) + + +# Global registry instance +registry = ComponentRegistry() diff --git a/QEfficient/finetune/experimental/tests/test_registry.py b/QEfficient/finetune/experimental/tests/test_registry.py new file mode 100644 index 000000000..3e10aa820 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_registry.py @@ -0,0 +1,167 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest + +from QEfficient.finetune.experimental.core.component_registry import ComponentRegistry, get_object, registry + + +class TestComponentRegistry: + @pytest.fixture(autouse=True) + def setUp(self): + """Set up test fixtures before each test method.""" + self.registry = ComponentRegistry() + + @pytest.mark.parametrize( + "register_method, get_method, object_name", + [ + ("trainer_module", "get_trainer_module", "test_trainer"), + ("optimizer", "get_optimizer", "test_optimizer"), + ("scheduler", "get_scheduler", "test_scheduler"), + ("dataset", "get_dataset", "test_dataset"), + ("model", "get_model", "test_model"), + ("data_collator", "get_data_collator", "test_collator"), + ("loss_function", "get_loss_function", "test_loss"), + ("callback", "get_callback", "test_callback"), + ], + ) + def test_object_success(self, register_method: str, get_method: str, object_name: str): + """Test object registration decorator.""" + + class MockObject: + pass + + # Register with decorator + getattr(self.registry, register_method)(object_name)(MockObject) + + # Verify registration + retrieved = getattr(self.registry, get_method)(object_name) + if register_method == "trainer_module": + retrieved = retrieved["trainer_cls"] + assert retrieved == MockObject + + @pytest.mark.parametrize( + "object_type, get_method", + [ + ("trainer module", "get_trainer_module"), + ("optimizer", "get_optimizer"), + ("scheduler", "get_scheduler"), + ("dataset", "get_dataset"), + ("model", "get_model"), + ("data collator", "get_data_collator"), + ("loss function", "get_loss_function"), + ("callback", "get_callback"), + ], + ) + def test_object_failure(self, object_type: str, get_method: str, object_name: str = "non_existent"): + """Test failure when retrieving non-existent object.""" + with pytest.raises(ValueError) as exc_info: + getattr(self.registry, get_method)(object_name) + + assert f"Unknown {object_type}" in str(exc_info.value) + + def test_init_empty_registries(self): + """Test that all registries are initialized as empty dictionaries.""" + assert len(self.registry._optimizers) == 0 + assert len(self.registry._schedulers) == 0 + assert len(self.registry._datasets) == 0 + assert len(self.registry._models) == 0 + assert len(self.registry._data_collators) == 0 + assert len(self.registry._metrics) == 0 + assert len(self.registry._loss_functions) == 0 + assert len(self.registry._callbacks) == 0 + assert len(self.registry._hooks) == 0 + assert len(self.registry._trainer_modules) == 0 + + def test_trainer_module_with_args_and_kwargs(self): + """Test trainer module registration with args class and required kwargs.""" + + class MockArgs: + pass + + class MockTrainer: + pass + + # Register with decorator including args class and required kwargs + self.registry.trainer_module( + "test_trainer_with_args", args_cls=MockArgs, required_kwargs={"param1": "default1", "param2": "default2"} + )(MockTrainer) + + # Verify registration details + module_info = self.registry.get_trainer_module("test_trainer_with_args") + assert module_info["trainer_cls"] == MockTrainer + assert module_info["args_cls"] == MockArgs + assert module_info["required_kwargs"] == {"param1": "default1", "param2": "default2"} + + def test_list_methods(self): + """Test all list methods return correct keys.""" + + # Register some dummy items + class DummyClass: + pass + + self.registry.optimizer("opt1")(DummyClass) + self.registry.scheduler("sched1")(DummyClass) + self.registry.dataset("ds1")(DummyClass) + self.registry.model("model1")(DummyClass) + self.registry.data_collator("coll1")(lambda x: x) + self.registry.loss_function("loss1")(DummyClass) + self.registry.callback("cb1")(DummyClass) + self.registry.trainer_module("tm1")(DummyClass) + + # Test lists + assert self.registry.list_optimizers() == ["opt1"] + assert self.registry.list_schedulers() == ["sched1"] + assert self.registry.list_datasets() == ["ds1"] + assert self.registry.list_models() == ["model1"] + assert self.registry.list_data_collators() == ["coll1"] + assert self.registry.list_loss_functions() == ["loss1"] + assert self.registry.list_callbacks() == ["cb1"] + assert self.registry.list_trainer_modules() == ["tm1"] + + def test_logging_on_registration(self, mocker): + """Test that registration logs messages.""" + mock_logger = mocker.patch("QEfficient.finetune.experimental.core.component_registry.logger") + + class MockClass: + pass + + # Test optimizer registration logging + self.registry.optimizer("test_opt")(MockClass) + mock_logger.info.assert_called_with("Registered optimizer: test_opt") + + # Reset mock + mock_logger.reset_mock() + + # Test trainer module registration logging + self.registry.trainer_module("test_tm")(MockClass) + mock_logger.info.assert_called_with("Registered trainer module: test_tm") + + +class TestGetObjectFunction: + def test_get_object_success(self): + """Test get_object function success case.""" + test_dict = {"key1": "value1", "key2": "value2"} + + result = get_object(test_dict, "key1", "test_type", lambda: ["key1", "key2"]) + assert result == "value1" + + def test_get_object_failure(self): + """Test get_object function failure case.""" + test_dict = {"key1": "value1"} + + with pytest.raises(ValueError) as exc_info: + get_object(test_dict, "nonexistent", "test_type", lambda: ["key1", "key2"]) + + assert "Unknown test_type: nonexistent" in str(exc_info.value) + assert "Available: ['key1', 'key2']" in str(exc_info.value) + + +class TestGlobalRegistry: + def test_global_registry_instance(self): + """Test that global registry instance exists and is of correct type.""" + assert isinstance(registry, ComponentRegistry)