-
Notifications
You must be signed in to change notification settings - Fork 6.2k
fix: caching allocator behaviour for quantization. #12172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
this fixes it for me, no more OOMs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, i thought the other code was correct as well, so looks like we both missed it
Hey thanks @sayakpaul Seems to work on main.
Happy to try and add a unit test. Maybe with some smaller model with known sizes we could do some asserts to validate some values of popular quant configs. One behaviour thing here though is that BnB seems to Reserve and allocate on load. But TorchAO seems to be only when moving to cuda. Maybe this is expected though. |
This is correct. You NEED a supported accelerator for the weights to initially load for bnb. However, for 4bit quantized models in bnb, you can move them to CPU. Test: diffusers/tests/quantization/bnb/test_4bit.py Line 288 in 9918d13
TorchAO doesn't have this nature.
Sure, let's maybe try adding one for each of bnb and torchao? Thanks a lot in advance. |
What does this PR do?
Thanks @JoeGaffney for bringing up this issue in #12043 (comment). Much appreciated.
This PR fixes it.
What happened?
Expand
We had a unit mismatch issue in the previous implementation from #12043. We accumulated bytes per device:
param_byte_count = param.numel() * param.element_size()
-- this is bytes.But we then passed that byte count directly to
torch.empty
as if it were a number of elements:torch.empty(byte_count // factor, dtype=dtype, device=device)
.torch.empty(n, dtype=...)
interprets n as an element count, NOT bytes. That’s too big by a factor ofelement_size(target_dtype)
(and often also byelement_size(param)
if they differ).The previous implementation from #11904 actually does it correctly. It summed elements (
math.prod(param.shape)
) and then allocated exactly that many elements with the chosen dtype. Sorry @a-r-r-o-w!I haven't verified yet, but this might be broken in
transformers
, too, from which this is adapted.Results
I have verified that with this PR, the non-quantized checkpoints load with similar speeds and yield similar reserved memory. But it's the quantized checkpoints where things get messy.
Allocated: 6704.4 MB, Reserved: 23995.6 MB (`main`) Allocated: 6701.2 MB, Reserved: 6905.9 MB (https://github.com/huggingface/diffusers/pull/11904) Allocated: 6704.0 MB, Reserved: 6878.7 MB (this PR)
Script
Run with
pytest test_memory_allocation.py -vs
.Thanks to @JoeGaffney for providing the snippet and for investigating this issue. In a separate PR, I would love to brainstorm a unit test for this with you. Even better -- if you're open to contributing it :)
@asomoza check it when you have a moment.