Skip to content

[trainer] bug: TorchtitanEngine silently ignores attn_type="flex" — no clear BKM for which torchtitan version to use #6182

@kahlun

Description

@kahlun

System Info

Environment

  • verl: v0.7.1 (latest stable release)
  • torchtitan tested: v0.2.2 (pip) and HEAD (from source)
  • GPU: NVIDIA A100 80GB, PyTorch 2.10+cu129

Problem

There is no documented BKM (best known method) for which torchtitan version
works with verl's TorchtitanEngine. I tried both options and hit different
failures each time:

With torchtitan v0.2.2 (pip install):
Fails at startup — Trainer.Config does not exist in v0.2.2.
verl v0.7.1 uses Trainer.Config (introduced to torchtitan in commit 9810191,
Feb 23 2026), but torchtitan v0.2.2 was tagged Feb 20 — 3 days before that
refactor. So pip install torchtitan is already broken with verl v0.7.1.

With torchtitan HEAD (from source):
Fails silently during training — the model is built with "sdpa" attention
even when attn_type="flex" (the default in TorchtitanEngineConfig).

Root cause: model_registry() is called without attn_backend=:

model_spec = model_module.model_registry(torchtitan_flavor)

Since torchtitan commit 7cec166 (Apr 17 2026, PR #2960) added attn_backend
as a parameter to model_registry(), omitting it causes torchtitan to fall back
to its own per-model default ("sdpa" for llama3/qwen3), silently ignoring
engine_config.attn_type. The mask builder then calls get_attention_masks()
with attn_type="flex" while the model was built for sdpa → crash or wrong
gradients.

There is also a second bug: the get_attention_masks() call site reads
self.trainer.model_config.layer.attention.attn_backend (wrong path, always
raises AttributeError) instead of self.engine_config.attn_type.

Questions
What is the correct/supported torchtitan commit or snapshot to use with
verl v0.7.1 or current main? Is there a pinned commit somewhere?

Is a fix for the attn_backend call site welcome as a PR, or is there a
follow-up planned as part of the roadmap (#5306)?

A minimal fix is straightforward:

Pass attn_backend=self.engine_config.attn_type to model_registry()
Fix the second call site to read self.engine_config.attn_type
cc @acisseJZhong (roadmap owner, #5306)

I have a working fix with unit tests ready if a PR is welcome.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce:

  1. Install verl v0.7.1 and torchtitan from source (HEAD, post-commit 7cec166)
  2. Configure a job using TorchtitanEngineConfig with default attn_type="flex" and use_remove_padding=True
  3. Run SFT with a llama3 or qwen3 model

Expected: model trains with FlexAttention as configured.

Actual (crash path): RuntimeError or TypeError from get_attention_masks() because the
model was silently built with "sdpa" modules while masks are built for "flex".

Actual (silent wrong path): If use_remove_padding=False, training runs but
attn_type is still ignored — attention backend differs from what was requested.

Code pointers (verl v0.7.1, transformer_impl.py):

# Line 110: attn_backend never passed — torchtitan defaults to "sdpa" for llama3/qwen3
model_spec = model_module.model_registry(torchtitan_flavor)

# Line 112-115: override guard always False — "layer" (singular) doesn't exist, it's "layers"
if hasattr(model_spec.model, "layer"):
    model_spec.model.layer.attention.attn_backend = attn_type  # dead code, never runs

# Line 599: wrong attribute path — raises AttributeError
attn_type = self.trainer.model_config.layer.attention.attn_backend  # should be self.engine_config.attn_type

### Expected behavior

TorchtitanEngine should use the attn_type value from TorchtitanEngineConfig
(default: "flex") when building the model via model_registry(), so that the
attention backend in the built model matches what get_attention_masks() is
called with downstream.

Additionally, it would be helpful to have a documented minimum torchtitan commit or pinned snapshot that works with verl v0.7.1 / current main, since pip install torchtitan (v0.2.2) is already incompatible due to the Trainer.Config API change.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions