Skip to content

Hdd offloading #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support a
* [iRevNet](#irevnet)
* [Reformer](#reformer)
* [Utils](#utils)
* [HDD Offload](#HDD-Offload)
* [Huggingface](#Huggingface)
* [Explanation](#explanation)

## Features
Expand Down Expand Up @@ -264,6 +266,60 @@ assert out.size() == (16, sequence, classes)

#### Utils

##### HDD Offload

One of the core features of RevLib is that, when used together with RevNet or checkpointing, it can offload parameters
onto RAM. This way, the GPU memory utilization is truly constant with increasing depth, which allows for enormous models
on limited resources.\
Since RevLib v1.5.1, you can now take it one step further and offload straight onto storage. This way, even the CPU RAM
gets freed as well and 200B models are just
[1TB (120€)](https://www.digitec.ch/en/s1/product/samsung-870-evo-1000-gb-25-ssd-14598791) away!

```PYTHON
import time
import torch
from memory_profiler import memory_usage
from revlib.utils import HDDParameter


def do_it(cls):
t = [cls(torch.randn((2 ** 14, 2 ** 14)) / 2 ** 10, requires_grad=True) for _ in range(1024)]
start_time = time.time()
for k in t:
k.requires_grad_(True)
for i in range(5):
with torch.enable_grad():
out = t[0][0:64] @ t[1][:, 0:128] @ t[2][256:384]
out.mean().backward()
for k in t:
grad = k.grad
if grad is None:
continue
grad = grad.clone()
k.grad = None
k.detach_()
k.requires_grad_(False)
k -= grad * 0.1
k.detach_()
k.requires_grad_(True)
print(f'took {time.time() - start_time:.2f}s, ', end='')


def main():
print(max(memory_usage((lambda: None,))), "MB") # 346.40234375 MB
print(max(memory_usage((lambda: do_it(HDDParameter),))), "MB") # took 106.33s, 5541.34765625 MB
print(max(memory_usage((lambda: do_it(torch.nn.Parameter),))), "MB") # took 11.76s, 8613.398475 MB


if __name__ == '__main__':
main()
```

While the results above are pretty nice already, it gets even more impressive with structured sparsity.\
When implementing a set of weights using lists,

##### Huggingface

RevLib also has its own `utils` module which provides helpful functions as `residual_to_momentum_net`. Using RevLib, you
can trivially convert any HuggingFace transformer into a MomentumNet without significant loss of performance. Especially
during fine-tuning, this can be a life-saver, as it allows for significantly bigger models to fit into memory without
Expand Down
96 changes: 96 additions & 0 deletions revlib/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
import secrets
import typing

import torch.utils.checkpoint
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from torch.utils._pytree import tree_map

from revlib.core import ReversibleSequential, MemoryModes, SingleBranchReversibleModule, split_tensor_list, MergeCalls

Expand Down Expand Up @@ -136,3 +140,95 @@ def module_list_to_momentum_net(module: torch.nn.ModuleList,
for i in range(0, len(stem) - 1, 2)]
out_modules.append(modules[-1])
return torch.nn.ModuleList(out_modules)


def _empty_tensor(cls: type, data: typing.Optional[torch.Tensor], requires_grad=True) -> torch.Tensor:
if data is None:
data = torch.zeros(())
if torch.torch_version.TorchVersion(torch.version.__version__) >= "1.11":
r = torch.Tensor._make_wrapper_subclass(cls, data.size(), strides=data.stride(), device=data.device,
storage_offset=data.storage_offset(), dtype=data.dtype,
layout=data.layout, requires_grad=requires_grad)
else:
meta = data.new_empty((0,))
meta.set_(meta.storage(), 0, data.size(), data.stride())
r = torch.Tensor._make_subclass(cls, meta, data.requires_grad)
return r


class HDDParameter(torch.nn.Parameter):
file_name: str
__slots__ = ['file_name']

@staticmethod
def __new__(cls, data=None, requires_grad=True):
file_name = f'.temporary_tensor_buffer_{secrets.token_urlsafe(32)}.pth'
torch.save(data, file_name)
r = _empty_tensor(cls, data, requires_grad)
r.file_name = file_name
return r

def __repr__(self):
return f"OffloadedParameter({self.data})"

@classmethod
def __torch_dispatch__(cls, func: typing.Callable, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*tree_map(_unwrap_offloaded_parameter, args), **tree_map(_unwrap_offloaded_parameter, kwargs))
if hasattr(func, '__name__') and func.__name__ != '_' and func.__name__.endswith('_'):
torch.save(out, args[0].file_name)
elif 'out' in kwargs:
torch.save(out, kwargs['out'])
return out

def __del__(self):
os.remove(self.file_name)


def _unwrap_offloaded_parameter(inp: typing.Any) -> typing.Any:
if not isinstance(inp, HDDParameter):
return inp
return torch.load(inp.file_name).requires_grad_(inp.requires_grad).to(device=inp.device, non_blocking=True)


class QuantizedTensor(torch.Tensor):
elem: torch.Tensor
absmax: torch.Tensor
code: torch.Tensor

__slots__ = ['elem', 'absmax', 'code']

@staticmethod
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.zeros(())
data = data.clone()
r = _empty_tensor(cls, data, requires_grad)
data, (absmax, code) = quantize_blockwise(data)
r.elem = data
r.absmax = absmax
r.code = code
return r

def __repr__(self):
return f"QuantizedTensor({self.elem})"

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*tree_map(_unwrap_quantized_tensor, args), **tree_map(_unwrap_quantized_tensor, kwargs))
return tree_map(_wrap_quantized_tensor, out)


def _unwrap_quantized_tensor(inp: typing.Any) -> typing.Any:
if not isinstance(inp, QuantizedTensor):
return inp
return dequantize_blockwise(inp.elem, absmax=inp.absmax, code=inp.code)


def _wrap_quantized_tensor(inp: typing.Any) -> typing.Any:
if not isinstance(inp, torch.Tensor):
return inp
return QuantizedTensor(inp)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name='revlib',
license='BSD',
description='Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support and parameter offload',
version='v1.5.0',
version='v1.5.1',
long_description=README,
url='https://github.com/clashluke/revlib',
packages=setuptools.find_packages(),
Expand Down