Skip to content

fix: caching allocator behaviour for quantization. #12172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 18, 2025
Merged

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Thanks @JoeGaffney for bringing up this issue in #12043 (comment). Much appreciated.

This PR fixes it.

What happened?

Expand

We had a unit mismatch issue in the previous implementation from #12043. We accumulated bytes per device:
param_byte_count = param.numel() * param.element_size() -- this is bytes.

But we then passed that byte count directly to torch.empty as if it were a number of elements: torch.empty(byte_count // factor, dtype=dtype, device=device). torch.empty(n, dtype=...) interprets n as an element count, NOT bytes. That’s too big by a factor of element_size(target_dtype) (and often also by element_size(param) if they differ).

The previous implementation from #11904 actually does it correctly. It summed elements (math.prod(param.shape)) and then allocated exactly that many elements with the chosen dtype. Sorry @a-r-r-o-w!

I haven't verified yet, but this might be broken in transformers, too, from which this is adapted.

Results

I have verified that with this PR, the non-quantized checkpoints load with similar speeds and yield similar reserved memory. But it's the quantized checkpoints where things get messy.

Allocated: 6704.4 MB, Reserved: 23995.6 MB (`main`)
Allocated: 6701.2 MB, Reserved: 6905.9 MB (https://github.com/huggingface/diffusers/pull/11904)
Allocated: 6704.0 MB, Reserved: 6878.7 MB (this PR)
Script
import pytest
import torch
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_bnb_quantized_model_warmup():
    model_id = "black-forest-labs/FLUX.1-dev"
    torch_dtype = torch.bfloat16

    # Quantization config for 4-bit BNB
    quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype)

    # Load model (actual warmup path triggered internally)
    model = FluxTransformer2DModel.from_pretrained(
        model_id, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch_dtype
    )
    # model = FluxTransformer2DModel.from_pretrained(
    #     model_id, subfolder="transformer", device_map="cuda", torch_dtype=torch_dtype
    # )

    # Check memory stats
    torch.cuda.reset_peak_memory_stats()
    mem_alloc = torch.cuda.memory_allocated()
    mem_reserved = torch.cuda.memory_reserved()
    print(f"Allocated: {mem_alloc/1e6:.1f} MB, Reserved: {mem_reserved/1e6:.1f} MB")

    # Assert some reasonable range 
    assert mem_alloc > 0, "Model should allocate some GPU memory"
    assert mem_reserved > 0, "Warmup should reserve some GPU memory"

Run with pytest test_memory_allocation.py -vs.

Thanks to @JoeGaffney for providing the snippet and for investigating this issue. In a separate PR, I would love to brainstorm a unit test for this with you. Even better -- if you're open to contributing it :)

@asomoza check it when you have a moment.

@sayakpaul sayakpaul requested a review from a-r-r-o-w August 18, 2025 05:02
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asomoza
Copy link
Member

asomoza commented Aug 18, 2025

this fixes it for me, no more OOMs

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, i thought the other code was correct as well, so looks like we both missed it

@sayakpaul sayakpaul merged commit e824660 into main Aug 18, 2025
35 checks passed
@sayakpaul sayakpaul deleted the fix-quantizer-warmup branch August 18, 2025 07:46
@JoeGaffney
Copy link

Hey thanks @sayakpaul

Seems to work on main.

BitsAndBytesConfig {
  "_load_in_4bit": true,
  "_load_in_8bit": false,
  "bnb_4bit_compute_dtype": "bfloat16",
  "bnb_4bit_quant_storage": "uint8",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

Before moving to GPU Allocated: 6704.6 MB, Reserved: 6830.4 MB
After moving to GPU Allocated: 6704.6 MB, Reserved: 6830.4 MB
After moving to CPU Allocated: 0.0 MB, Reserved: 0.0 MB
PASSED

TorchAoConfig {
  "modules_to_not_convert": null,
  "quant_method": "torchao",
  "quant_type": "int8_weight_only",
  "quant_type_kwargs": {}
}

Before moving to GPU Allocated: 0.0 MB, Reserved: 0.0 MB
After moving to GPU Allocated: 12014.9 MB, Reserved: 12306.1 MB
After moving to CPU Allocated: 0.0 MB, Reserved: 0.0 MB
PASSED

Happy to try and add a unit test. Maybe with some smaller model with known sizes we could do some asserts to validate some values of popular quant configs.

One behaviour thing here though is that BnB seems to Reserve and allocate on load. But TorchAO seems to be only when moving to cuda. Maybe this is expected though.

@sayakpaul
Copy link
Member Author

One behaviour thing here though is that BnB seems to Reserve and allocate on load. But TorchAO seems to be only when moving to cuda. Maybe this is expected though.

This is correct. You NEED a supported accelerator for the weights to initially load for bnb. However, for 4bit quantized models in bnb, you can move them to CPU. Test:

def test_device_assignment(self):

TorchAO doesn't have this nature.

Happy to try and add a unit test. Maybe with some smaller model with known sizes we could do some asserts to validate some values of popular quant configs.

Sure, let's maybe try adding one for each of bnb and torchao? Thanks a lot in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants