diff --git a/README.md b/README.md index 598a8b2..bcf34c3 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support a * [iRevNet](#irevnet) * [Reformer](#reformer) * [Utils](#utils) + * [HDD Offload](#HDD-Offload) + * [Huggingface](#Huggingface) * [Explanation](#explanation) ## Features @@ -264,6 +266,60 @@ assert out.size() == (16, sequence, classes) #### Utils +##### HDD Offload + +One of the core features of RevLib is that, when used together with RevNet or checkpointing, it can offload parameters +onto RAM. This way, the GPU memory utilization is truly constant with increasing depth, which allows for enormous models +on limited resources.\ +Since RevLib v1.5.1, you can now take it one step further and offload straight onto storage. This way, even the CPU RAM +gets freed as well and 200B models are just +[1TB (120€)](https://www.digitec.ch/en/s1/product/samsung-870-evo-1000-gb-25-ssd-14598791) away! + +```PYTHON +import time +import torch +from memory_profiler import memory_usage +from revlib.utils import HDDParameter + + +def do_it(cls): + t = [cls(torch.randn((2 ** 14, 2 ** 14)) / 2 ** 10, requires_grad=True) for _ in range(1024)] + start_time = time.time() + for k in t: + k.requires_grad_(True) + for i in range(5): + with torch.enable_grad(): + out = t[0][0:64] @ t[1][:, 0:128] @ t[2][256:384] + out.mean().backward() + for k in t: + grad = k.grad + if grad is None: + continue + grad = grad.clone() + k.grad = None + k.detach_() + k.requires_grad_(False) + k -= grad * 0.1 + k.detach_() + k.requires_grad_(True) + print(f'took {time.time() - start_time:.2f}s, ', end='') + + +def main(): + print(max(memory_usage((lambda: None,))), "MB") # 346.40234375 MB + print(max(memory_usage((lambda: do_it(HDDParameter),))), "MB") # took 106.33s, 5541.34765625 MB + print(max(memory_usage((lambda: do_it(torch.nn.Parameter),))), "MB") # took 11.76s, 8613.398475 MB + + +if __name__ == '__main__': + main() +``` + +While the results above are pretty nice already, it gets even more impressive with structured sparsity.\ +When implementing a set of weights using lists, + +##### Huggingface + RevLib also has its own `utils` module which provides helpful functions as `residual_to_momentum_net`. Using RevLib, you can trivially convert any HuggingFace transformer into a MomentumNet without significant loss of performance. Especially during fine-tuning, this can be a life-saver, as it allows for significantly bigger models to fit into memory without diff --git a/revlib/utils.py b/revlib/utils.py index 7489557..807279e 100644 --- a/revlib/utils.py +++ b/revlib/utils.py @@ -1,6 +1,10 @@ +import os +import secrets import typing import torch.utils.checkpoint +from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise +from torch.utils._pytree import tree_map from revlib.core import ReversibleSequential, MemoryModes, SingleBranchReversibleModule, split_tensor_list, MergeCalls @@ -136,3 +140,95 @@ def module_list_to_momentum_net(module: torch.nn.ModuleList, for i in range(0, len(stem) - 1, 2)] out_modules.append(modules[-1]) return torch.nn.ModuleList(out_modules) + + +def _empty_tensor(cls: type, data: typing.Optional[torch.Tensor], requires_grad=True) -> torch.Tensor: + if data is None: + data = torch.zeros(()) + if torch.torch_version.TorchVersion(torch.version.__version__) >= "1.11": + r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, + storage_offset=data.storage_offset(), dtype=data.dtype, + layout=data.layout, requires_grad=requires_grad) + else: + meta = data.new_empty((0,)) + meta.set_(meta.storage(), 0, data.size(), data.stride()) + r = torch.Tensor._make_subclass(cls, meta, data.requires_grad) + return r + + +class HDDParameter(torch.nn.Parameter): + file_name: str + __slots__ = ['file_name'] + + @staticmethod + def __new__(cls, data=None, requires_grad=True): + file_name = f'.temporary_tensor_buffer_{secrets.token_urlsafe(32)}.pth' + torch.save(data, file_name) + r = _empty_tensor(cls, data, requires_grad) + r.file_name = file_name + return r + + def __repr__(self): + return f"OffloadedParameter({self.data})" + + @classmethod + def __torch_dispatch__(cls, func: typing.Callable, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + out = func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs)) + if hasattr(func, '__name__') and func.__name__ != '_' and func.__name__.endswith('_'): + torch.save(out, args[0].file_name) + elif 'out' in kwargs: + torch.save(out, kwargs['out']) + return out + + def __del__(self): + os.remove(self.file_name) + + +def _unwrap_offloaded_parameter(inp: typing.Any) -> typing.Any: + if not isinstance(inp, HDDParameter): + return inp + return torch.load(inp.file_name).requires_grad_(inp.requires_grad).to(device=inp.device, non_blocking=True) + + +class QuantizedTensor(torch.Tensor): + elem: torch.Tensor + absmax: torch.Tensor + code: torch.Tensor + + __slots__ = ['elem', 'absmax', 'code'] + + @staticmethod + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.zeros(()) + data = data.clone() + r = _empty_tensor(cls, data, requires_grad) + data, (absmax, code) = quantize_blockwise(data) + r.elem = data + r.absmax = absmax + r.code = code + return r + + def __repr__(self): + return f"QuantizedTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + out = func(*tree_map(_unwrap_quantized_tensor, args), **tree_map(_unwrap_quantized_tensor, kwargs)) + return tree_map(_wrap_quantized_tensor, out) + + +def _unwrap_quantized_tensor(inp: typing.Any) -> typing.Any: + if not isinstance(inp, QuantizedTensor): + return inp + return dequantize_blockwise(inp.elem, absmax=inp.absmax, code=inp.code) + + +def _wrap_quantized_tensor(inp: typing.Any) -> typing.Any: + if not isinstance(inp, torch.Tensor): + return inp + return QuantizedTensor(inp) diff --git a/setup.py b/setup.py index 56c71f9..d0ead11 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='revlib', license='BSD', description='Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support and parameter offload', - version='v1.5.0', + version='v1.5.1', long_description=README, url='https://github.com/clashluke/revlib', packages=setuptools.find_packages(),