Skip to content

Commit ccca85e

Browse files
zucchini-nlpzaristei
authored andcommitted
[attention] fix test for packed padfree masking (huggingface#39582)
* fix most tests * skip a few more tests * address comments * fix chameleon tests * forgot to uncomment * qwen has its own tests with images, rename it as well
1 parent 98fa4e6 commit ccca85e

File tree

17 files changed

+153
-250
lines changed

17 files changed

+153
-250
lines changed

src/transformers/models/bark/modeling_bark.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -408,69 +408,31 @@ def get_input_embeddings(self):
408408
def set_input_embeddings(self, new_embeddings):
409409
self.input_embeds_layer = new_embeddings
410410

411-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, cache_position=None, **kwargs):
412-
# Overwritten -- bark has a model-specific hack
413-
input_embeds = kwargs.get("input_embeds", None)
414-
415-
attention_mask = kwargs.get("attention_mask", None)
416-
position_ids = kwargs.get("position_ids", None)
417-
418-
if cache_position[0] != 0:
419-
# Omit tokens covered by past_key_values
420-
seq_len = input_ids.shape[1]
421-
past_length = past_key_values.get_seq_length()
422-
423-
# Some generation methods already pass only the last input ID
424-
if input_ids.shape[1] > past_length:
425-
remove_prefix_length = past_length
426-
else:
427-
# Default to old behavior: keep only final ID
428-
remove_prefix_length = input_ids.shape[1] - 1
429-
430-
input_ids = input_ids[:, remove_prefix_length:]
431-
432-
# input_embeds have already been used and is not required anymore
433-
input_embeds = None
434-
else:
435-
if input_embeds is not None and kwargs.get("use_cache"):
436-
seq_len = input_embeds.shape[1]
437-
else:
438-
seq_len = input_ids.shape[1]
411+
def prepare_inputs_for_generation(
412+
self,
413+
input_ids,
414+
attention_mask=None,
415+
input_embeds=None,
416+
past_key_values=None,
417+
position_ids=None,
418+
use_cache=None,
419+
cache_position=None,
420+
**kwargs,
421+
):
422+
# Overwritten -- bark uses `input_embeds` not `inputS_embeds`
439423

440-
# ensure that attention_mask and position_ids shapes are aligned with the weird Bark hack of reducing
441-
# sequence length on the first forward pass
442-
if attention_mask is not None:
443-
attention_mask = attention_mask[:, :seq_len]
444-
if position_ids is not None:
445-
position_ids = position_ids[:, :seq_len]
446-
447-
if attention_mask is not None and position_ids is None:
448-
# create position_ids on the fly for batch generation
449-
position_ids = attention_mask.long().cumsum(-1) - 1
450-
position_ids.masked_fill_(attention_mask == 0, 1)
451-
if past_key_values:
452-
position_ids = position_ids[:, -input_ids.shape[1] :]
453-
else:
454-
position_ids = None
455-
456-
if input_embeds is not None and kwargs.get("use_cache"):
457-
return {
458-
"input_ids": None,
459-
"input_embeds": input_embeds,
460-
"past_key_values": past_key_values,
461-
"use_cache": kwargs.get("use_cache"),
462-
"position_ids": position_ids,
463-
"attention_mask": attention_mask,
464-
"cache_position": cache_position,
465-
}
466-
return {
467-
"input_ids": input_ids,
468-
"past_key_values": past_key_values,
469-
"use_cache": kwargs.get("use_cache"),
470-
"position_ids": position_ids,
471-
"attention_mask": attention_mask,
472-
"cache_position": cache_position,
473-
}
424+
model_inputs = super().prepare_inputs_for_generation(
425+
input_ids,
426+
attention_mask=attention_mask,
427+
inputs_embeds=input_embeds,
428+
past_key_values=past_key_values,
429+
position_ids=position_ids,
430+
use_cache=use_cache,
431+
cache_position=cache_position,
432+
**kwargs,
433+
)
434+
model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
435+
return model_inputs
474436

475437
@auto_docstring
476438
def forward(
@@ -546,7 +508,7 @@ def forward(
546508
return_legacy_cache = True
547509
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
548510

549-
past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values
511+
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
550512

551513
if position_ids is None:
552514
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 10 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...activations import ACT2FN
2626
from ...cache_utils import Cache, DynamicCache
2727
from ...generation import GenerationMixin
28-
from ...modeling_attn_mask_utils import AttentionMaskConverter
28+
from ...masking_utils import create_causal_mask
2929
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3030
from ...modeling_layers import GradientCheckpointingLayer
3131
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -35,19 +35,12 @@
3535
TransformersKwargs,
3636
auto_docstring,
3737
can_return_tuple,
38-
is_torch_flex_attn_available,
3938
is_torchdynamo_compiling,
4039
logging,
4140
)
4241
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
4342

4443

45-
if is_torch_flex_attn_available():
46-
from torch.nn.attention.flex_attention import BlockMask
47-
48-
from ...integrations.flex_attention import make_flex_block_causal_mask
49-
50-
5144
logger = logging.get_logger(__name__)
5245

5346

@@ -353,15 +346,8 @@ def forward(
353346
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
354347

355348
attention_interface: Callable = eager_attention_forward
356-
357349
if self.config._attn_implementation != "eager":
358-
if self.config._attn_implementation == "sdpa" and output_attentions:
359-
logger.warning_once(
360-
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
361-
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
362-
)
363-
else:
364-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
350+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
365351

366352
attn_output, attn_weights = attention_interface(
367353
self,
@@ -942,7 +928,7 @@ def forward(
942928
else:
943929
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
944930

945-
n_image_tokens_in_text = (special_image_mask).sum()
931+
n_image_tokens_in_text = special_image_mask.sum()
946932
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
947933

948934
image_embeds = self.get_image_features(pixel_values)
@@ -966,8 +952,13 @@ def forward(
966952
if position_ids is None:
967953
position_ids = cache_position.unsqueeze(0)
968954

969-
causal_mask = self._update_causal_mask(
970-
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
955+
causal_mask = create_causal_mask(
956+
config=self.config,
957+
input_embeds=inputs_embeds,
958+
attention_mask=attention_mask,
959+
cache_position=cache_position,
960+
past_key_values=past_key_values,
961+
position_ids=position_ids,
971962
)
972963

973964
# embed positions
@@ -1015,131 +1006,6 @@ def forward(
10151006
attentions=all_self_attns,
10161007
)
10171008

1018-
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
1019-
def _update_causal_mask(
1020-
self,
1021-
attention_mask: Union[torch.Tensor, "BlockMask"],
1022-
input_tensor: torch.Tensor,
1023-
cache_position: torch.Tensor,
1024-
past_key_values: Cache,
1025-
output_attentions: bool = False,
1026-
):
1027-
if self.config._attn_implementation == "flash_attention_2":
1028-
if attention_mask is not None and (attention_mask == 0.0).any():
1029-
return attention_mask
1030-
return None
1031-
if self.config._attn_implementation == "flex_attention":
1032-
if isinstance(attention_mask, torch.Tensor):
1033-
attention_mask = make_flex_block_causal_mask(attention_mask)
1034-
return attention_mask
1035-
1036-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1037-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1038-
# to infer the attention mask.
1039-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1040-
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
1041-
1042-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1043-
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
1044-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
1045-
attention_mask,
1046-
inputs_embeds=input_tensor,
1047-
past_key_values_length=past_seen_tokens,
1048-
is_training=self.training,
1049-
):
1050-
return None
1051-
1052-
dtype = input_tensor.dtype
1053-
sequence_length = input_tensor.shape[1]
1054-
if using_compilable_cache:
1055-
target_length = past_key_values.get_max_cache_shape()
1056-
else:
1057-
target_length = (
1058-
attention_mask.shape[-1]
1059-
if isinstance(attention_mask, torch.Tensor)
1060-
else past_seen_tokens + sequence_length + 1
1061-
)
1062-
1063-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1064-
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1065-
attention_mask,
1066-
sequence_length=sequence_length,
1067-
target_length=target_length,
1068-
dtype=dtype,
1069-
cache_position=cache_position,
1070-
batch_size=input_tensor.shape[0],
1071-
)
1072-
1073-
if (
1074-
self.config._attn_implementation == "sdpa"
1075-
and attention_mask is not None
1076-
and attention_mask.device.type in ["cuda", "xpu", "npu"]
1077-
and not output_attentions
1078-
):
1079-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1080-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1081-
# Details: https://github.com/pytorch/pytorch/issues/110213
1082-
min_dtype = torch.finfo(dtype).min
1083-
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1084-
1085-
return causal_mask
1086-
1087-
@staticmethod
1088-
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
1089-
def _prepare_4d_causal_attention_mask_with_cache_position(
1090-
attention_mask: torch.Tensor,
1091-
sequence_length: int,
1092-
target_length: int,
1093-
dtype: torch.dtype,
1094-
cache_position: torch.Tensor,
1095-
batch_size: int,
1096-
**kwargs,
1097-
):
1098-
"""
1099-
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1100-
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1101-
1102-
Args:
1103-
attention_mask (`torch.Tensor`):
1104-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1105-
`(batch_size, 1, query_length, key_value_length)`.
1106-
sequence_length (`int`):
1107-
The sequence length being processed.
1108-
target_length (`int`):
1109-
The target length: when generating with static cache, the mask should be as long as the static cache,
1110-
to account for the 0 padding, the part of the cache that is not filled yet.
1111-
dtype (`torch.dtype`):
1112-
The dtype to use for the 4D attention mask.
1113-
cache_position (`torch.Tensor`):
1114-
Indices depicting the position of the input sequence tokens in the sequence.
1115-
batch_size (`torch.Tensor`):
1116-
Batch size.
1117-
"""
1118-
if attention_mask is not None and attention_mask.dim() == 4:
1119-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1120-
causal_mask = attention_mask
1121-
else:
1122-
min_dtype = torch.finfo(dtype).min
1123-
causal_mask = torch.full(
1124-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
1125-
)
1126-
if sequence_length != 1:
1127-
causal_mask = torch.triu(causal_mask, diagonal=1)
1128-
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
1129-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1130-
if attention_mask is not None:
1131-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1132-
mask_length = attention_mask.shape[-1]
1133-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1134-
causal_mask.device
1135-
)
1136-
padding_mask = padding_mask == 0
1137-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1138-
padding_mask, min_dtype
1139-
)
1140-
1141-
return causal_mask
1142-
11431009

11441010
@auto_docstring(
11451011
custom_intro="""

tests/models/chameleon/test_modeling_chameleon.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def prepare_config_and_inputs(self):
270270
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
271271
input_ids[input_ids == self.image_token_id] = self.pad_token_id
272272
input_ids[:, : self.image_seq_length] = self.image_token_id
273-
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
273+
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
274274
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
275275

276276
config = self.get_config()
@@ -325,6 +325,14 @@ def test_disk_offload_safetensors(self):
325325
def test_model_is_small(self):
326326
pass
327327

328+
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
329+
def test_eager_padding_matches_padding_free_with_position_ids(self):
330+
pass
331+
332+
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
333+
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
334+
pass
335+
328336
def test_mismatching_num_image_tokens(self):
329337
"""
330338
Tests that VLMs through an error with explicit message saying what is wrong

tests/models/emu3/test_modeling_emu3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989

9090
def prepare_config_and_inputs(self):
9191
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
92-
attention_mask = input_ids.ne(1).to(torch_device)
92+
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
9393

9494
config = self.get_config()
9595

@@ -234,9 +234,9 @@ def prepare_config_and_inputs(self):
234234
config = self.get_config()
235235

236236
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size)
237-
attention_mask = input_ids.ne(1).to(torch_device)
238237
input_ids[input_ids == self.image_token_id] = self.pad_token_id
239238
input_ids[:, : self.image_seq_length] = self.image_token_id
239+
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
240240

241241
pixel_values = floats_tensor(
242242
[

tests/models/fuyu/test_modeling_fuyu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,14 @@ def test_model_parallelism(self):
214214
def test_generate_continue_from_inputs_embeds():
215215
pass
216216

217+
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
218+
def test_eager_padding_matches_padding_free_with_position_ids(self):
219+
pass
220+
221+
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
222+
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
223+
pass
224+
217225

218226
@slow
219227
@require_torch_accelerator

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ def test_eager_matches_fa2_generate(self):
143143
def test_multi_gpu_data_parallel_forward(self):
144144
pass
145145

146+
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
147+
def test_eager_padding_matches_padding_free_with_position_ids(self):
148+
pass
149+
150+
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
151+
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
152+
pass
153+
146154

147155
class Gemma3Vision2TextModelTester:
148156
def __init__(

tests/models/kosmos2/test_modeling_kosmos2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,14 @@ def test_eager_matches_sdpa_inference(
465465
):
466466
pass
467467

468+
@unittest.skip("KOSMOS-2 doesn't support padding")
469+
def test_eager_padding_matches_padding_free_with_position_ids(self):
470+
pass
471+
472+
@unittest.skip("KOSMOS-2 doesn't support padding")
473+
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
474+
pass
475+
468476
@pytest.mark.generate
469477
def test_left_padding_compatibility(self):
470478
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask

0 commit comments

Comments
 (0)