Skip to content

Cache packed sequence metadata to reduce D2H syncs across layers#4173

Closed
ichbinhandsome wants to merge 3447 commits intounslothai:mainfrom
ichbinhandsome:pakcing_cache_optimization
Closed

Cache packed sequence metadata to reduce D2H syncs across layers#4173
ichbinhandsome wants to merge 3447 commits intounslothai:mainfrom
ichbinhandsome:pakcing_cache_optimization

Conversation

@ichbinhandsome
Copy link

Added per-forward-pass caching to eliminate redundant D2H copies and cudaStreamSynchronize calls across layers.

When packing (padding-free) is enabled, three functions are called on every layer of the model during the forward pass:

  • get_packed_info_from_kwargs: calls lengths.max().item() — triggers D2H copy + sync
  • build_sdpa_packed_attention_mask (SDPA backend): calls seq_lengths.sum().item() and seq_lengths.tolist() — triggers 2 D2H copies + syncs
  • build_xformers_block_causal_mask (XFormers backend): calls seq_lengths.to("cpu") — triggers D2H copy + sync

For a model with N layers, this results in N unnecessary D2H synchronizations per function, even though the packed sequence metadata (seq_lengths) is identical across all layers within the same forward pass.

Solution

Cache the output of each function using Python object identity (is) comparison on the seq_lengths tensor. Since the same seq_lengths tensor object is passed to all layers within a single forward pass, subsequent layers hit the cache and skip the D2H operations entirely. A new batch produces a new seq_lengths tensor object, which naturally invalidates the cache.

This reduces D2H synchronizations per forward pass:

Function Before After
get_packed_info_from_kwargs N 1
build_sdpa_packed_attention_mask 2N 2
build_xformers_block_causal_mask N 1

Nsys profiling traces:

  • Without this PR:
image
  • With this PR:
image

Performance

With the caching strategy, CudaStreamSync only appears in the first layer and in the following layers it disappears. We achieve around 43.3% speedup for forward, 5.8% speedup for backward, 14.3% speedup for each batch for Qwen3 14B QLoRA SFT.

Fizza-Mukhtar and others added 30 commits December 28, 2025 21:23
* Guard optional trl.experimental.openenv usage in RL patches

* Simplify optional trl.openenv import handling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…3790)

* Fix is_contiguous() method call and remove duplicate imports

- Fix bug in rope_embedding.py where is_contiguous was used without
  parentheses, causing the method object (always truthy) to be evaluated
  instead of calling the method. This fixes issue unslothai#3781 where fast rope
  backpropagation was broken for zero strided/non-contiguous tensors.

- Remove duplicate `import torch` in rl.py (lines 20 and 25)
- Remove duplicate `import functools` and `import types` in vision.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix Boolean value of Tensor ambiguity error in mistral.py

Replace `or` operator with explicit `is None` check when getting
n_items from kwargs. The `or` operator fails when the value is a
Tensor because Python cannot determine the boolean value of a
multi-element tensor.

Fixes unslothai#3766

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Update rope_embedding.py

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…lothai#3794)

Add "corda" as an allowed value for the init_lora_weights parameter
in FastLanguageModel.get_peft_model() and FastBaseModel.get_peft_model().

This enables users to use CorDA (Correlation-aware Decomposed Adaptation)
initialization from PEFT, which provides an alternative LoRA initialization
strategy for improved finetuning performance.

Fixes unslothai#3693

Signed-off-by: majiayu000 <1835304752@qq.com>
…lothai#3811)

* Fix correctness bugs in rl.py, rl_replacements.py, and vision.py

1. rl_replacements.py (lines 864, 870): Fixed undefined `nanmin`/`nanmax`
   functions by using `.nan_to_num(nan=inf/-inf).min()/.max()` pattern.
   PyTorch doesn't have torch.nanmin/nanmax, so we replace NaN values
   before computing min/max.

2. vision.py (line 150): Fixed bug where code checked for "input" key
   but then accessed kwargs["input_ids"] instead of kwargs["input"].

3. vision.py (line 159): Fixed bug where literal string "key" was used
   instead of the variable `key` when accessing kwargs.

4. rl.py (lines 903, 905): Fixed non-existent `MathError` exception
   by replacing with `ValueError`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1. cohere.py:347-348 - Fixed wrong variable names in QK normalization.
   Used `Q`/`K` but variables were named `Qn`/`Kn`. This caused NameError
   when `use_qk_norm=True` (e.g., c4ai-command-r-plus models).

2. cohere.py:482 - Fixed wrong object reference in inference loop.
   Used `self.mlp` but should be `decoder_layer.mlp` since we're
   iterating through decoder layers. Caused AttributeError during inference.

3. falcon_h1.py:459,461 - Fixed wrong attribute names in inference path.
   Used `post_attention_layernorm` and `mlp` but Falcon H1 uses
   `pre_ff_layernorm` and `feed_forward`. Caused AttributeError during generation.

4. qwen3_moe.py:210 - Fixed wrong module path with incorrect capitalization.
   Used `transformers.models.Qwen3Moe` but should be `transformers.models.qwen3_moe`.
   Caused AttributeError when patching rotary embeddings.

5. qwen3_moe.py:239 - Fixed wrong model_patcher class.
   Used `FastQwen3Model` but should be `FastQwen3MoeModel` for MoE models.
   Caused incorrect patching for Qwen3 MoE models.

6. hf_hub.py:21-22 - Fixed floor division and missing return for billion values.
   Used `//` instead of `/` for millions, and had no return for values >= 1B.
   Caused incorrect formatting and None return for large numbers.

7. save.py:550 - Fixed self-assignment that did nothing.
   `sharded_ram_usage = sharded_ram_usage` should be `= max_shard_size`.
   Caused integer shard sizes to be ignored.

8. rl.py:562-567 - Fixed orphan string not included in length_check.
   The elif branch for max_seq_length validation was a standalone string
   expression, not concatenated to length_check. Caused silent skip of
   the max_seq_length > model_max_seq_length warning.

9. granite.py:49-52 - Fixed wrong model name and version in error message.
   Said "Gemma2" and "4.42.3" but should be "Granite" and "4.45.0".
…tmul

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
FIX: weight tying for LoRA embeddings and lm_head
Gemma3 models have a large vocabulary (262144 tokens) which causes
training loss to explode when using int8 embedding quantization.

This fix auto-detects Gemma3 models and switches from int8-int4
(phone-deployment) to int4 weight-only QAT for stable training.
…lity

Fix Gemma3 QAT training instability with int8-int4 scheme
When users load a model with fast_inference=False but then try to use
vLLM-style arguments with fast_generate, they previously got confusing
errors. This adds a wrapper that detects common mistakes and provides
helpful guidance:

- Using sampling_params: explains to use HF generate args instead
- Using lora_request: explains LoRA weights are already merged
- Passing text strings: shows how to tokenize input first

Changes:
- Add make_fast_generate_wrapper to _utils.py
- Apply wrapper in llama.py when fast_inference=False
- Apply wrapper in vision.py when fast_inference=False
…apper-helpful-errors

Add helpful error messages for fast_generate when fast_inference=False
…curl

Make llama.cpp CURL dependency optional when building from source
Datta0 and others added 11 commits March 3, 2026 06:30
* Fix lm_head lora save

* Fix _need_to_train_embeddings guard for lm_head LoRA targets

When lm_head is already in final_modules as a LoRA target, the
_need_to_train_embeddings block should not also add it to
modules_to_save. This prevents dual-wrapping (LoRA + modules_to_save
on the same module) which causes assertion failures downstream.

Check if embed_tokens/lm_head are already being trained as LoRA
targets before adding them to modules_to_save. Also prevents
duplicate entries with elif guards.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add intel support for torch210

* fix for typo
…support (unslothai#4138)

* fix: update GGUF save paths to use ~/.unsloth/llama.cpp with Windows support

* fix: quote LLAMA_CPP_DEFAULT_DIR in fallback shell commands to handle paths with spaces

* refactor: deduplicate platform-specific build instructions in quantization error message

* chore: remove accidentally committed PR description file

* Fix import safety and f-string bugs in save.py

- H4: Add defensive try/except for LLAMA_CPP_DEFAULT_DIR and IS_WINDOWS imports
  with fallback defaults, so save.py works even if zoo PR unslothai#526 is not merged yet
- H5: Fix Kaggle error path using plain "Error: {e}" instead of f"Error: {e}",
  so the actual exception is shown to users

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixup mapper issues and resolve properly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix broken wandb import crashing unsloth startup

When wandb is installed but broken (e.g., wandb < 0.19.11 with
protobuf >= 6.0), the import chain unsloth -> trl -> transformers ->
is_wandb_available() -> import wandb crashes with:

  ImportError: cannot import name 'Imports' from
  'wandb.proto.wandb_telemetry_pb2'

This happens because transformers' is_wandb_available() has no
try/except around `import wandb`. The error propagates up and kills
`from unsloth import FastLanguageModel` even though wandb is optional.

Add disable_broken_wandb() following the same pattern as
disable_torchcodec_if_broken(). It proactively tries importing wandb
during early init, and if the import fails, patches
is_wandb_available() to return False and sets WANDB_DISABLED=true.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…slothai#4148)

trl/trainer/callbacks.py imports is_wandb_available from
accelerate.utils, not from transformers. The original fix in unslothai#4147
only patched the transformers version, so `from trl import GRPOTrainer`
still crashed via the callbacks.py -> accelerate -> wandb path.

Must patch both the source module (accelerate.utils.imports) AND the
re-export namespace (accelerate.utils) since Python's
`from accelerate.utils import X` reads from the latter, which holds
its own cached reference.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a caching mechanism to optimize the performance of models utilizing packed sequences. By intelligently caching metadata related to sequence lengths and attention masks, it drastically reduces the number of costly Device-to-Host synchronizations that previously occurred in every layer of a multi-layer model. This optimization leads to substantial speedups in both forward and backward passes, improving overall training and inference efficiency.

Highlights

  • D2H Synchronization Reduction: Implemented per-forward-pass caching for get_packed_info_from_kwargs, build_sdpa_packed_attention_mask, and build_xformers_block_causal_mask to eliminate redundant Device-to-Host (D2H) copies and cudaStreamSynchronize calls across model layers.
  • Caching Strategy: Caches are invalidated naturally by new seq_lengths tensor objects for each new batch, ensuring correctness while maximizing cache hits for subsequent layers within the same forward pass by using Python object identity (is) comparison.
  • Performance Improvement: Achieved significant speedups, including approximately 43.3% for forward pass, 5.8% for backward pass, and 14.3% per batch for Qwen3 14B QLoRA SFT.
Changelog
  • unsloth/utils/packing.py
    • Introduced global cache dictionaries for packed info, SDPA masks, and XFormers block masks.
    • Modified get_packed_info_from_kwargs to check and utilize the _PACKED_INFO_CACHE, and to store results.
    • Updated cu_seqlens initialization from torch.empty to torch.zeros in get_packed_info_from_kwargs.
    • Modified build_xformers_block_causal_mask to incorporate caching logic using _XFORMERS_BLOCK_MASK_CACHE.
    • Modified build_sdpa_packed_attention_mask to integrate caching logic using _SDPA_MASK_CACHE.
Activity
  • Pull request created by ichbinhandsome.
  • Initial code changes implemented to introduce caching for packed sequence metadata.
  • Performance benchmarks provided in the PR description demonstrating significant speedups.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ab927312f7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +236 to +237
if _PACKED_INFO_CACHE["seq_lengths"] is seq_lengths:
return _PACKED_INFO_CACHE["result"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Include device in packed-info cache key

The new cache in get_packed_info_from_kwargs returns a previously computed (lengths, cu_seqlens, max_seqlen) based only on packed_seq_lengths object identity, but those tensors are device-specific. When a forward spans layers on different devices (for example with a device_map/pipeline-parallel setup), later layers can receive cached cu_seqlens from the wrong device, which causes runtime failures in varlen attention paths that require all inputs on the same device. The cache lookup should also key on device (or store per-device results) before reusing the tuple.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a caching mechanism to optimize performance by reducing D2H synchronizations. However, the current implementation has critical thread-safety issues due to the use of global state without proper synchronization, which can lead to race conditions and data corruption in concurrent environments. Specifically, cache updates are not atomic, and one caching function ignores the target device, potentially causing crashes in multi-GPU or pipeline-parallel configurations. It is crucial to address these issues by ensuring thread safety and including all relevant parameters in the cache keys, possibly by making cache updates atomic through storing entries as tuples.

Comment on lines +40 to +46
_PACKED_INFO_CACHE: dict = {"seq_lengths": None, "result": None}

# Cache for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers
_SDPA_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None}

# Cache for build_xformers_block_causal_mask to avoid repeated D2H sync across layers
_XFORMERS_BLOCK_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The use of global dictionaries (_PACKED_INFO_CACHE, _SDPA_MASK_CACHE, _XFORMERS_BLOCK_MASK_CACHE) for caching metadata across forward passes is not thread-safe. In concurrent environments, multiple threads will overwrite these global variables without synchronization, leading to race conditions, silent data corruption, or incorrect model behavior. Specifically, updating dictionary keys in separate statements is not atomic. Consider using thread-local storage (threading.local()) or appropriate locking mechanisms, or storing the entire cache entry as a single tuple to ensure atomicity and thread safety.

Suggested change
_PACKED_INFO_CACHE: dict = {"seq_lengths": None, "result": None}
# Cache for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers
_SDPA_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None}
# Cache for build_xformers_block_causal_mask to avoid repeated D2H sync across layers
_XFORMERS_BLOCK_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None}
_PACKED_INFO_CACHE: dict = {"entry": (None, None)}
# Cache for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers
_SDPA_MASK_CACHE: dict = {"entry": (None, None, None)}
# Cache for build_xformers_block_causal_mask to avoid repeated D2H sync across layers
_XFORMERS_BLOCK_MASK_CACHE: dict = {"entry": (None, None, None)}

Comment on lines +261 to +276
params = (sliding_window,)
if (
_XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] is seq_lengths
and _XFORMERS_BLOCK_MASK_CACHE["params"] == params
):
return _XFORMERS_BLOCK_MASK_CACHE["mask"]

lengths_tensor = seq_lengths.to("cpu", torch.int32)
if lengths_tensor.numel() == 0:
return None
lengths = tuple(int(x) for x in lengths_tensor.tolist())
mask = _get_cached_block_mask(lengths, sliding_window)

_XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] = seq_lengths
_XFORMERS_BLOCK_MASK_CACHE["params"] = params
_XFORMERS_BLOCK_MASK_CACHE["mask"] = mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To ensure thread-safety as discussed for the cache definitions, please update this function to use the atomic cache update pattern. The cache check and update should be modified to work with a single tuple entry.

Suggested change
params = (sliding_window,)
if (
_XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] is seq_lengths
and _XFORMERS_BLOCK_MASK_CACHE["params"] == params
):
return _XFORMERS_BLOCK_MASK_CACHE["mask"]
lengths_tensor = seq_lengths.to("cpu", torch.int32)
if lengths_tensor.numel() == 0:
return None
lengths = tuple(int(x) for x in lengths_tensor.tolist())
mask = _get_cached_block_mask(lengths, sliding_window)
_XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] = seq_lengths
_XFORMERS_BLOCK_MASK_CACHE["params"] = params
_XFORMERS_BLOCK_MASK_CACHE["mask"] = mask
params = (sliding_window,)
cached_seq_lengths, cached_params, cached_mask = _XFORMERS_BLOCK_MASK_CACHE["entry"]
if (
cached_seq_lengths is seq_lengths
and cached_params == params
):
return cached_mask
lengths_tensor = seq_lengths.to("cpu", torch.int32)
if lengths_tensor.numel() == 0:
return None
lengths = tuple(int(x) for x in lengths_tensor.tolist())
mask = _get_cached_block_mask(lengths, sliding_window)
_XFORMERS_BLOCK_MASK_CACHE["entry"] = (seq_lengths, params, mask)

Comment on lines +236 to +246
if _PACKED_INFO_CACHE["seq_lengths"] is seq_lengths:
return _PACKED_INFO_CACHE["result"]

lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True)
cu_seqlens = torch.empty(lengths.numel() + 1, dtype = torch.int32, device = device)
cu_seqlens[0] = 0
cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device)
torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:])

max_seqlen = int(lengths.max().item())
return lengths, cu_seqlens, max_seqlen
result = (lengths, cu_seqlens, max_seqlen)
_PACKED_INFO_CACHE["seq_lengths"] = seq_lengths
_PACKED_INFO_CACHE["result"] = result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The cache check in get_packed_info_from_kwargs currently ignores the device argument, which can lead to returning cached tensors on the wrong device and cause a PyTorch RuntimeError in multi-GPU or pipeline-parallel setups. The cache key should include the device to ensure correctness across different execution contexts. Additionally, consider the thread-safety implications of updating cache entries, as non-atomic operations can lead to race conditions.

Suggested change
if _PACKED_INFO_CACHE["seq_lengths"] is seq_lengths:
return _PACKED_INFO_CACHE["result"]
lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True)
cu_seqlens = torch.empty(lengths.numel() + 1, dtype = torch.int32, device = device)
cu_seqlens[0] = 0
cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device)
torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:])
max_seqlen = int(lengths.max().item())
return lengths, cu_seqlens, max_seqlen
result = (lengths, cu_seqlens, max_seqlen)
_PACKED_INFO_CACHE["seq_lengths"] = seq_lengths
_PACKED_INFO_CACHE["result"] = result
if _PACKED_INFO_CACHE.get("seq_lengths") is seq_lengths and _PACKED_INFO_CACHE.get("device") == device:
return _PACKED_INFO_CACHE["result"]
lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True)
cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device)
torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:])
max_seqlen = int(lengths.max().item())
result = (lengths, cu_seqlens, max_seqlen)
_PACKED_INFO_CACHE["seq_lengths"] = seq_lengths
_PACKED_INFO_CACHE["device"] = device
_PACKED_INFO_CACHE["result"] = result

sstamenk and others added 9 commits March 7, 2026 01:33
* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
)

* Fix gpt temporary patch for grpo to happen after compile

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for trl 0.28 and above

Remove sync/reload weights calls , remove vllm.LLM instantiation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for trl 0.28 and above

Remove sync/reload weights calls , remove vllm.LLM instantiation

* patch rpc in openenv for newer trl

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pluesclues <136766175+pluesclues@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.4 → v0.15.5](astral-sh/ruff-pre-commit@v0.15.4...v0.15.5)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@ichbinhandsome ichbinhandsome force-pushed the pakcing_cache_optimization branch from 43ace06 to 85cebe7 Compare March 11, 2026 19:07
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2711dbc81a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +343 to +347
_SDPA_MASK_CACHE[device] = {
"seq_lengths": seq_lengths,
"params": params,
"mask": result,
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid pinning full SDPA masks in global cache

This stores the full packed SDPA mask tensor in a module-global dict, so the latest (1,1,T,T) mask for a device stays strongly referenced until another packed SDPA call replaces it. With long packed sequences this can be a very large allocation, and if execution later switches to non-packed/xformers/flash paths (so this function is not called again) the memory remains live and can cause avoidable OOMs in later steps. A per-forward or explicitly evictable cache would keep the sync optimization without holding large masks indefinitely.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.