diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index dd600a0f7..ea0d5f254 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -83,7 +83,7 @@ def from_dataloader( for key, value in batch.items(): if mask_padding and (key == "input_ids") and "attention_mask" in batch: value = cls._mask_padding(value, batch["attention_mask"]) - values[key] = IntermediateValue(value=value, device=model_device) + values[key] = cls._offload_value(value, offload_device, model_device) batch_intermediates.append(values) @@ -114,7 +114,8 @@ def update(self, batch_index: int, values: Dict[str, Any]): :param batch_index: index of batch whose values will be updated :param values: dictionary mapping keys to values used for update """ - intermediates = {k: self._offload_value(v) for k, v in values.items()} + device = self.offload_device + intermediates = {k: self._offload_value(v, device) for k, v in values.items()} self.batch_intermediates[batch_index].update(intermediates) def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None): @@ -189,7 +190,14 @@ def __iter__(self) -> Generator[Any, None, None]: def __len__(self) -> int: return len(self.batch_intermediates) - def _onload_value(self, intermediate: IntermediateValue) -> Any: + @classmethod + def _onload_value(cls, intermediate: IntermediateValue) -> Any: + """ + Onload a value's tensors to the onload device + + :param intermediate: intermediates value representation to onload + :return: original value with tensors onloaded to the onload device + """ value = intermediate.value device = intermediate.device @@ -197,51 +205,65 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any: case torch.Tensor(): return value.to(device=device) case list(): - return [self._onload_value(v) for v in value] + return [cls._onload_value(v) for v in value] case tuple(): - return tuple(self._onload_value(v) for v in value) + return tuple(cls._onload_value(v) for v in value) case dict(): - return {k: self._onload_value(v) for k, v in value.items()} + return {k: cls._onload_value(v) for k, v in value.items()} case _ if is_dataclass(value): for field in fields(value): v = getattr(value, field.name) - setattr(value, field.name, self._onload_value(v)) + setattr(value, field.name, cls._onload_value(v)) return value case _: # handles primitive values that should be returned as is. # without this, a MatchError would be raised for unhandled types. return value - def _offload_value(self, value: Any) -> IntermediateValue: + @classmethod + def _offload_value( + cls, + value: Any, + offload_device: torch.device | None, + onload_device: Optional[torch.device] = None, + ) -> IntermediateValue: + """ + Offload a value's tensors to the offload device + + :param value: value to offload + :param offload_device: device to offload `torch.Tensor` values to + :param onload_device: device used when onloading `torch.Tensor` values. + If None is provided, use the tensor's current device + :return: Instance of IntermediateValue representing the offloaded value + """ + kwargs = {"offload_device": offload_device, "onload_device": onload_device} match value: case torch.Tensor(): return IntermediateValue( - value=( - value - if self.offload_device is None - else value.to(device=self.offload_device) - ), - device=value.device, + value=value.to(device=offload_device), + device=(onload_device if onload_device else value.device), ) case list(): return IntermediateValue( - value=[self._offload_value(v) for v in value], + value=[cls._offload_value(v, **kwargs) for v in value], device=None, ) case tuple(): return IntermediateValue( - value=tuple(self._offload_value(v) for v in value), + value=tuple(cls._offload_value(v, **kwargs) for v in value), device=None, ) case dict(): return IntermediateValue( - value={k: self._offload_value(v) for k, v in value.items()}, + value={ + k: cls._offload_value(v, **kwargs) for k, v in value.items() + }, device=None, ) case _ if is_dataclass(value): for field in fields(value): v = getattr(value, field.name) - setattr(value, field.name, self._offload_value(v)) + setattr(value, field.name, cls._offload_value(v, **kwargs)) return IntermediateValue(value=value, device=None) case _: # handles primitive values and provides a warning for unsupported types. diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index eda040d52..ff86f9400 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -1,10 +1,16 @@ -from dataclasses import dataclass +from dataclasses import dataclass, fields, is_dataclass import pytest import torch from torch.utils.data import DataLoader, StackDataset -from llmcompressor.pipelines.cache import IntermediatesCache, IntermediateValue +from llmcompressor.pipelines.cache import IntermediatesCache + + +@dataclass +class SampleDataclass: + a: torch.Tensor + b: int @pytest.fixture @@ -28,6 +34,14 @@ def sample_cache(sample_dataloader): ) +values_to_test = [ + torch.randn(2, 3).to("cpu"), + SampleDataclass(a=torch.randn(2, 3), b=42), + torch.float32, + [1, 2, 3], +] + + @pytest.mark.unit def test_initialization(sample_dataloader): cache = IntermediatesCache.from_dataloader( @@ -95,62 +109,22 @@ def test_mask_padding(): @pytest.mark.unit -def test_offload_and_onload_tensor(): - cache = IntermediatesCache([], torch.device("cpu")) - - # Test tensor offloading - original_tensor = torch.randn(2, 3).to("cpu") - offloaded = cache._offload_value(original_tensor) +@pytest.mark.parametrize("value", values_to_test) +def test_from_dataloader(value): + dataset = StackDataset(value=[value]) + dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0]) + cache = IntermediatesCache.from_dataloader(dataloader) - assert isinstance(offloaded, IntermediateValue) - assert isinstance(offloaded.value, torch.Tensor) - assert offloaded.device == original_tensor.device - - # Test tensor onloading - onloaded = cache._onload_value(offloaded) - assert torch.equal(onloaded, original_tensor) - - -@dataclass -class SampleDataclass: - a: torch.Tensor - b: int + onloaded = cache.fetch(0, ["value"])["value"] + assert deep_equal(onloaded, value) @pytest.mark.unit -def test_offload_and_onload_dataclass(): - cache = IntermediatesCache([], torch.device("cpu")) - - # Create a sample dataclass instance - sample_data = SampleDataclass(a=torch.randn(2, 3), b=42) - - # Test dataclass offloading - offloaded = cache._offload_value(sample_data) - assert isinstance(offloaded, IntermediateValue) - assert isinstance(offloaded.value, SampleDataclass) - assert isinstance(offloaded.value.a, IntermediateValue) - assert isinstance(offloaded.value.b, IntermediateValue) - - # Test dataclass onloading - onloaded = cache._onload_value(offloaded) - assert onloaded == sample_data - - -@pytest.mark.unit -def test_offload_and_onload_dtype(): - cache = IntermediatesCache([], torch.device("cpu")) - - # Create a sample dataclass instance - sample_data = torch.float32 - - # Test dataclass offloading - offloaded = cache._offload_value(sample_data) - assert isinstance(offloaded, IntermediateValue) - assert isinstance(offloaded.value, torch.dtype) - - # Test dataclass onloading - onloaded = cache._onload_value(offloaded) - assert onloaded == sample_data +@pytest.mark.parametrize("value", values_to_test) +def test_offload_and_onload(value): + offloaded = IntermediatesCache._offload_value(value, torch.device("cpu")) + onloaded = IntermediatesCache._onload_value(offloaded) + assert deep_equal(onloaded, value) @pytest.mark.unit @@ -190,3 +164,27 @@ def test_device_handling(sample_dataloader): # Verify tensors are loaded back to GPU when fetched fetched = cache.fetch(0, ["hidden_states"]) assert fetched["hidden_states"].device.type == "cuda" + + +def deep_equal(a, b) -> bool: + if type(a) != type(b): + return False + + match a: + case torch.Tensor(): + return torch.equal(a, b) + case list() | tuple(): + if len(a) != len(b): + return False + return all(deep_equal(_a, _b) for _a, _b in zip(a, b)) + case dict(): + if a.keys() != b.keys(): + return False + return all(deep_equal(a[key], b[key]) for key in a.keys()) + case _ if is_dataclass(a): + a_dict = {field.name: getattr(a, field.name) for field in fields(a)} + b_dict = {field.name: getattr(b, field.name) for field in fields(b)} + + return deep_equal(a_dict, b_dict) + case _: + return a == b