Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def from_dataloader(
for key, value in batch.items():
if mask_padding and (key == "input_ids") and "attention_mask" in batch:
value = cls._mask_padding(value, batch["attention_mask"])
values[key] = IntermediateValue(value=value, device=model_device)
values[key] = cls._offload_value(value, offload_device, model_device)

batch_intermediates.append(values)

Expand Down Expand Up @@ -114,7 +114,8 @@ def update(self, batch_index: int, values: Dict[str, Any]):
:param batch_index: index of batch whose values will be updated
:param values: dictionary mapping keys to values used for update
"""
intermediates = {k: self._offload_value(v) for k, v in values.items()}
device = self.offload_device
intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
self.batch_intermediates[batch_index].update(intermediates)

def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None):
Expand Down Expand Up @@ -189,59 +190,80 @@ def __iter__(self) -> Generator[Any, None, None]:
def __len__(self) -> int:
return len(self.batch_intermediates)

def _onload_value(self, intermediate: IntermediateValue) -> Any:
@classmethod
def _onload_value(cls, intermediate: IntermediateValue) -> Any:
"""
Onload a value's tensors to the onload device

:param intermediate: intermediates value representation to onload
:return: original value with tensors onloaded to the onload device
"""
value = intermediate.value
device = intermediate.device

match value:
case torch.Tensor():
return value.to(device=device)
case list():
return [self._onload_value(v) for v in value]
return [cls._onload_value(v) for v in value]
case tuple():
return tuple(self._onload_value(v) for v in value)
return tuple(cls._onload_value(v) for v in value)
case dict():
return {k: self._onload_value(v) for k, v in value.items()}
return {k: cls._onload_value(v) for k, v in value.items()}
case _ if is_dataclass(value):
for field in fields(value):
v = getattr(value, field.name)
setattr(value, field.name, self._onload_value(v))
setattr(value, field.name, cls._onload_value(v))
return value
case _:
# handles primitive values that should be returned as is.
# without this, a MatchError would be raised for unhandled types.
return value

def _offload_value(self, value: Any) -> IntermediateValue:
@classmethod
def _offload_value(
cls,
value: Any,
offload_device: torch.device | None,
onload_device: Optional[torch.device] = None,
) -> IntermediateValue:
"""
Offload a value's tensors to the offload device

:param value: value to offload
:param offload_device: device to offload `torch.Tensor` values to
:param onload_device: device used when onloading `torch.Tensor` values.
If None is provided, use the tensor's current device
:return: Instance of IntermediateValue representing the offloaded value
"""
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
match value:
case torch.Tensor():
return IntermediateValue(
value=(
value
if self.offload_device is None
else value.to(device=self.offload_device)
),
device=value.device,
value=value.to(device=offload_device),
device=(onload_device if onload_device else value.device),
)
case list():
return IntermediateValue(
value=[self._offload_value(v) for v in value],
value=[cls._offload_value(v, **kwargs) for v in value],
device=None,
)
case tuple():
return IntermediateValue(
value=tuple(self._offload_value(v) for v in value),
value=tuple(cls._offload_value(v, **kwargs) for v in value),
device=None,
)
case dict():
return IntermediateValue(
value={k: self._offload_value(v) for k, v in value.items()},
value={
k: cls._offload_value(v, **kwargs) for k, v in value.items()
},
device=None,
)
case _ if is_dataclass(value):
for field in fields(value):
v = getattr(value, field.name)
setattr(value, field.name, self._offload_value(v))
setattr(value, field.name, cls._offload_value(v, **kwargs))
return IntermediateValue(value=value, device=None)
case _:
# handles primitive values and provides a warning for unsupported types.
Expand Down
106 changes: 52 additions & 54 deletions tests/llmcompressor/pipelines/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from dataclasses import dataclass
from dataclasses import dataclass, fields, is_dataclass

import pytest
import torch
from torch.utils.data import DataLoader, StackDataset

from llmcompressor.pipelines.cache import IntermediatesCache, IntermediateValue
from llmcompressor.pipelines.cache import IntermediatesCache


@dataclass
class SampleDataclass:
a: torch.Tensor
b: int


@pytest.fixture
Expand All @@ -28,6 +34,14 @@ def sample_cache(sample_dataloader):
)


values_to_test = [
torch.randn(2, 3).to("cpu"),
SampleDataclass(a=torch.randn(2, 3), b=42),
torch.float32,
[1, 2, 3],
]


@pytest.mark.unit
def test_initialization(sample_dataloader):
cache = IntermediatesCache.from_dataloader(
Expand Down Expand Up @@ -95,62 +109,22 @@ def test_mask_padding():


@pytest.mark.unit
def test_offload_and_onload_tensor():
cache = IntermediatesCache([], torch.device("cpu"))

# Test tensor offloading
original_tensor = torch.randn(2, 3).to("cpu")
offloaded = cache._offload_value(original_tensor)
@pytest.mark.parametrize("value", values_to_test)
def test_from_dataloader(value):
dataset = StackDataset(value=[value])
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
cache = IntermediatesCache.from_dataloader(dataloader)

assert isinstance(offloaded, IntermediateValue)
assert isinstance(offloaded.value, torch.Tensor)
assert offloaded.device == original_tensor.device

# Test tensor onloading
onloaded = cache._onload_value(offloaded)
assert torch.equal(onloaded, original_tensor)


@dataclass
class SampleDataclass:
a: torch.Tensor
b: int
onloaded = cache.fetch(0, ["value"])["value"]
assert deep_equal(onloaded, value)


@pytest.mark.unit
def test_offload_and_onload_dataclass():
cache = IntermediatesCache([], torch.device("cpu"))

# Create a sample dataclass instance
sample_data = SampleDataclass(a=torch.randn(2, 3), b=42)

# Test dataclass offloading
offloaded = cache._offload_value(sample_data)
assert isinstance(offloaded, IntermediateValue)
assert isinstance(offloaded.value, SampleDataclass)
assert isinstance(offloaded.value.a, IntermediateValue)
assert isinstance(offloaded.value.b, IntermediateValue)

# Test dataclass onloading
onloaded = cache._onload_value(offloaded)
assert onloaded == sample_data


@pytest.mark.unit
def test_offload_and_onload_dtype():
cache = IntermediatesCache([], torch.device("cpu"))

# Create a sample dataclass instance
sample_data = torch.float32

# Test dataclass offloading
offloaded = cache._offload_value(sample_data)
assert isinstance(offloaded, IntermediateValue)
assert isinstance(offloaded.value, torch.dtype)

# Test dataclass onloading
onloaded = cache._onload_value(offloaded)
assert onloaded == sample_data
@pytest.mark.parametrize("value", values_to_test)
def test_offload_and_onload(value):
offloaded = IntermediatesCache._offload_value(value, torch.device("cpu"))
onloaded = IntermediatesCache._onload_value(offloaded)
assert deep_equal(onloaded, value)


@pytest.mark.unit
Expand Down Expand Up @@ -190,3 +164,27 @@ def test_device_handling(sample_dataloader):
# Verify tensors are loaded back to GPU when fetched
fetched = cache.fetch(0, ["hidden_states"])
assert fetched["hidden_states"].device.type == "cuda"


def deep_equal(a, b) -> bool:
if type(a) != type(b):
return False

match a:
case torch.Tensor():
return torch.equal(a, b)
case list() | tuple():
if len(a) != len(b):
return False
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
case dict():
if a.keys() != b.keys():
return False
return all(deep_equal(a[key], b[key]) for key in a.keys())
case _ if is_dataclass(a):
a_dict = {field.name: getattr(a, field.name) for field in fields(a)}
b_dict = {field.name: getattr(b, field.name) for field in fields(b)}

return deep_equal(a_dict, b_dict)
case _:
return a == b