diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0775290de9e2..93d9886d85ea 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -441,6 +441,10 @@ title: Encoder Decoder Models - local: model_doc/ernie title: ERNIE + - local: model_doc/ernie4_5 + title: Ernie4_5 + - local: model_doc/ernie4_5_moe + title: Ernie4_5_MoE - local: model_doc/ernie_m title: ErnieM - local: model_doc/esm diff --git a/docs/source/en/model_doc/ernie4_5.md b/docs/source/en/model_doc/ernie4_5.md new file mode 100644 index 000000000000..b350b9d429ae --- /dev/null +++ b/docs/source/en/model_doc/ernie4_5.md @@ -0,0 +1,99 @@ + + +
+
+ PyTorch + FlashAttention + SDPA + Tensor parallelism +
+
+ +# Ernie 4.5 + +## Overview + +The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu. +This family of models contains multiple different architectures and model sizes. This model in specific targets the base text +model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core. + +Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md). + +
+ +
+ + +## Usage Tips + +### Generate text + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "baidu/ERNIE-4.5-0.3B-PT" + +# load the tokenizer and the model +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# prepare the model input +inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") +prompt = "Hey, are you conscious? Can you talk to me?" +messages = [ + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + +# conduct text completion +generated_ids = model.generate( + **model_inputs, + max_new_tokens=32, +) +output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + +# decode the generated ids +generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) +``` + +This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV). +The original code can be found [here](https://github.com/PaddlePaddle/ERNIE). + + +## Ernie4_5Config + +[[autodoc]] Ernie4_5Config + +## Ernie4_5Model + +[[autodoc]] Ernie4_5Model + - forward + +## Ernie4_5ForCausalLM + +[[autodoc]] Ernie4_5ForCausalLM + - forward diff --git a/docs/source/en/model_doc/ernie4_5_moe.md b/docs/source/en/model_doc/ernie4_5_moe.md new file mode 100644 index 000000000000..9d8703e5929d --- /dev/null +++ b/docs/source/en/model_doc/ernie4_5_moe.md @@ -0,0 +1,183 @@ + + +
+
+ PyTorch + FlashAttention + SDPA + Tensor parallelism +
+
+ +# Ernie 4.5 MoE + +## Overview + +The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu. +This family of models contains multiple different architectures and model sizes. This model in specific targets the base text +model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters. +It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared +experts. + +Other models from the family can be found at [Ernie 4.5](./ernie4_5.md). + +
+ +
+ + +## Usage Tips + +### Generate text + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "baidu/ERNIE-4.5-21B-A3B-PT" + +# load the tokenizer and the model +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.bfloat16, +) + +# prepare the model input +inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") +prompt = "Hey, are you conscious? Can you talk to me?" +messages = [ + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + +# conduct text completion +generated_ids = model.generate( + **model_inputs, + max_new_tokens=32, +) +output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + +# decode the generated ids +generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) +``` + +### Distributed Generation with Tensor Parallelism + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "baidu/ERNIE-4.5-21B-A3B-PT" + +# load the tokenizer and the model +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.bfloat16, + tp_plan="auto", +) + +# prepare the model input +inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") +prompt = "Hey, are you conscious? Can you talk to me?" +messages = [ + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + +# conduct text completion +generated_ids = model.generate( + **model_inputs, + max_new_tokens=32, +) +output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + +# decode the generated ids +generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) +``` + +### Quantization with Bitsandbytes + +```python +import torch +from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer + +model_name = "baidu/ERNIE-4.5-21B-A3B-PT" + +# load the tokenizer and the model +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), +) + +# prepare the model input +inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") +prompt = "Hey, are you conscious? Can you talk to me?" +messages = [ + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + +# conduct text completion +generated_ids = model.generate( + **model_inputs, + max_new_tokens=32, +) +output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + +# decode the generated ids +generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) +``` + +This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV). +The original code can be found [here](https://github.com/PaddlePaddle/ERNIE). + + +## Ernie4_5_MoEConfig + +[[autodoc]] Ernie4_5_MoEConfig + +## Ernie4_5_MoEModel + +[[autodoc]] Ernie4_5_MoEModel + - forward + +## Ernie4_5_MoEForCausalLM + +[[autodoc]] Ernie4_5_MoEForCausalLM + - forward + - generate diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e20f4c0fe728..44229e21c40a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2977,6 +2977,17 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings): else: output_embeddings.weight = input_embeddings.weight + # Passing hooks over to the embeddings if needed + # (currently limited to tensor parallel hooks and flags only) + if hasattr(input_embeddings, "_is_hooked") and getattr(input_embeddings, "_hf_tp_plan", None): + output_embeddings._is_hooked = input_embeddings._is_hooked + output_embeddings._hf_tp_plan = input_embeddings._hf_tp_plan + output_embeddings._forward_hooks = input_embeddings._forward_hooks + output_embeddings._forward_pre_hooks = input_embeddings._forward_pre_hooks + output_embeddings.__repr__ = ( + lambda: f"{output_embeddings.__repr__()}\nTP Plan: {output_embeddings._hf_tp_plan}" + ) + if getattr(output_embeddings, "bias", None) is not None: output_embeddings.bias.data = nn.functional.pad( output_embeddings.bias.data, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e5a615f63396..b4642b9cafb2 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -128,6 +128,8 @@ ("encoder-decoder", "EncoderDecoderConfig"), ("eomt", "EomtConfig"), ("ernie", "ErnieConfig"), + ("ernie4_5", "Ernie4_5Config"), + ("ernie4_5_moe", "Ernie4_5_MoEConfig"), ("ernie_m", "ErnieMConfig"), ("esm", "EsmConfig"), ("falcon", "FalconConfig"), @@ -520,6 +522,8 @@ ("encoder-decoder", "Encoder decoder"), ("eomt", "EoMT"), ("ernie", "ERNIE"), + ("ernie4_5", "Ernie4_5"), + ("ernie4_5_moe", "Ernie4_5_MoE"), ("ernie_m", "ErnieM"), ("esm", "ESM"), ("falcon", "Falcon"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 08d1113d409e..36aaec4d53c2 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -119,6 +119,8 @@ ("emu3", "Emu3Model"), ("encodec", "EncodecModel"), ("ernie", "ErnieModel"), + ("ernie4_5", "Ernie4_5Model"), + ("ernie4_5_moe", "Ernie4_5_MoEModel"), ("ernie_m", "ErnieMModel"), ("esm", "EsmModel"), ("falcon", "FalconModel"), @@ -594,6 +596,8 @@ ("electra", "ElectraForCausalLM"), ("emu3", "Emu3ForCausalLM"), ("ernie", "ErnieForCausalLM"), + ("ernie4_5", "Ernie4_5ForCausalLM"), + ("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"), ("falcon", "FalconForCausalLM"), ("falcon_h1", "FalconH1ForCausalLM"), ("falcon_mamba", "FalconMambaForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 0a66a1b97915..eda3d29ae777 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -212,6 +212,8 @@ ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), ("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)), ("esm", ("EsmTokenizer", None)), ("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/ernie4_5/__init__.py b/src/transformers/models/ernie4_5/__init__.py new file mode 100644 index 000000000000..5d6e69432c9a --- /dev/null +++ b/src/transformers/models/ernie4_5/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ernie4_5 import * + from .modeling_ernie4_5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/ernie4_5/configuration_ernie4_5.py b/src/transformers/models/ernie4_5/configuration_ernie4_5.py new file mode 100644 index 000000000000..e6e2795b5daa --- /dev/null +++ b/src/transformers/models/ernie4_5/configuration_ernie4_5.py @@ -0,0 +1,202 @@ +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ernie 4.5 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Ernie4_5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Ernie4_5Model`]. It is used to instantiate an Ernie 4.5 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie 4.5 0.3B. + e.g. [baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 103424): + Vocabulary size of the Ernie 4.5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Ernie4_5Model`] + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 18): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 500000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in any of the projections including mlp and attention for example. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + + ```python + >>> from transformers import Ernie4_5Model, Ernie4_5Config + + >>> # Initializing a Ernie4_5 0.3B style configuration + >>> configuration = Ernie4_5Config() + + >>> # Initializing a model from the 0.3B style configuration + >>> model = Ernie4_5Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie4_5" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Ernie4_5Model` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=103424, + hidden_size=1024, + intermediate_size=3072, + num_hidden_layers=18, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=500000.0, + rope_scaling=None, + use_bias=False, + head_dim=128, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.use_bias = use_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Ernie4_5Config"] diff --git a/src/transformers/models/ernie4_5/convert_ernie4_5_tokenizer.py b/src/transformers/models/ernie4_5/convert_ernie4_5_tokenizer.py new file mode 100644 index 000000000000..25994bb1436f --- /dev/null +++ b/src/transformers/models/ernie4_5/convert_ernie4_5_tokenizer.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from transformers import LlamaTokenizer, LlamaTokenizerFast + + +DEFAULT_CHAT_TEMPLATE = '{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = true -%}\n{%- endif -%}\n{%- if not cls_token is defined -%}\n {%- set cls_token = "<|begin_of_sentence|>" -%}\n{%- endif -%}\n{%- if not sep_token is defined -%}\n {%- set sep_token = "<|end_of_sentence|>" -%}\n{%- endif -%}\n{{- cls_token -}}\n{%- for message in messages -%}\n {%- if message["role"] == "user" -%}\n {{- "User: " + message["content"] + "\n" -}}\n {%- elif message["role"] == "assistant" -%}\n {{- "Assistant: " + message["content"] + sep_token -}}\n {%- elif message["role"] == "system" -%}\n {{- message["content"] + "\n" -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- "Assistant: " -}}\n{%- endif -%}' +DEFAULT_TEXT_ADD_TOKENS = [ + "", + "", + "", + "", +] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo_name", + help="Name of the repo where the tokenizer is located at.", + default="baidu/ERNIE-4.5-0.3B-Base-PT", + ) + parser.add_argument( + "--push_to_hub", + help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.", + action="store_true", + default=False, + ) + parser.add_argument( + "--output_dir", + help="Location to write the tokenizer", + ) + args = parser.parse_args() + + hf_tok = LlamaTokenizer.from_pretrained( + args.repo_name, + pad_token="", + cls_token="<|begin_of_sentence|>", + sep_token="<|end_of_sentence|>", + mask_token="", + add_bos_token=False, + add_prefix_space=False, + chat_template=DEFAULT_CHAT_TEMPLATE, + legacy=True, + ) + hf_tok.model_max_length = 131072 + hf_tok.init_kwargs.pop("auto_map", None) + # special tokens which we need to map as additional special tokens instead + hf_tok.init_kwargs.pop("header_start_token", None) + hf_tok.init_kwargs.pop("header_end_token", None) + hf_tok.init_kwargs.pop("sys_start_token", None) + hf_tok.init_kwargs.pop("sys_end_token", None) + for token in DEFAULT_TEXT_ADD_TOKENS: + hf_tok.add_tokens([token], special_tokens=True) + + # save slow model and convert on load time + hf_tok.save_pretrained("/tmp/ernie4_5_tokenizer") + hf_tok_fast = LlamaTokenizerFast.from_pretrained("/tmp/ernie4_5_tokenizer", from_slow=True) + hf_tok_fast.save_pretrained(args.output_dir, push_to_hub=args.push_to_hub) diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py new file mode 100644 index 000000000000..aab89aa0a261 --- /dev/null +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -0,0 +1,503 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ernie4_5/modular_ernie4_5.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ernie4_5.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs +from .configuration_ernie4_5 import Ernie4_5Config + + +class Ernie4_5RotaryEmbedding(nn.Module): + def __init__(self, config: Ernie4_5Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + # keeping it in full precision + return cos, sin + + +class Ernie4_5MLP(nn.Module): + def __init__(self, config: Ernie4_5Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # glm rope style (with full dim) and full precision + original_dtype = q.dtype + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) + k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) + + return q_embed.to(original_dtype), k_embed.to(original_dtype) + + +class Ernie4_5Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Ernie4_5Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.attention_dropout = 0.0 + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Ernie4_5RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Ernie4_5RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Ernie4_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Ernie4_5Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) + + self.mlp = Ernie4_5MLP(config) + self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Ernie4_5PreTrainedModel(PreTrainedModel): + config: Ernie4_5Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Ernie4_5DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _supports_static_cache = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Ernie4_5DecoderLayer, + "attentions": Ernie4_5Attention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Ernie4_5RMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Ernie4_5Model(Ernie4_5PreTrainedModel): + def __init__(self, config: Ernie4_5Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Ernie4_5Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Ernie4_5ForCausalLM", "Ernie4_5Model", "Ernie4_5PreTrainedModel"] diff --git a/src/transformers/models/ernie4_5/modular_ernie4_5.py b/src/transformers/models/ernie4_5/modular_ernie4_5.py new file mode 100644 index 000000000000..f76c7c6bdae7 --- /dev/null +++ b/src/transformers/models/ernie4_5/modular_ernie4_5.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Ernie 4.5 model""" + +import torch +from torch import nn + +from ...modeling_rope_utils import dynamic_rope_update +from ...utils import auto_docstring, can_return_tuple +from ..glm.modeling_glm import rotate_half +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaRotaryEmbedding, +) +from .configuration_ernie4_5 import Ernie4_5Config + + +class Ernie4_5RotaryEmbedding(LlamaRotaryEmbedding): + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + # keeping it in full precision + return cos, sin + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # glm rope style (with full dim) and full precision + original_dtype = q.dtype + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) + k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) + + return q_embed.to(original_dtype), k_embed.to(original_dtype) + + +class Ernie4_5MLP(LlamaMLP): + def __init__(self, config: Ernie4_5Config): + super().__init__() + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + +class Ernie4_5Attention(LlamaAttention): + def __init__(self, config: Ernie4_5Config, layer_idx: int): + super().__init__(config, layer_idx) + + self.attention_dropout = 0.0 + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + +class Ernie4_5ForCausalLM(LlamaForCausalLM): + @can_return_tuple + @auto_docstring + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + super().forward(**super_kwargs) + + +__all__ = [ + "Ernie4_5ForCausalLM", + "Ernie4_5Model", # noqa: F822 + "Ernie4_5PreTrainedModel", # noqa: F822 +] diff --git a/src/transformers/models/ernie4_5_moe/__init__.py b/src/transformers/models/ernie4_5_moe/__init__.py new file mode 100644 index 000000000000..eb30318fa25e --- /dev/null +++ b/src/transformers/models/ernie4_5_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ernie4_5_moe import * + from .modeling_ernie4_5_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py new file mode 100644 index 000000000000..cec4f4661a77 --- /dev/null +++ b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py @@ -0,0 +1,254 @@ +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ernie 4.5 MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Ernie4_5_MoEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Ernie4_5_MoEModel`]. It is used to instantiate a + Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 103424): + Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Ernie4_5_MoEModel`] + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 500000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in any of the projections including mlp and attention for example. + moe_intermediate_size (`int`, *optional*, defaults to 1536): + Intermediate size of the routed expert. + moe_k (`int`, *optional*, defaults to 6): + Number of selected experts. + moe_num_experts (`int`, *optional*, defaults to 64): + Number of routed experts. + moe_num_shared_experts (`int`, *optional*, defaults to 2): + The number of experts that are shared for all MoE forwards. + moe_layer_start_index (`int`, *optional*, defaults to 1): + The first index at which MoE layers start to appear. + moe_layer_end_index (`int`, *optional*, defaults to -1): + The last possible index for a MoE layer. + moe_layer_interval (`int`, *optional*, defaults to 1): + The intervals between MoE layers to appear. + moe_norm_min (`float`, *optional*, defaults to 1e-12): + Minimum division value during routing normalization. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import Ernie4_5_MoEModel, Ernie4_5_MoEConfig + + >>> # Initializing a Ernie4_5_MoE style configuration + >>> configuration = Ernie4_5_MoEConfig() + + >>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration + >>> model = Ernie4_5_MoEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie4_5_moe" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_experts": "moe_num_experts", "num_experts_per_tok": "moe_k"} + + # Default tensor parallel plan for base model `Ernie4_5_MoE` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + # sequence parallel is pretty slow + # "norm.weight": "sequence_parallel", + # "layers.*.input_layernorm.weight": "sequence_parallel", + # "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "layers.*.mlp.shared_experts.gate_proj": "local_colwise", + "layers.*.mlp.shared_experts.up_proj": "local_colwise", + "layers.*.mlp.shared_experts.down_proj": "local_rowwise", + "layers.*.mlp.experts.*.gate_proj": "local_colwise", + "layers.*.mlp.experts.*.up_proj": "local_colwise", + "layers.*.mlp.experts.*.down_proj": "local_rowwise", + "layers.*.mlp.experts": "local", + "layers.*.mlp.gate_proj": "local_colwise", + "layers.*.mlp.up_proj": "local_colwise", + "layers.*.mlp.down_proj": "local_rowwise", + "layers.*.mlp": "gather", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=103424, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + hidden_size=2560, + intermediate_size=12288, + num_hidden_layers=28, + num_attention_heads=20, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=True, + rope_theta=500000.0, + rope_scaling=None, + use_bias=False, + moe_intermediate_size=1536, + moe_k=6, + moe_num_experts=64, + moe_num_shared_experts=2, + moe_layer_start_index=1, + moe_layer_end_index=-1, + moe_layer_interval=1, + moe_norm_min=1e-12, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_bias = use_bias + + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.moe_intermediate_size = moe_intermediate_size + self.moe_k = moe_k + self.moe_num_experts = moe_num_experts + self.moe_num_shared_experts = moe_num_shared_experts + self.moe_layer_start_index = moe_layer_start_index + self.moe_layer_end_index = self.num_hidden_layers - 1 if moe_layer_end_index == -1 else moe_layer_end_index + self.moe_layer_interval = moe_layer_interval + self.moe_norm_min = moe_norm_min + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Ernie4_5_MoEConfig"] diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py new file mode 100644 index 000000000000..89b9b7bb1705 --- /dev/null +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -0,0 +1,779 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_ernie4_5_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class Ernie4_5_MoERMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Ernie4_5_MoERMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Ernie4_5_MoEMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Ernie4_5_MoERotaryEmbedding(nn.Module): + def __init__(self, config: Ernie4_5_MoEConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + # keeping it in full precision + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # glm rope style (with full dim) and full precision + original_dtype = q.dtype + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) + k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) + + return q_embed.to(original_dtype), k_embed.to(original_dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Ernie4_5_MoEAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.attention_dropout = 0.0 + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Ernie4_5_MoEStatics(nn.Module): + """ + Stores MoE (Mixture of Experts) statistics + - Bias for the gating + - Additionally, usage per expert in the original codebase + """ + + def __init__(self, config): + super().__init__() + + num_experts_groups = 1 + num_experts = config.moe_num_experts + + self.e_score_correction_bias = nn.Parameter( + torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), + requires_grad=False, + ) + + def forward(self, hidden_states): + # NOTE: This is a workaround to enable TP with a module that only has parameters + # + # Otherwise, it stays as `DTensor` when called in the "super" forward + # 1. All other tensors are local (`torch.Tensor`) + # 2. Isolate does not work on `nn.Module` which only has parameters + return hidden_states + self.e_score_correction_bias.squeeze() + + +class Ernie4_5_MoESparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + + Ernie 4.5 MoE's original formula is based on case (2) with + (optional) shared experts and a corrections bias during gating. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.moe_num_experts + self.top_k = config.moe_k + + # correction bias (yes it seems to be a typo with statics <> statistics) + self.moe_statics = Ernie4_5_MoEStatics(config) + + # gating + self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) + self.experts = nn.ModuleList( + [Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)] + ) + self.norm_min = config.moe_norm_min + + # (optional) shared experts for all forwards + self.shared_experts = None + if config.moe_num_shared_experts > 0: + self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # (Optional) shared experts + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + device_type = ( + hidden_states.device.type + if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states.float()) + + # NOTE: we are using the original code base at + # https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116 + # this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights = self.moe_statics(routing_weights) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / torch.clamp( + routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min + ) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + # Add (optional) shared experts to the result + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5_MoEAttention(config, layer_idx) + + if ( + ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= config.moe_layer_start_index + and layer_idx <= config.moe_layer_end_index + ): + self.mlp = Ernie4_5_MoESparseMoeBlock(config) + else: + self.mlp = Ernie4_5_MoEMLP(config) + + self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class Ernie4_5_MoEPreTrainedModel(PreTrainedModel): + config: Ernie4_5_MoEConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Ernie4_5_MoEDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1), + "hidden_states": Ernie4_5_MoEDecoderLayer, + "attentions": Ernie4_5_MoEAttention, + } + _keep_in_fp32_modules_strict = ["gate", "moe_statics"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Ernie4_5_MoERMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Ernie4_5_MoEStatics): + module.e_score_correction_bias.data.zero_() + + +@auto_docstring +class Ernie4_5_MoEModel(Ernie4_5_MoEPreTrainedModel): + def __init__(self, config: Ernie4_5_MoEConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Ernie4_5_MoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Ernie4_5_MoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Ernie4_5_MoERotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Ernie4_5_MoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = config.moe_k + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["Ernie4_5_MoEForCausalLM", "Ernie4_5_MoEModel", "Ernie4_5_MoEPreTrainedModel"] diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py new file mode 100644 index 000000000000..daf122929bc7 --- /dev/null +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025 Baidu, Inc. and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Ernie 4.5 MoE model.""" + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_outputs import MoeModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs +from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401 +from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm +from ..mixtral.modeling_mixtral import ( + MixtralForCausalLM, + MixtralModel, + MixtralPreTrainedModel, +) +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP +from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig + + +logger = logging.get_logger(__name__) + + +class Ernie4_5_MoERMSNorm(LlamaRMSNorm): + pass + + +class Ernie4_5_MoEMLP(Qwen3MoeMLP): + def __init__(self, config, intermediate_size=None): + super().__init__(config, intermediate_size) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + +class Ernie4_5_MoERotaryEmbedding(Ernie4_5RotaryEmbedding): + pass + + +class Ernie4_5_MoEAttention(LlamaAttention): + def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.attention_dropout = 0.0 + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + +class Ernie4_5_MoEStatics(nn.Module): + """ + Stores MoE (Mixture of Experts) statistics + - Bias for the gating + - Additionally, usage per expert in the original codebase + """ + + def __init__(self, config): + super().__init__() + + num_experts_groups = 1 + num_experts = config.moe_num_experts + + self.e_score_correction_bias = nn.Parameter( + torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), + requires_grad=False, + ) + + def forward(self, hidden_states): + # NOTE: This is a workaround to enable TP with a module that only has parameters + # + # Otherwise, it stays as `DTensor` when called in the "super" forward + # 1. All other tensors are local (`torch.Tensor`) + # 2. Isolate does not work on `nn.Module` which only has parameters + return hidden_states + self.e_score_correction_bias.squeeze() + + +class Ernie4_5_MoESparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + + Ernie 4.5 MoE's original formula is based on case (2) with + (optional) shared experts and a corrections bias during gating. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.moe_num_experts + self.top_k = config.moe_k + + # correction bias (yes it seems to be a typo with statics <> statistics) + self.moe_statics = Ernie4_5_MoEStatics(config) + + # gating + self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) + self.experts = nn.ModuleList( + [Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)] + ) + self.norm_min = config.moe_norm_min + + # (optional) shared experts for all forwards + self.shared_experts = None + if config.moe_num_shared_experts > 0: + self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # (Optional) shared experts + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + device_type = ( + hidden_states.device.type + if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states.float()) + + # NOTE: we are using the original code base at + # https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116 + # this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights = self.moe_statics(routing_weights) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / torch.clamp( + routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min + ) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + # Add (optional) shared experts to the result + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Ernie4_5_MoEDecoderLayer(Qwen3MoeDecoderLayer, nn.Module): + def __init__(self, config, layer_idx): + nn.Module().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Ernie4_5_MoEAttention(config, layer_idx) + + if ( + ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= config.moe_layer_start_index + and layer_idx <= config.moe_layer_end_index + ): + self.mlp = Ernie4_5_MoESparseMoeBlock(config) + else: + self.mlp = Ernie4_5_MoEMLP(config) + + self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) + + +@auto_docstring +class Ernie4_5_MoEPreTrainedModel(MixtralPreTrainedModel): + _keep_in_fp32_modules_strict = ["gate", "moe_statics"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Ernie4_5_MoERMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Ernie4_5_MoEStatics): + module.e_score_correction_bias.data.zero_() + + +@auto_docstring +class Ernie4_5_MoEModel(MixtralModel): + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel): + def __init__(self, config): + Ernie4_5_MoEPreTrainedModel().__init__(config) + self.model = Ernie4_5_MoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = config.moe_k + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + super().forward(**super_kwargs) + + +__all__ = [ + "Ernie4_5_MoEForCausalLM", + "Ernie4_5_MoEModel", + "Ernie4_5_MoEPreTrainedModel", +] diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index b13f824bf7f4..b6479524a948 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -104,9 +104,11 @@ def __init__( is_decoder=False, scope=None, expert_interval=1, + moe_layer_start_index=0, moe_intermediate_size=12, shared_expert_intermediate_size=36, shared_expert_gate=True, + moe_num_shared_experts=2, num_experts_per_tok=2, num_experts=8, mamba_n_groups=1, @@ -146,9 +148,11 @@ def __init__( self.head_dim = self.hidden_size // self.num_attention_heads self.is_decoder = is_decoder self.expert_interval = expert_interval + self.moe_layer_start_index = moe_layer_start_index self.moe_intermediate_size = moe_intermediate_size self.shared_expert_intermediate_size = shared_expert_intermediate_size self.shared_expert_gate = shared_expert_gate + self.moe_num_shared_experts = moe_num_shared_experts self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.mamba_n_groups = mamba_n_groups diff --git a/tests/models/ernie4_5/__init__.py b/tests/models/ernie4_5/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/ernie4_5/test_modeling_ernie4_5.py b/tests/models/ernie4_5/test_modeling_ernie4_5.py new file mode 100644 index 000000000000..1c5bffa2c6d6 --- /dev/null +++ b/tests/models/ernie4_5/test_modeling_ernie4_5.py @@ -0,0 +1,122 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Ernie4.5 model.""" + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import ( + Expectations, + cleanup, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +if is_torch_available(): + import torch + + from transformers import ( + AutoTokenizer, + Ernie4_5Config, + Ernie4_5ForCausalLM, + Ernie4_5Model, + ) + from transformers.models.ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding + + +class Ernie4_5ModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = Ernie4_5Config + base_model_class = Ernie4_5Model + causal_lm_class = Ernie4_5ForCausalLM + + +@require_torch +class Ernie4_5ModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = ( + ( + Ernie4_5Model, + Ernie4_5ForCausalLM, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": Ernie4_5Model, + "text-generation": Ernie4_5ForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = Ernie4_5ModelTester + rotary_embedding_layer = Ernie4_5RotaryEmbedding # Enables RoPE tests if set + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = Ernie4_5ForCausalLM if is_torch_available() else None + + +@require_torch_accelerator +class Ernie4_5IntegrationTest(unittest.TestCase): + def setup(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_ernie4_5_0p3B(self): + """ + An integration test for Ernie 4.5 0.3B. + """ + expected_texts = Expectations( + { + ("cuda", None): "User: Hey, are you conscious? Can you talk to me?\nAssistant: Hey! I'm here to help you with whatever you need. Are you feeling a bit overwhelmed or stressed? I'm here to listen and provide support.", + } + ) # fmt: skip + EXPECTED_TEXT = expected_texts.get_expectation() + + tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3") + model = Ernie4_5ForCausalLM.from_pretrained( + "baidu/ERNIE-4.5-0.3B-PT", + revision="refs/pr/3", + device_map="auto", + torch_dtype=torch.bfloat16, + ) + + prompt = "Hey, are you conscious? Can you talk to me?" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + + generated_ids = model.generate( + model_inputs.input_ids, + max_new_tokens=128, + do_sample=False, + ) + + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip("\n") + self.assertEqual(generated_text, EXPECTED_TEXT) diff --git a/tests/models/ernie4_5_moe/__init__.py b/tests/models/ernie4_5_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py new file mode 100644 index 000000000000..63fb00745c62 --- /dev/null +++ b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py @@ -0,0 +1,199 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Ernie4.5 MoE model.""" + +import tempfile +import unittest + +import pytest + +from transformers import Ernie4_5_MoEConfig, is_torch_available +from transformers.testing_utils import ( + cleanup, + is_flaky, + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_large_accelerator, + require_torch_multi_accelerator, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + AutoTokenizer, + Ernie4_5_MoEForCausalLM, + Ernie4_5_MoEModel, + ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class Ernie4_5_MoEModelTester(CausalLMModelTester): + config_class = Ernie4_5_MoEConfig + if is_torch_available(): + base_model_class = Ernie4_5_MoEModel + causal_lm_class = Ernie4_5_MoEForCausalLM + + +@require_torch +class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = ( + ( + Ernie4_5_MoEModel, + Ernie4_5_MoEForCausalLM, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": Ernie4_5_MoEModel, + "text-generation": Ernie4_5_MoEForCausalLM, + } + if is_torch_available() + else {} + ) + + test_headmasking = False + test_pruning = False + test_all_params_have_gradient = False + model_tester_class = Ernie4_5_MoEModelTester + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @is_flaky() + @slow + def test_flash_attn_2_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(reason="Model does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + # higher tolerance, not sure where it stems from + assert torch.allclose(logits_fa, logits, atol=1e-2, rtol=1e-2) + + # Ignore copy + def test_load_balancing_loss(self): + r""" + Let's make sure we can actually compute the loss and do a backward on it. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.num_experts = 8 + config.expert_interval = 2 + config.output_router_logits = True + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = Ernie4_5_MoEForCausalLM(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask) + self.assertEqual(result.router_logits[0].shape, (91, config.num_experts)) + torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2) + + # First, we make sure that adding padding tokens doesn't change the loss + # loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding) + pad_length = 1000 + # Add padding tokens (assume that pad_token_id=1) to input_ids + padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device) + padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left + padded_attention_mask = padded_input_ids.ne(1).to(torch_device) + + padded_result = model(padded_input_ids, attention_mask=padded_attention_mask) + torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4) + + # We make sure that the loss of including padding tokens != the loss without padding tokens + # if attention_mask=None --> we don't exclude padding tokens + include_padding_result = model(padded_input_ids, attention_mask=None) + + # This is to mimic torch.testing.assert_not_close + self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + + +# Run on runners with larger accelerators (for example A10 instead of T4) with a lot of CPU RAM (e.g. g5-12xlarge) +@require_torch_multi_accelerator +@require_torch_large_accelerator +@require_torch +class Ernie4_5_MoEIntegrationTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = None + + @classmethod + def tearDownClass(cls): + del cls.model + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = Ernie4_5_MoEForCausalLM.from_pretrained( + "baidu/ERNIE-4.5-21B-A3B-PT", + revision="refs/pr/11", + device_map="auto", + load_in_4bit=True, + ) + + return cls.model + + @require_bitsandbytes + @slow + def test_model_21b_a3b_generation(self): + EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: Yes, I am conscious and I can communicate with you. How can I assist you with any questions or information you need?" # fmt: skip + + model = self.get_model() + tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11") + prompt = "Hey, are you conscious? Can you talk to me?" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) + + generated_ids = model.generate( + model_inputs.input_ids, + max_new_tokens=32, + do_sample=False, + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip("\n") + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5a24bcecee1e..5589c8cc0d61 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -258,10 +258,10 @@ def _test_eager_matches_sdpa_inference( model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="sdpa") except ValueError: model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs) - model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") - model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) + model_eager = model_eager.eval().to(torch_device) set_model_for_less_flaky_test(model_eager) set_model_for_less_flaky_test(model_sdpa) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 2307b962044f..08e3f26245a3 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -32,6 +32,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING SPECIAL_CASES_TO_ALLOW = { + "Ernie4_5Config": ["tie_word_embeddings"], + "Ernie4_5_MoEConfig": ["tie_word_embeddings"], "Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"], # used internally during generation to provide the custom logit processors with their necessary information "DiaConfig": [