From 47cd076785941d100a15d33e0d2b663651974487 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 14 Jan 2022 15:42:17 +0100 Subject: [PATCH 1/5] feat(utils): add BNB quantization, HDD Offloading --- revlib/utils.py | 82 +++++++++++++++++++++++++++++++++++++++++++++++++ setup.py | 2 +- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/revlib/utils.py b/revlib/utils.py index 7489557..a1fac8c 100644 --- a/revlib/utils.py +++ b/revlib/utils.py @@ -1,6 +1,10 @@ +import os +import secrets import typing import torch.utils.checkpoint +from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise +from torch.utils._pytree import tree_map from revlib.core import ReversibleSequential, MemoryModes, SingleBranchReversibleModule, split_tensor_list, MergeCalls @@ -136,3 +140,81 @@ def module_list_to_momentum_net(module: torch.nn.ModuleList, for i in range(0, len(stem) - 1, 2)] out_modules.append(modules[-1]) return torch.nn.ModuleList(out_modules) + + +class HDDParameter(torch.nn.Parameter): + file_name: str + __slots__ = ['file_name'] + + @staticmethod + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.zeros(()) + meta = data.new_empty((0,)) + meta.set_(meta.storage(), 0, data.size(), data.stride()) + r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, + storage_offset=data.storage_offset(), dtype=data.dtype, + layout=data.layout, requires_grad=requires_grad) + file_name = f'.temporary_tensor_buffer_{secrets.token_urlsafe(32)}.pth' + torch.save(data, file_name) + r.file_name = file_name + return r + + def __repr__(self): + return f"OffloadedParameter({self.data})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + return func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs)) + + def __del__(self): + os.remove(self.file_name) + + +def _unwrap_offloaded_parameter(inp: typing.Any) -> typing.Any: + if not isinstance(inp, HDDParameter): + return inp + return torch.load(inp.file_name).requires_grad_(inp.requires_grad) + + +class QuantizedTensor(torch.Tensor): + elem: torch.Tensor + absmax: torch.Tensor + code: torch.Tensor + + __slots__ = ['elem', 'absmax', 'code'] + + @staticmethod + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.zeros(()) + meta = data.new_empty((0,)) + meta.set_(meta.storage(), 0, data.size(), data.stride()) + r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, + storage_offset=data.storage_offset(), dtype=data.dtype, + layout=data.layout, requires_grad=requires_grad) + data, (absmax, code) = quantize_blockwise(data) + r.elem = data + r.absmax = absmax + r.code = code + return r + + def __repr__(self): + return f"QuantizedTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + out = func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs)) + return tree_map(_wrap_quantized_tensor, out) + + +def _unwrap_quantized_tensor(inp: typing.Any) -> typing.Any: + if not isinstance(inp, QuantizedTensor): + return inp + return dequantize_blockwise(inp.elem, absmax=inp.absmax, code=inp.code) + + +def _wrap_quantized_tensor(inp: typing.Any) -> typing.Any: + if not isinstance(inp, torch.Tensor): + return inp + return QuantizedTensor(inp) diff --git a/setup.py b/setup.py index 56c71f9..d0ead11 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='revlib', license='BSD', description='Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support and parameter offload', - version='v1.5.0', + version='v1.5.1', long_description=README, url='https://github.com/clashluke/revlib', packages=setuptools.find_packages(), From dc59ee1c6d3dd73cfb8ae436f58929a23916b25b Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 14 Jan 2022 16:22:51 +0100 Subject: [PATCH 2/5] feat(utils): support offloading from CUDA --- revlib/utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/revlib/utils.py b/revlib/utils.py index a1fac8c..e0965f6 100644 --- a/revlib/utils.py +++ b/revlib/utils.py @@ -150,8 +150,6 @@ class HDDParameter(torch.nn.Parameter): def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.zeros(()) - meta = data.new_empty((0,)) - meta.set_(meta.storage(), 0, data.size(), data.stride()) r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, storage_offset=data.storage_offset(), dtype=data.dtype, layout=data.layout, requires_grad=requires_grad) @@ -174,7 +172,7 @@ def __del__(self): def _unwrap_offloaded_parameter(inp: typing.Any) -> typing.Any: if not isinstance(inp, HDDParameter): return inp - return torch.load(inp.file_name).requires_grad_(inp.requires_grad) + return torch.load(inp.file_name).requires_grad_(inp.requires_grad).to(device=inp.device, non_blocking=True) class QuantizedTensor(torch.Tensor): @@ -188,8 +186,7 @@ class QuantizedTensor(torch.Tensor): def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.zeros(()) - meta = data.new_empty((0,)) - meta.set_(meta.storage(), 0, data.size(), data.stride()) + data = data.clone() r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, storage_offset=data.storage_offset(), dtype=data.dtype, layout=data.layout, requires_grad=requires_grad) @@ -204,7 +201,7 @@ def __repr__(self): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - out = func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs)) + out = func(*tree_map(_unwrap_quantized_tensor, args), **tree_map(_unwrap_quantized_tensor, kwargs)) return tree_map(_wrap_quantized_tensor, out) From bd4038abfcea8d7cee2667fecb17986781892b39 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 14 Jan 2022 17:03:52 +0100 Subject: [PATCH 3/5] feat(utils): kind of support torch 1.10 --- revlib/utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/revlib/utils.py b/revlib/utils.py index e0965f6..3dcfd09 100644 --- a/revlib/utils.py +++ b/revlib/utils.py @@ -142,19 +142,29 @@ def module_list_to_momentum_net(module: torch.nn.ModuleList, return torch.nn.ModuleList(out_modules) +def _empty_tensor(cls: type, data: typing.Optional[torch.Tensor]) -> torch.Tensor: + if data is None: + data = torch.zeros(()) + if torch.torch_version.TorchVersion(torch.version.__version__) >= 1.11: + r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, + storage_offset=data.storage_offset(), dtype=data.dtype, + layout=data.layout, requires_grad=requires_grad) + else: + meta = data.new_empty((0,)) + meta.set_(meta.storage(), 0, data.size(), data.stride()) + r = torch.Tensor._make_subclass(cls, meta, data.requires_grad) + return r + + class HDDParameter(torch.nn.Parameter): file_name: str __slots__ = ['file_name'] @staticmethod def __new__(cls, data=None, requires_grad=True): - if data is None: - data = torch.zeros(()) - r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, - storage_offset=data.storage_offset(), dtype=data.dtype, - layout=data.layout, requires_grad=requires_grad) file_name = f'.temporary_tensor_buffer_{secrets.token_urlsafe(32)}.pth' torch.save(data, file_name) + r = _empty_tensor(cls, data) r.file_name = file_name return r @@ -187,10 +197,8 @@ def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.zeros(()) data = data.clone() - r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, - storage_offset=data.storage_offset(), dtype=data.dtype, - layout=data.layout, requires_grad=requires_grad) data, (absmax, code) = quantize_blockwise(data) + r = _empty_tensor(cls, data) r.elem = data r.absmax = absmax r.code = code From 03a26d3e373f15d765d70770913098470abb029e Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 14 Jan 2022 17:51:41 +0100 Subject: [PATCH 4/5] eat(utils): kind of support torch 1.10 --- revlib/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/revlib/utils.py b/revlib/utils.py index 3dcfd09..8e015f7 100644 --- a/revlib/utils.py +++ b/revlib/utils.py @@ -142,10 +142,10 @@ def module_list_to_momentum_net(module: torch.nn.ModuleList, return torch.nn.ModuleList(out_modules) -def _empty_tensor(cls: type, data: typing.Optional[torch.Tensor]) -> torch.Tensor: +def _empty_tensor(cls: type, data: typing.Optional[torch.Tensor], requires_grad=True) -> torch.Tensor: if data is None: data = torch.zeros(()) - if torch.torch_version.TorchVersion(torch.version.__version__) >= 1.11: + if torch.torch_version.TorchVersion(torch.version.__version__) >= "1.11": r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device, storage_offset=data.storage_offset(), dtype=data.dtype, layout=data.layout, requires_grad=requires_grad) @@ -164,7 +164,7 @@ class HDDParameter(torch.nn.Parameter): def __new__(cls, data=None, requires_grad=True): file_name = f'.temporary_tensor_buffer_{secrets.token_urlsafe(32)}.pth' torch.save(data, file_name) - r = _empty_tensor(cls, data) + r = _empty_tensor(cls, data, requires_grad) r.file_name = file_name return r @@ -197,8 +197,8 @@ def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.zeros(()) data = data.clone() + r = _empty_tensor(cls, data, requires_grad) data, (absmax, code) = quantize_blockwise(data) - r = _empty_tensor(cls, data) r.elem = data r.absmax = absmax r.code = code From 1638c7d1d014192b6d2dbd4a6a22b60c2f739aff Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 19 Jan 2022 14:47:39 +0100 Subject: [PATCH 5/5] eat(utils): kind of support torch 1.10 --- README.md | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ revlib/utils.py | 13 ++++++++++-- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 598a8b2..bcf34c3 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support a * [iRevNet](#irevnet) * [Reformer](#reformer) * [Utils](#utils) + * [HDD Offload](#HDD-Offload) + * [Huggingface](#Huggingface) * [Explanation](#explanation) ## Features @@ -264,6 +266,60 @@ assert out.size() == (16, sequence, classes) #### Utils +##### HDD Offload + +One of the core features of RevLib is that, when used together with RevNet or checkpointing, it can offload parameters +onto RAM. This way, the GPU memory utilization is truly constant with increasing depth, which allows for enormous models +on limited resources.\ +Since RevLib v1.5.1, you can now take it one step further and offload straight onto storage. This way, even the CPU RAM +gets freed as well and 200B models are just +[1TB (120€)](https://www.digitec.ch/en/s1/product/samsung-870-evo-1000-gb-25-ssd-14598791) away! + +```PYTHON +import time +import torch +from memory_profiler import memory_usage +from revlib.utils import HDDParameter + + +def do_it(cls): + t = [cls(torch.randn((2 ** 14, 2 ** 14)) / 2 ** 10, requires_grad=True) for _ in range(1024)] + start_time = time.time() + for k in t: + k.requires_grad_(True) + for i in range(5): + with torch.enable_grad(): + out = t[0][0:64] @ t[1][:, 0:128] @ t[2][256:384] + out.mean().backward() + for k in t: + grad = k.grad + if grad is None: + continue + grad = grad.clone() + k.grad = None + k.detach_() + k.requires_grad_(False) + k -= grad * 0.1 + k.detach_() + k.requires_grad_(True) + print(f'took {time.time() - start_time:.2f}s, ', end='') + + +def main(): + print(max(memory_usage((lambda: None,))), "MB") # 346.40234375 MB + print(max(memory_usage((lambda: do_it(HDDParameter),))), "MB") # took 106.33s, 5541.34765625 MB + print(max(memory_usage((lambda: do_it(torch.nn.Parameter),))), "MB") # took 11.76s, 8613.398475 MB + + +if __name__ == '__main__': + main() +``` + +While the results above are pretty nice already, it gets even more impressive with structured sparsity.\ +When implementing a set of weights using lists, + +##### Huggingface + RevLib also has its own `utils` module which provides helpful functions as `residual_to_momentum_net`. Using RevLib, you can trivially convert any HuggingFace transformer into a MomentumNet without significant loss of performance. Especially during fine-tuning, this can be a life-saver, as it allows for significantly bigger models to fit into memory without diff --git a/revlib/utils.py b/revlib/utils.py index 8e015f7..807279e 100644 --- a/revlib/utils.py +++ b/revlib/utils.py @@ -172,8 +172,15 @@ def __repr__(self): return f"OffloadedParameter({self.data})" @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - return func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs)) + def __torch_dispatch__(cls, func: typing.Callable, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + out = func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs)) + if hasattr(func, '__name__') and func.__name__ != '_' and func.__name__.endswith('_'): + torch.save(out, args[0].file_name) + elif 'out' in kwargs: + torch.save(out, kwargs['out']) + return out def __del__(self): os.remove(self.file_name) @@ -209,6 +216,8 @@ def __repr__(self): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} out = func(*tree_map(_unwrap_quantized_tensor, args), **tree_map(_unwrap_quantized_tensor, kwargs)) return tree_map(_wrap_quantized_tensor, out)