Cache packed sequence metadata to reduce D2H syncs across layers#4173
Cache packed sequence metadata to reduce D2H syncs across layers#4173ichbinhandsome wants to merge 3447 commits intounslothai:mainfrom
Conversation
* Guard optional trl.experimental.openenv usage in RL patches * Simplify optional trl.openenv import handling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…3790) * Fix is_contiguous() method call and remove duplicate imports - Fix bug in rope_embedding.py where is_contiguous was used without parentheses, causing the method object (always truthy) to be evaluated instead of calling the method. This fixes issue unslothai#3781 where fast rope backpropagation was broken for zero strided/non-contiguous tensors. - Remove duplicate `import torch` in rl.py (lines 20 and 25) - Remove duplicate `import functools` and `import types` in vision.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix Boolean value of Tensor ambiguity error in mistral.py Replace `or` operator with explicit `is None` check when getting n_items from kwargs. The `or` operator fails when the value is a Tensor because Python cannot determine the boolean value of a multi-element tensor. Fixes unslothai#3766 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Update rope_embedding.py --------- Co-authored-by: yurekami <yurekami@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…lothai#3794) Add "corda" as an allowed value for the init_lora_weights parameter in FastLanguageModel.get_peft_model() and FastBaseModel.get_peft_model(). This enables users to use CorDA (Correlation-aware Decomposed Adaptation) initialization from PEFT, which provides an alternative LoRA initialization strategy for improved finetuning performance. Fixes unslothai#3693 Signed-off-by: majiayu000 <1835304752@qq.com>
for more information, see https://pre-commit.ci
…lothai#3811) * Fix correctness bugs in rl.py, rl_replacements.py, and vision.py 1. rl_replacements.py (lines 864, 870): Fixed undefined `nanmin`/`nanmax` functions by using `.nan_to_num(nan=inf/-inf).min()/.max()` pattern. PyTorch doesn't have torch.nanmin/nanmax, so we replace NaN values before computing min/max. 2. vision.py (line 150): Fixed bug where code checked for "input" key but then accessed kwargs["input_ids"] instead of kwargs["input"]. 3. vision.py (line 159): Fixed bug where literal string "key" was used instead of the variable `key` when accessing kwargs. 4. rl.py (lines 903, 905): Fixed non-existent `MathError` exception by replacing with `ValueError`. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1. cohere.py:347-348 - Fixed wrong variable names in QK normalization. Used `Q`/`K` but variables were named `Qn`/`Kn`. This caused NameError when `use_qk_norm=True` (e.g., c4ai-command-r-plus models). 2. cohere.py:482 - Fixed wrong object reference in inference loop. Used `self.mlp` but should be `decoder_layer.mlp` since we're iterating through decoder layers. Caused AttributeError during inference. 3. falcon_h1.py:459,461 - Fixed wrong attribute names in inference path. Used `post_attention_layernorm` and `mlp` but Falcon H1 uses `pre_ff_layernorm` and `feed_forward`. Caused AttributeError during generation. 4. qwen3_moe.py:210 - Fixed wrong module path with incorrect capitalization. Used `transformers.models.Qwen3Moe` but should be `transformers.models.qwen3_moe`. Caused AttributeError when patching rotary embeddings. 5. qwen3_moe.py:239 - Fixed wrong model_patcher class. Used `FastQwen3Model` but should be `FastQwen3MoeModel` for MoE models. Caused incorrect patching for Qwen3 MoE models. 6. hf_hub.py:21-22 - Fixed floor division and missing return for billion values. Used `//` instead of `/` for millions, and had no return for values >= 1B. Caused incorrect formatting and None return for large numbers. 7. save.py:550 - Fixed self-assignment that did nothing. `sharded_ram_usage = sharded_ram_usage` should be `= max_shard_size`. Caused integer shard sizes to be ignored. 8. rl.py:562-567 - Fixed orphan string not included in length_check. The elif branch for max_seq_length validation was a standalone string expression, not concatenated to length_check. Caused silent skip of the max_seq_length > model_max_seq_length warning. 9. granite.py:49-52 - Fixed wrong model name and version in error message. Said "Gemma2" and "4.42.3" but should be "Granite" and "4.45.0".
…tmul Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
FIX: weight tying for LoRA embeddings and lm_head
Gemma3 models have a large vocabulary (262144 tokens) which causes training loss to explode when using int8 embedding quantization. This fix auto-detects Gemma3 models and switches from int8-int4 (phone-deployment) to int4 weight-only QAT for stable training.
…lity Fix Gemma3 QAT training instability with int8-int4 scheme
When users load a model with fast_inference=False but then try to use vLLM-style arguments with fast_generate, they previously got confusing errors. This adds a wrapper that detects common mistakes and provides helpful guidance: - Using sampling_params: explains to use HF generate args instead - Using lora_request: explains LoRA weights are already merged - Passing text strings: shows how to tokenize input first Changes: - Add make_fast_generate_wrapper to _utils.py - Apply wrapper in llama.py when fast_inference=False - Apply wrapper in vision.py when fast_inference=False
for more information, see https://pre-commit.ci
…apper-helpful-errors Add helpful error messages for fast_generate when fast_inference=False
for more information, see https://pre-commit.ci
…curl Make llama.cpp CURL dependency optional when building from source
* Fix lm_head lora save * Fix _need_to_train_embeddings guard for lm_head LoRA targets When lm_head is already in final_modules as a LoRA target, the _need_to_train_embeddings block should not also add it to modules_to_save. This prevents dual-wrapping (LoRA + modules_to_save on the same module) which causes assertion failures downstream. Check if embed_tokens/lm_head are already being trained as LoRA targets before adding them to modules_to_save. Also prevents duplicate entries with elif guards. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add intel support for torch210 * fix for typo
…support (unslothai#4138) * fix: update GGUF save paths to use ~/.unsloth/llama.cpp with Windows support * fix: quote LLAMA_CPP_DEFAULT_DIR in fallback shell commands to handle paths with spaces * refactor: deduplicate platform-specific build instructions in quantization error message * chore: remove accidentally committed PR description file * Fix import safety and f-string bugs in save.py - H4: Add defensive try/except for LLAMA_CPP_DEFAULT_DIR and IS_WINDOWS imports with fallback defaults, so save.py works even if zoo PR unslothai#526 is not merged yet - H5: Fix Kaggle error path using plain "Error: {e}" instead of f"Error: {e}", so the actual exception is shown to users * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixup mapper issues and resolve properly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix broken wandb import crashing unsloth startup When wandb is installed but broken (e.g., wandb < 0.19.11 with protobuf >= 6.0), the import chain unsloth -> trl -> transformers -> is_wandb_available() -> import wandb crashes with: ImportError: cannot import name 'Imports' from 'wandb.proto.wandb_telemetry_pb2' This happens because transformers' is_wandb_available() has no try/except around `import wandb`. The error propagates up and kills `from unsloth import FastLanguageModel` even though wandb is optional. Add disable_broken_wandb() following the same pattern as disable_torchcodec_if_broken(). It proactively tries importing wandb during early init, and if the import fails, patches is_wandb_available() to return False and sets WANDB_DISABLED=true. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…slothai#4148) trl/trainer/callbacks.py imports is_wandb_available from accelerate.utils, not from transformers. The original fix in unslothai#4147 only patched the transformers version, so `from trl import GRPOTrainer` still crashed via the callbacks.py -> accelerate -> wandb path. Must patch both the source module (accelerate.utils.imports) AND the re-export namespace (accelerate.utils) since Python's `from accelerate.utils import X` reads from the latter, which holds its own cached reference.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a caching mechanism to optimize the performance of models utilizing packed sequences. By intelligently caching metadata related to sequence lengths and attention masks, it drastically reduces the number of costly Device-to-Host synchronizations that previously occurred in every layer of a multi-layer model. This optimization leads to substantial speedups in both forward and backward passes, improving overall training and inference efficiency. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ab927312f7
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
unsloth/utils/packing.py
Outdated
| if _PACKED_INFO_CACHE["seq_lengths"] is seq_lengths: | ||
| return _PACKED_INFO_CACHE["result"] |
There was a problem hiding this comment.
Include device in packed-info cache key
The new cache in get_packed_info_from_kwargs returns a previously computed (lengths, cu_seqlens, max_seqlen) based only on packed_seq_lengths object identity, but those tensors are device-specific. When a forward spans layers on different devices (for example with a device_map/pipeline-parallel setup), later layers can receive cached cu_seqlens from the wrong device, which causes runtime failures in varlen attention paths that require all inputs on the same device. The cache lookup should also key on device (or store per-device results) before reusing the tuple.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces a caching mechanism to optimize performance by reducing D2H synchronizations. However, the current implementation has critical thread-safety issues due to the use of global state without proper synchronization, which can lead to race conditions and data corruption in concurrent environments. Specifically, cache updates are not atomic, and one caching function ignores the target device, potentially causing crashes in multi-GPU or pipeline-parallel configurations. It is crucial to address these issues by ensuring thread safety and including all relevant parameters in the cache keys, possibly by making cache updates atomic through storing entries as tuples.
unsloth/utils/packing.py
Outdated
| _PACKED_INFO_CACHE: dict = {"seq_lengths": None, "result": None} | ||
|
|
||
| # Cache for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers | ||
| _SDPA_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None} | ||
|
|
||
| # Cache for build_xformers_block_causal_mask to avoid repeated D2H sync across layers | ||
| _XFORMERS_BLOCK_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None} |
There was a problem hiding this comment.
The use of global dictionaries (_PACKED_INFO_CACHE, _SDPA_MASK_CACHE, _XFORMERS_BLOCK_MASK_CACHE) for caching metadata across forward passes is not thread-safe. In concurrent environments, multiple threads will overwrite these global variables without synchronization, leading to race conditions, silent data corruption, or incorrect model behavior. Specifically, updating dictionary keys in separate statements is not atomic. Consider using thread-local storage (threading.local()) or appropriate locking mechanisms, or storing the entire cache entry as a single tuple to ensure atomicity and thread safety.
| _PACKED_INFO_CACHE: dict = {"seq_lengths": None, "result": None} | |
| # Cache for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers | |
| _SDPA_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None} | |
| # Cache for build_xformers_block_causal_mask to avoid repeated D2H sync across layers | |
| _XFORMERS_BLOCK_MASK_CACHE: dict = {"seq_lengths": None, "params": None, "mask": None} | |
| _PACKED_INFO_CACHE: dict = {"entry": (None, None)} | |
| # Cache for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers | |
| _SDPA_MASK_CACHE: dict = {"entry": (None, None, None)} | |
| # Cache for build_xformers_block_causal_mask to avoid repeated D2H sync across layers | |
| _XFORMERS_BLOCK_MASK_CACHE: dict = {"entry": (None, None, None)} |
unsloth/utils/packing.py
Outdated
| params = (sliding_window,) | ||
| if ( | ||
| _XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] is seq_lengths | ||
| and _XFORMERS_BLOCK_MASK_CACHE["params"] == params | ||
| ): | ||
| return _XFORMERS_BLOCK_MASK_CACHE["mask"] | ||
|
|
||
| lengths_tensor = seq_lengths.to("cpu", torch.int32) | ||
| if lengths_tensor.numel() == 0: | ||
| return None | ||
| lengths = tuple(int(x) for x in lengths_tensor.tolist()) | ||
| mask = _get_cached_block_mask(lengths, sliding_window) | ||
|
|
||
| _XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] = seq_lengths | ||
| _XFORMERS_BLOCK_MASK_CACHE["params"] = params | ||
| _XFORMERS_BLOCK_MASK_CACHE["mask"] = mask |
There was a problem hiding this comment.
To ensure thread-safety as discussed for the cache definitions, please update this function to use the atomic cache update pattern. The cache check and update should be modified to work with a single tuple entry.
| params = (sliding_window,) | |
| if ( | |
| _XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] is seq_lengths | |
| and _XFORMERS_BLOCK_MASK_CACHE["params"] == params | |
| ): | |
| return _XFORMERS_BLOCK_MASK_CACHE["mask"] | |
| lengths_tensor = seq_lengths.to("cpu", torch.int32) | |
| if lengths_tensor.numel() == 0: | |
| return None | |
| lengths = tuple(int(x) for x in lengths_tensor.tolist()) | |
| mask = _get_cached_block_mask(lengths, sliding_window) | |
| _XFORMERS_BLOCK_MASK_CACHE["seq_lengths"] = seq_lengths | |
| _XFORMERS_BLOCK_MASK_CACHE["params"] = params | |
| _XFORMERS_BLOCK_MASK_CACHE["mask"] = mask | |
| params = (sliding_window,) | |
| cached_seq_lengths, cached_params, cached_mask = _XFORMERS_BLOCK_MASK_CACHE["entry"] | |
| if ( | |
| cached_seq_lengths is seq_lengths | |
| and cached_params == params | |
| ): | |
| return cached_mask | |
| lengths_tensor = seq_lengths.to("cpu", torch.int32) | |
| if lengths_tensor.numel() == 0: | |
| return None | |
| lengths = tuple(int(x) for x in lengths_tensor.tolist()) | |
| mask = _get_cached_block_mask(lengths, sliding_window) | |
| _XFORMERS_BLOCK_MASK_CACHE["entry"] = (seq_lengths, params, mask) |
unsloth/utils/packing.py
Outdated
| if _PACKED_INFO_CACHE["seq_lengths"] is seq_lengths: | ||
| return _PACKED_INFO_CACHE["result"] | ||
|
|
||
| lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True) | ||
| cu_seqlens = torch.empty(lengths.numel() + 1, dtype = torch.int32, device = device) | ||
| cu_seqlens[0] = 0 | ||
| cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device) | ||
| torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:]) | ||
|
|
||
| max_seqlen = int(lengths.max().item()) | ||
| return lengths, cu_seqlens, max_seqlen | ||
| result = (lengths, cu_seqlens, max_seqlen) | ||
| _PACKED_INFO_CACHE["seq_lengths"] = seq_lengths | ||
| _PACKED_INFO_CACHE["result"] = result |
There was a problem hiding this comment.
The cache check in get_packed_info_from_kwargs currently ignores the device argument, which can lead to returning cached tensors on the wrong device and cause a PyTorch RuntimeError in multi-GPU or pipeline-parallel setups. The cache key should include the device to ensure correctness across different execution contexts. Additionally, consider the thread-safety implications of updating cache entries, as non-atomic operations can lead to race conditions.
| if _PACKED_INFO_CACHE["seq_lengths"] is seq_lengths: | |
| return _PACKED_INFO_CACHE["result"] | |
| lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True) | |
| cu_seqlens = torch.empty(lengths.numel() + 1, dtype = torch.int32, device = device) | |
| cu_seqlens[0] = 0 | |
| cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device) | |
| torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:]) | |
| max_seqlen = int(lengths.max().item()) | |
| return lengths, cu_seqlens, max_seqlen | |
| result = (lengths, cu_seqlens, max_seqlen) | |
| _PACKED_INFO_CACHE["seq_lengths"] = seq_lengths | |
| _PACKED_INFO_CACHE["result"] = result | |
| if _PACKED_INFO_CACHE.get("seq_lengths") is seq_lengths and _PACKED_INFO_CACHE.get("device") == device: | |
| return _PACKED_INFO_CACHE["result"] | |
| lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True) | |
| cu_seqlens = torch.zeros(lengths.numel() + 1, dtype = torch.int32, device = device) | |
| torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:]) | |
| max_seqlen = int(lengths.max().item()) | |
| result = (lengths, cu_seqlens, max_seqlen) | |
| _PACKED_INFO_CACHE["seq_lengths"] = seq_lengths | |
| _PACKED_INFO_CACHE["device"] = device | |
| _PACKED_INFO_CACHE["result"] = result |
* Refactor loss computation to include completion_mask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
) * Fix gpt temporary patch for grpo to happen after compile * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Refactor loss computation to include completion_mask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes for trl 0.28 and above Remove sync/reload weights calls , remove vllm.LLM instantiation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor loss computation to include completion_mask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes for trl 0.28 and above Remove sync/reload weights calls , remove vllm.LLM instantiation * patch rpc in openenv for newer trl * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pluesclues <136766175+pluesclues@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
updates: - [github.com/astral-sh/ruff-pre-commit: v0.15.4 → v0.15.5](astral-sh/ruff-pre-commit@v0.15.4...v0.15.5) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
for more information, see https://pre-commit.ci
43ace06 to
85cebe7
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2711dbc81a
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| _SDPA_MASK_CACHE[device] = { | ||
| "seq_lengths": seq_lengths, | ||
| "params": params, | ||
| "mask": result, | ||
| } |
There was a problem hiding this comment.
Avoid pinning full SDPA masks in global cache
This stores the full packed SDPA mask tensor in a module-global dict, so the latest (1,1,T,T) mask for a device stays strongly referenced until another packed SDPA call replaces it. With long packed sequences this can be a very large allocation, and if execution later switches to non-packed/xformers/flash paths (so this function is not called again) the memory remains live and can cause avoidable OOMs in later steps. A per-forward or explicitly evictable cache would keep the sync optimization without holding large masks indefinitely.
Useful? React with 👍 / 👎.
Added per-forward-pass caching to eliminate redundant D2H copies and
cudaStreamSynchronizecalls across layers.When packing (padding-free) is enabled, three functions are called on every layer of the model during the forward pass:
get_packed_info_from_kwargs: callslengths.max().item()— triggers D2H copy + syncbuild_sdpa_packed_attention_mask(SDPA backend): callsseq_lengths.sum().item()andseq_lengths.tolist()— triggers 2 D2H copies + syncsbuild_xformers_block_causal_mask(XFormers backend): callsseq_lengths.to("cpu")— triggers D2H copy + syncFor a model with N layers, this results in N unnecessary D2H synchronizations per function, even though the packed sequence metadata (
seq_lengths) is identical across all layers within the same forward pass.Solution
Cache the output of each function using Python object identity (is) comparison on the
seq_lengthstensor. Since the sameseq_lengthstensor object is passed to all layers within a single forward pass, subsequent layers hit the cache and skip the D2H operations entirely. A new batch produces a newseq_lengthstensor object, which naturally invalidates the cache.This reduces D2H synchronizations per forward pass:
get_packed_info_from_kwargsbuild_sdpa_packed_attention_maskbuild_xformers_block_causal_maskNsys profiling traces:
Performance
With the caching strategy, CudaStreamSync only appears in the first layer and in the following layers it disappears. We achieve around 43.3% speedup for forward, 5.8% speedup for backward, 14.3% speedup for each batch for Qwen3 14B QLoRA SFT.