Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
055fa74
init
vasqu Jul 3, 2025
cf31849
copied from remote
vasqu Jul 3, 2025
69d942c
add proper structure and llama like structure
vasqu Jul 3, 2025
0fa32aa
fixup
vasqu Jul 3, 2025
2fc44ee
revert to state that works
vasqu Jul 3, 2025
39176fd
get closer to llama
vasqu Jul 3, 2025
f9d5789
slow and steady
vasqu Jul 3, 2025
de53a33
some removal
vasqu Jul 3, 2025
999e7e9
masks work
vasqu Jul 3, 2025
1816efb
it is indeed the rope implementation, how dafuq does it mesh with the…
vasqu Jul 4, 2025
d0c5877
nice
vasqu Jul 4, 2025
0361822
getting closer
vasqu Jul 4, 2025
0682085
closer to transformers style
vasqu Jul 4, 2025
a551f22
let's simplify this, batching works now
vasqu Jul 4, 2025
4a9c20b
simplified
vasqu Jul 4, 2025
393c2c7
working version with modular
vasqu Jul 4, 2025
936aa04
it is indeed the rotation per weights, make it complete llama style
vasqu Jul 4, 2025
b33edd1
cleanup conversion, next to look at -> tokenizer
vasqu Jul 4, 2025
4bcb7f0
remove llama artefacts
vasqu Jul 4, 2025
4129506
fix modeling tests (common ones)
vasqu Jul 4, 2025
e837bf5
style
vasqu Jul 4, 2025
be5a7b0
integration test + first look into tokenization (will need more work,…
vasqu Jul 7, 2025
f98c8d9
style
vasqu Jul 7, 2025
de0389f
working moe version, based on remote
vasqu Jul 7, 2025
1a41719
lets keep it simple and go step by step - transformers annotations fo…
vasqu Jul 7, 2025
484da46
more cleanup
vasqu Jul 7, 2025
5627e5a
refactor namings and remove addition forXXX classes
vasqu Jul 7, 2025
a146b13
our moe won't cut it it seems, correction bias seems to be missing in…
vasqu Jul 7, 2025
d1b1144
tokenization change (remote)
vasqu Jul 8, 2025
54749e5
our moe version works when adding normalization :D
vasqu Jul 8, 2025
c81082f
cleanup moe
vasqu Jul 8, 2025
0efb8a7
nits
vasqu Jul 8, 2025
8132cf4
cleanup modeling -> let's get to modular next
vasqu Jul 8, 2025
0f47d74
style
vasqu Jul 8, 2025
be22c9a
modular v1
vasqu Jul 8, 2025
522b8d9
minor things + attempt at conversion (which doesn't work)
vasqu Jul 8, 2025
016bc8e
no conversion follow glm, fixup modular and other nits
vasqu Jul 9, 2025
ac62ca7
modular cleanup
vasqu Jul 9, 2025
1f185e2
fixes
vasqu Jul 9, 2025
d51b257
tests, tests, tests + some moe dtype forcing
vasqu Jul 9, 2025
72f1627
simplify modular, fix fatal fa2 bug, remaining tests
vasqu Jul 9, 2025
ba4b05e
fix import issue?
vasqu Jul 9, 2025
cce6eaa
some initial docs, fix bnb faulty behavior --> needs to fix some test…
vasqu Jul 9, 2025
64fc88a
fix sdpa test, load on init dtype only
vasqu Jul 9, 2025
e9eab1f
Merge branch 'main' into ernie4_5
vasqu Jul 10, 2025
cf37555
fixup post merge
vasqu Jul 10, 2025
76dc3f9
style
vasqu Jul 10, 2025
654bcd3
fix doc links
vasqu Jul 10, 2025
c595c3e
Merge branch 'main' into ernie4_5
vasqu Jul 10, 2025
694837a
Merge branch 'main' into ernie4_5
vasqu Jul 10, 2025
40a5ade
tokenization cleanup beginnings
vasqu Jul 14, 2025
b0bb530
simplify tokenizer by a lot as its basically llama
vasqu Jul 14, 2025
090f78e
tokenizer is full llama with different defaults + extra special tokens
vasqu Jul 15, 2025
2910168
sync og special tokens of ernie
vasqu Jul 15, 2025
90a6933
fix decoding with numbers (also in remote done what a timing), begin …
vasqu Jul 15, 2025
68e3c94
align with remote and preserve special tokens, adjust tests to ernie …
vasqu Jul 15, 2025
98a3759
nits
vasqu Jul 15, 2025
1eb7b9e
docs
vasqu Jul 16, 2025
f6d768e
Merge branch 'main' into ernie4_5
vasqu Jul 16, 2025
157c622
my daily post merge it is
vasqu Jul 16, 2025
4bcd77a
check
vasqu Jul 16, 2025
8253906
tokenization update with explanations and conversion script
vasqu Jul 17, 2025
c91193a
review on modular (til), revert some tokenizer things i did prior, re…
vasqu Jul 17, 2025
c22e0c2
Merge branch 'main' into ernie4_5
vasqu Jul 17, 2025
b8506a4
post merge fixes
vasqu Jul 17, 2025
49b4639
fixup tokenization, llama fast is the way to go
vasqu Jul 17, 2025
ad76c7c
more fixups
vasqu Jul 17, 2025
c08509d
check
vasqu Jul 17, 2025
3c0014a
import fixes
vasqu Jul 17, 2025
2948aed
correction bias following the paddle code
vasqu Jul 18, 2025
81bafc6
fix
vasqu Jul 18, 2025
f42e312
Merge branch 'main' into ernie4_5
vasqu Jul 18, 2025
508b683
fix TP plan, fix correction bias sharding during forward
vasqu Jul 18, 2025
624820c
Merge branch 'main' into ernie4_5
vasqu Jul 18, 2025
b8d8f5c
style
vasqu Jul 18, 2025
c72e4d9
whoops
vasqu Jul 18, 2025
b9f8db7
fix tied weights
vasqu Jul 21, 2025
680353e
Merge branch 'main' into ernie4_5
vasqu Jul 21, 2025
60ebe11
docs and last nit
vasqu Jul 21, 2025
a463588
license
vasqu Jul 21, 2025
416cda6
Merge branch 'main' into ernie4_5
vasqu Jul 21, 2025
f31f41d
flasky tests
vasqu Jul 21, 2025
961423d
move repo id, update when merged on the hub
vasqu Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,10 @@
title: Encoder Decoder Models
- local: model_doc/ernie
title: ERNIE
- local: model_doc/ernie4_5
title: Ernie4_5
- local: model_doc/ernie4_5_moe
title: Ernie4_5_MoE
- local: model_doc/ernie_m
title: ErnieM
- local: model_doc/esm
Expand Down
99 changes: 99 additions & 0 deletions docs/source/en/model_doc/ernie4_5.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
</div>
</div>

# Ernie 4.5

## Overview

The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.

Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md).

<div class="flex justify-center">
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
</div>


## Usage Tips

### Generate text

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "baidu/ERNIE-4.5-0.3B-PT"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)

# prepare the model input
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32,
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

# decode the generated ids
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
```

This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).


## Ernie4_5Config

[[autodoc]] Ernie4_5Config

## Ernie4_5Model

[[autodoc]] Ernie4_5Model
- forward

## Ernie4_5ForCausalLM

[[autodoc]] Ernie4_5ForCausalLM
- forward
183 changes: 183 additions & 0 deletions docs/source/en/model_doc/ernie4_5_moe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
</div>
</div>

# Ernie 4.5 MoE

## Overview

The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
experts.

Other models from the family can be found at [Ernie 4.5](./ernie4_5.md).

<div class="flex justify-center">
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
</div>


## Usage Tips

### Generate text

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "baidu/ERNIE-4.5-21B-A3B-PT"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)

# prepare the model input
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32,
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

# decode the generated ids
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
```

### Distributed Generation with Tensor Parallelism

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "baidu/ERNIE-4.5-21B-A3B-PT"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
tp_plan="auto",
)

# prepare the model input
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32,
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

# decode the generated ids
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
```

### Quantization with Bitsandbytes

```python
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer

model_name = "baidu/ERNIE-4.5-21B-A3B-PT"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)

# prepare the model input
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt")
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32,
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

# decode the generated ids
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True)
```

This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).


## Ernie4_5_MoEConfig

[[autodoc]] Ernie4_5_MoEConfig

## Ernie4_5_MoEModel

[[autodoc]] Ernie4_5_MoEModel
- forward

## Ernie4_5_MoEForCausalLM

[[autodoc]] Ernie4_5_MoEForCausalLM
- forward
- generate
11 changes: 11 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,6 +2977,17 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
else:
output_embeddings.weight = input_embeddings.weight

# Passing hooks over to the embeddings if needed
# (currently limited to tensor parallel hooks and flags only)
if hasattr(input_embeddings, "_is_hooked") and getattr(input_embeddings, "_hf_tp_plan", None):
output_embeddings._is_hooked = input_embeddings._is_hooked
output_embeddings._hf_tp_plan = input_embeddings._hf_tp_plan
output_embeddings._forward_hooks = input_embeddings._forward_hooks
output_embeddings._forward_pre_hooks = input_embeddings._forward_pre_hooks
output_embeddings.__repr__ = (
lambda: f"{output_embeddings.__repr__()}\nTP Plan: {output_embeddings._hf_tp_plan}"
)
Comment on lines +2980 to +2989
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay! makes sense!


if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@
("encoder-decoder", "EncoderDecoderConfig"),
("eomt", "EomtConfig"),
("ernie", "ErnieConfig"),
("ernie4_5", "Ernie4_5Config"),
("ernie4_5_moe", "Ernie4_5_MoEConfig"),
("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"),
("falcon", "FalconConfig"),
Expand Down Expand Up @@ -520,6 +522,8 @@
("encoder-decoder", "Encoder decoder"),
("eomt", "EoMT"),
("ernie", "ERNIE"),
("ernie4_5", "Ernie4_5"),
("ernie4_5_moe", "Ernie4_5_MoE"),
("ernie_m", "ErnieM"),
("esm", "ESM"),
("falcon", "Falcon"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@
("emu3", "Emu3Model"),
("encodec", "EncodecModel"),
("ernie", "ErnieModel"),
("ernie4_5", "Ernie4_5Model"),
("ernie4_5_moe", "Ernie4_5_MoEModel"),
("ernie_m", "ErnieMModel"),
("esm", "EsmModel"),
("falcon", "FalconModel"),
Expand Down Expand Up @@ -594,6 +596,8 @@
("electra", "ElectraForCausalLM"),
("emu3", "Emu3ForCausalLM"),
("ernie", "ErnieForCausalLM"),
("ernie4_5", "Ernie4_5ForCausalLM"),
("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"),
("falcon", "FalconForCausalLM"),
("falcon_h1", "FalconH1ForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
("esm", ("EsmTokenizer", None)),
("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/ernie4_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_ernie4_5 import *
from .modeling_ernie4_5 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading