Skip to content

Commit c4fa7db

Browse files
committed
Added component registry and factory functionality.
Signed-off-by: meetkuma <[email protected]>
1 parent ea26341 commit c4fa7db

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

QEfficient/finetune/experimental/core/component_registry.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,196 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
9+
import logging
10+
from typing import Tuple, Dict, Optional, Type, List, Callable
11+
12+
# from QEfficient.finetune.experimental.core.logger import get_logger
13+
14+
# logger = get_logger()
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def get_object(obj_dict: Dict, name: str, object_type: str, list_fn: Callable) -> Optional[Type]:
19+
"""Utility to get object from a dictionary with error handling."""
20+
obj = obj_dict.get(name)
21+
if obj is None:
22+
raise ValueError(f"Unknown {object_type}: {name}. Available: {list_fn()}")
23+
return obj
24+
25+
class ComponentRegistry:
26+
"""Registry for managing different training components."""
27+
28+
def __init__(self):
29+
self._optimizers: Dict[str, Type] = {}
30+
self._schedulers: Dict[str, Type] = {}
31+
self._datasets: Dict[str, Type] = {}
32+
self._models: Dict[str, Type] = {}
33+
self._data_collators: Dict[str, Type] = {}
34+
self._metrics: Dict[str, Type] = {}
35+
self._loss_functions: Dict[str, Type] = {}
36+
self._callbacks: Dict[str, Type] = {}
37+
self._hooks: Dict[str, Type] = {}
38+
self._trainer_modules: Dict[str, Type] = {}
39+
40+
def trainer_module(self, name: str, args_cls=None, required_kwargs=None):
41+
"""
42+
Decorator to register a trainer module with its configuration.
43+
Each trainer module has to be binded to its args class and required kwargs.
44+
45+
Args:
46+
name: Name of the trainer type
47+
args_cls: The arguments class for this trainer
48+
required_kwargs: Dictionary of required keyword arguments and their default values
49+
"""
50+
required_kwargs = required_kwargs or {}
51+
52+
def decorator(trainer_cls):
53+
self._trainer_modules[name] = {
54+
'trainer_cls': trainer_cls,
55+
'args_cls': args_cls,
56+
'required_kwargs': required_kwargs
57+
}
58+
logger.info(f"Registered trainer module: {name}")
59+
return self._trainer_modules[name]
60+
61+
return decorator
62+
63+
def optimizer(self, name: str):
64+
"""Decorator to register an optimizer class."""
65+
66+
def decorator(cls: Type):
67+
self._optimizers[name] = cls
68+
logger.info(f"Registered optimizer: {name}")
69+
return cls
70+
71+
return decorator
72+
73+
def scheduler(self, name: str):
74+
"""Decorator to register a scheduler class."""
75+
76+
def decorator(cls: Type):
77+
self._schedulers[name] = cls
78+
logger.info(f"Registered scheduler: {name}")
79+
return cls
80+
81+
return decorator
82+
83+
def dataset(self, name: str):
84+
"""Decorator to register a dataset class."""
85+
86+
def decorator(cls: Type):
87+
self._datasets[name] = cls
88+
logger.info(f"Registered dataset: {name}")
89+
return cls
90+
91+
return decorator
92+
93+
def model(self, name: str):
94+
"""Decorator to register a model class."""
95+
96+
def decorator(cls: Type):
97+
self._models[name] = cls
98+
logger.info(f"Registered model: {name}")
99+
return cls
100+
101+
return decorator
102+
103+
def data_collator(self, name: str):
104+
"""Decorator to register a data collator class."""
105+
106+
def decorator(fn_pointer: Type):
107+
self._data_collators[name] = fn_pointer
108+
logger.info(f"Registered data collator: {name}")
109+
return fn_pointer
110+
111+
return decorator
112+
113+
def loss_function(self, name: str):
114+
"""Decorator to register a loss function class."""
115+
116+
def decorator(cls: Type):
117+
self._loss_functions[name] = cls
118+
logger.info(f"Registered loss function: {name}")
119+
return cls
120+
121+
return decorator
122+
123+
def callback(self, name: str):
124+
"""Decorator to register a callback class."""
125+
126+
def decorator(cls: Type):
127+
self._callbacks[name] = cls
128+
logger.info(f"Registered callback: {name}")
129+
return cls
130+
131+
return decorator
132+
133+
def get_trainer_module(self, name: str) -> Optional[Type]:
134+
"""Get trainer module class by name."""
135+
return get_object(self._trainer_modules, name, "trainer module", self.list_trainer_modules)
136+
137+
def get_optimizer(self, name: str) -> Optional[Type]:
138+
"""Get optimizer class by name."""
139+
return get_object(self._optimizers, name, "optimizer", self.list_optimizers)
140+
141+
def get_scheduler(self, name: str) -> Optional[Type]:
142+
"""Get scheduler class by name."""
143+
return get_object(self._schedulers, name, "scheduler", self.list_schedulers)
144+
145+
def get_dataset(self, name: str) -> Optional[Type]:
146+
"""Get dataset class by name."""
147+
return get_object(self._datasets, name, "dataset", self.list_datasets)
148+
149+
def get_model(self, name: str) -> Optional[Type]:
150+
"""Get model class by name."""
151+
return get_object(self._models, name, "model", self.list_models)
152+
153+
def get_data_collator(self, name: str) -> Optional[Type]:
154+
"""Get data collator class by name."""
155+
return get_object(self._data_collators, name, "data collator", self.list_data_collators)
156+
157+
def get_loss_function(self, name: str) -> Optional[Type]:
158+
"""Get loss function class by name."""
159+
return get_object(self._loss_functions, name, "loss function", self.list_loss_functions)
160+
161+
def get_callback(self, name: str) -> Optional[Type]:
162+
"""Get callback class by name."""
163+
return get_object(self._callbacks, name, "callback", self.list_callbacks)
164+
165+
def list_trainer_modules(self) -> list[str]:
166+
"""List all registered trainer modules."""
167+
return list(self._trainer_modules.keys())
168+
169+
def list_optimizers(self) -> list[str]:
170+
"""List all registered optimizers."""
171+
return list(self._optimizers.keys())
172+
173+
def list_schedulers(self) -> list[str]:
174+
"""List all registered schedulers."""
175+
return list(self._schedulers.keys())
176+
177+
def list_datasets(self) -> list[str]:
178+
"""List all registered datasets."""
179+
return list(self._datasets.keys())
180+
181+
def list_models(self) -> list[str]:
182+
"""List all registered models."""
183+
return list(self._models.keys())
184+
185+
def list_data_collators(self) -> list[str]:
186+
"""List all registered data collators."""
187+
return list(self._data_collators.keys())
188+
189+
def list_loss_functions(self) -> list[str]:
190+
"""List all registered loss functions."""
191+
return list(self._loss_functions.keys())
192+
193+
def list_callbacks(self) -> list[str]:
194+
"""List all registered callbacks."""
195+
return list(self._callbacks.keys())
196+
197+
198+
# Global registry instance
199+
registry = ComponentRegistry()
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import pytest
9+
10+
from QEfficient.finetune.experimental.core.component_registry import ComponentRegistry, registry, get_object
11+
12+
13+
class TestComponentRegistry:
14+
@pytest.fixture(autouse=True)
15+
def setUp(self):
16+
"""Set up test fixtures before each test method."""
17+
self.registry = ComponentRegistry()
18+
19+
@pytest.mark.parametrize(
20+
"register_method, get_method, object_name",
21+
[
22+
("trainer_module", "get_trainer_module", "test_trainer"),
23+
("optimizer", "get_optimizer", "test_optimizer"),
24+
("scheduler", "get_scheduler", "test_scheduler"),
25+
("dataset", "get_dataset", "test_dataset"),
26+
("model", "get_model", "test_model"),
27+
("data_collator", "get_data_collator", "test_collator"),
28+
("loss_function", "get_loss_function", "test_loss"),
29+
("callback", "get_callback", "test_callback"),
30+
],
31+
)
32+
def test_object_success(self, register_method: str, get_method: str, object_name: str):
33+
"""Test object registration decorator."""
34+
35+
class MockObject:
36+
pass
37+
38+
# Register with decorator
39+
getattr(self.registry, register_method)(object_name)(MockObject)
40+
41+
# Verify registration
42+
retrieved = getattr(self.registry, get_method)(object_name)
43+
assert retrieved == MockObject
44+
45+
@pytest.mark.parametrize(
46+
"object_type, get_method",
47+
[
48+
("trainer module", "get_trainer_module"),
49+
("optimizer", "get_optimizer"),
50+
("scheduler", "get_scheduler"),
51+
("dataset", "get_dataset"),
52+
("model", "get_model"),
53+
("data collator", "get_data_collator"),
54+
("loss function", "get_loss_function"),
55+
("callback", "get_callback"),
56+
],
57+
)
58+
def test_object_failure(self, object_type: str, get_method: str, object_name: str = "non_existent"):
59+
"""Test failure when retrieving non-existent object."""
60+
with pytest.raises(ValueError) as exc_info:
61+
getattr(self.registry, get_method)(object_name)
62+
63+
assert f"Unknown {object_type}" in str(exc_info.value)
64+
65+
def test_init_empty_registries(self):
66+
"""Test that all registries are initialized as empty dictionaries."""
67+
assert len(self.registry._optimizers) == 0
68+
assert len(self.registry._schedulers) == 0
69+
assert len(self.registry._datasets) == 0
70+
assert len(self.registry._models) == 0
71+
assert len(self.registry._data_collators) == 0
72+
assert len(self.registry._metrics) == 0
73+
assert len(self.registry._loss_functions) == 0
74+
assert len(self.registry._callbacks) == 0
75+
assert len(self.registry._hooks) == 0
76+
assert len(self.registry._trainer_modules) == 0
77+
78+
def test_trainer_module_with_args_and_kwargs(self):
79+
"""Test trainer module registration with args class and required kwargs."""
80+
81+
class MockArgs:
82+
pass
83+
84+
class MockTrainer:
85+
pass
86+
87+
# Register with decorator including args class and required kwargs
88+
self.registry.trainer_module(
89+
"test_trainer_with_args", args_cls=MockArgs, required_kwargs={"param1": "default1", "param2": "default2"}
90+
)(MockTrainer)
91+
92+
# Verify registration details
93+
module_info = self.registry.get_trainer_module("test_trainer_with_args")
94+
assert module_info["trainer_cls"] == MockTrainer
95+
assert module_info["args_cls"] == MockArgs
96+
assert module_info["required_kwargs"] == {"param1": "default1", "param2": "default2"}
97+
98+
def test_list_methods(self):
99+
"""Test all list methods return correct keys."""
100+
101+
# Register some dummy items
102+
class DummyClass:
103+
pass
104+
105+
self.registry.optimizer("opt1")(DummyClass)
106+
self.registry.scheduler("sched1")(DummyClass)
107+
self.registry.dataset("ds1")(DummyClass)
108+
self.registry.model("model1")(DummyClass)
109+
self.registry.data_collator("coll1")(lambda x: x)
110+
self.registry.loss_function("loss1")(DummyClass)
111+
self.registry.callback("cb1")(DummyClass)
112+
self.registry.trainer_module("tm1")(DummyClass)
113+
114+
# Test lists
115+
assert self.registry.list_optimizers() == ["opt1"]
116+
assert self.registry.list_schedulers() == ["sched1"]
117+
assert self.registry.list_datasets() == ["ds1"]
118+
assert self.registry.list_models() == ["model1"]
119+
assert self.registry.list_data_collators() == ["coll1"]
120+
assert self.registry.list_loss_functions() == ["loss1"]
121+
assert self.registry.list_callbacks() == ["cb1"]
122+
assert self.registry.list_trainer_modules() == ["tm1"]
123+
124+
def test_logging_on_registration(self, mocker):
125+
"""Test that registration logs messages."""
126+
mock_logger = mocker.patch("QEfficient.finetune.experimental.core.component_registry.logger")
127+
128+
class MockClass:
129+
pass
130+
131+
# Test optimizer registration logging
132+
self.registry.optimizer("test_opt")(MockClass)
133+
mock_logger.info.assert_called_with("Registered optimizer: test_opt")
134+
135+
# Reset mock
136+
mock_logger.reset_mock()
137+
138+
# Test trainer module registration logging
139+
self.registry.trainer_module("test_tm")(MockClass)
140+
mock_logger.info.assert_called_with("Registered trainer module: test_tm")
141+
142+
143+
class TestGetObjectFunction:
144+
def test_get_object_success(self):
145+
"""Test get_object function success case."""
146+
test_dict = {"key1": "value1", "key2": "value2"}
147+
148+
result = get_object(test_dict, "key1", "test_type", lambda: ["key1", "key2"])
149+
assert result == "value1"
150+
151+
def test_get_object_failure(self):
152+
"""Test get_object function failure case."""
153+
test_dict = {"key1": "value1"}
154+
155+
with pytest.raises(ValueError) as exc_info:
156+
get_object(test_dict, "nonexistent", "test_type", lambda: ["key1", "key2"])
157+
158+
assert "Unknown test_type: nonexistent" in str(exc_info.value)
159+
assert "Available: ['key1', 'key2']" in str(exc_info.value)
160+
161+
162+
class TestGlobalRegistry:
163+
def test_global_registry_instance(self):
164+
"""Test that global registry instance exists and is of correct type."""
165+
assert isinstance(registry, ComponentRegistry)

0 commit comments

Comments
 (0)