2525from ...activations import ACT2FN
2626from ...cache_utils import Cache , DynamicCache
2727from ...generation import GenerationMixin
28- from ...modeling_attn_mask_utils import AttentionMaskConverter
28+ from ...masking_utils import create_causal_mask
2929from ...modeling_flash_attention_utils import FlashAttentionKwargs
3030from ...modeling_layers import GradientCheckpointingLayer
3131from ...modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
3535 TransformersKwargs ,
3636 auto_docstring ,
3737 can_return_tuple ,
38- is_torch_flex_attn_available ,
3938 is_torchdynamo_compiling ,
4039 logging ,
4140)
4241from .configuration_chameleon import ChameleonConfig , ChameleonVQVAEConfig
4342
4443
45- if is_torch_flex_attn_available ():
46- from torch .nn .attention .flex_attention import BlockMask
47-
48- from ...integrations .flex_attention import make_flex_block_causal_mask
49-
50-
5144logger = logging .get_logger (__name__ )
5245
5346
@@ -353,15 +346,8 @@ def forward(
353346 key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
354347
355348 attention_interface : Callable = eager_attention_forward
356-
357349 if self .config ._attn_implementation != "eager" :
358- if self .config ._attn_implementation == "sdpa" and output_attentions :
359- logger .warning_once (
360- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
361- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
362- )
363- else :
364- attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
350+ attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
365351
366352 attn_output , attn_weights = attention_interface (
367353 self ,
@@ -942,7 +928,7 @@ def forward(
942928 else :
943929 special_image_mask = input_ids == self .vocabulary_mapping .image_token_id
944930
945- n_image_tokens_in_text = ( special_image_mask ) .sum ()
931+ n_image_tokens_in_text = special_image_mask .sum ()
946932 special_image_mask = special_image_mask .unsqueeze (- 1 ).expand_as (inputs_embeds ).to (inputs_embeds .device )
947933
948934 image_embeds = self .get_image_features (pixel_values )
@@ -966,8 +952,13 @@ def forward(
966952 if position_ids is None :
967953 position_ids = cache_position .unsqueeze (0 )
968954
969- causal_mask = self ._update_causal_mask (
970- attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
955+ causal_mask = create_causal_mask (
956+ config = self .config ,
957+ input_embeds = inputs_embeds ,
958+ attention_mask = attention_mask ,
959+ cache_position = cache_position ,
960+ past_key_values = past_key_values ,
961+ position_ids = position_ids ,
971962 )
972963
973964 # embed positions
@@ -1015,131 +1006,6 @@ def forward(
10151006 attentions = all_self_attns ,
10161007 )
10171008
1018- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
1019- def _update_causal_mask (
1020- self ,
1021- attention_mask : Union [torch .Tensor , "BlockMask" ],
1022- input_tensor : torch .Tensor ,
1023- cache_position : torch .Tensor ,
1024- past_key_values : Cache ,
1025- output_attentions : bool = False ,
1026- ):
1027- if self .config ._attn_implementation == "flash_attention_2" :
1028- if attention_mask is not None and (attention_mask == 0.0 ).any ():
1029- return attention_mask
1030- return None
1031- if self .config ._attn_implementation == "flex_attention" :
1032- if isinstance (attention_mask , torch .Tensor ):
1033- attention_mask = make_flex_block_causal_mask (attention_mask )
1034- return attention_mask
1035-
1036- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1037- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1038- # to infer the attention mask.
1039- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
1040- using_compilable_cache = past_key_values .is_compileable if past_key_values is not None else False
1041-
1042- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1043- if self .config ._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions :
1044- if AttentionMaskConverter ._ignore_causal_mask_sdpa (
1045- attention_mask ,
1046- inputs_embeds = input_tensor ,
1047- past_key_values_length = past_seen_tokens ,
1048- is_training = self .training ,
1049- ):
1050- return None
1051-
1052- dtype = input_tensor .dtype
1053- sequence_length = input_tensor .shape [1 ]
1054- if using_compilable_cache :
1055- target_length = past_key_values .get_max_cache_shape ()
1056- else :
1057- target_length = (
1058- attention_mask .shape [- 1 ]
1059- if isinstance (attention_mask , torch .Tensor )
1060- else past_seen_tokens + sequence_length + 1
1061- )
1062-
1063- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1064- causal_mask = self ._prepare_4d_causal_attention_mask_with_cache_position (
1065- attention_mask ,
1066- sequence_length = sequence_length ,
1067- target_length = target_length ,
1068- dtype = dtype ,
1069- cache_position = cache_position ,
1070- batch_size = input_tensor .shape [0 ],
1071- )
1072-
1073- if (
1074- self .config ._attn_implementation == "sdpa"
1075- and attention_mask is not None
1076- and attention_mask .device .type in ["cuda" , "xpu" , "npu" ]
1077- and not output_attentions
1078- ):
1079- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1080- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1081- # Details: https://github.com/pytorch/pytorch/issues/110213
1082- min_dtype = torch .finfo (dtype ).min
1083- causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
1084-
1085- return causal_mask
1086-
1087- @staticmethod
1088- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
1089- def _prepare_4d_causal_attention_mask_with_cache_position (
1090- attention_mask : torch .Tensor ,
1091- sequence_length : int ,
1092- target_length : int ,
1093- dtype : torch .dtype ,
1094- cache_position : torch .Tensor ,
1095- batch_size : int ,
1096- ** kwargs ,
1097- ):
1098- """
1099- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1100- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1101-
1102- Args:
1103- attention_mask (`torch.Tensor`):
1104- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1105- `(batch_size, 1, query_length, key_value_length)`.
1106- sequence_length (`int`):
1107- The sequence length being processed.
1108- target_length (`int`):
1109- The target length: when generating with static cache, the mask should be as long as the static cache,
1110- to account for the 0 padding, the part of the cache that is not filled yet.
1111- dtype (`torch.dtype`):
1112- The dtype to use for the 4D attention mask.
1113- cache_position (`torch.Tensor`):
1114- Indices depicting the position of the input sequence tokens in the sequence.
1115- batch_size (`torch.Tensor`):
1116- Batch size.
1117- """
1118- if attention_mask is not None and attention_mask .dim () == 4 :
1119- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1120- causal_mask = attention_mask
1121- else :
1122- min_dtype = torch .finfo (dtype ).min
1123- causal_mask = torch .full (
1124- (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = cache_position .device
1125- )
1126- if sequence_length != 1 :
1127- causal_mask = torch .triu (causal_mask , diagonal = 1 )
1128- causal_mask *= torch .arange (target_length , device = cache_position .device ) > cache_position .reshape (- 1 , 1 )
1129- causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
1130- if attention_mask is not None :
1131- causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
1132- mask_length = attention_mask .shape [- 1 ]
1133- padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (
1134- causal_mask .device
1135- )
1136- padding_mask = padding_mask == 0
1137- causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
1138- padding_mask , min_dtype
1139- )
1140-
1141- return causal_mask
1142-
11431009
11441010@auto_docstring (
11451011 custom_intro = """
0 commit comments