-
Notifications
You must be signed in to change notification settings - Fork 30.5k
[Ernie 4.5
] Add ernie text models
#39228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
055fa74
init
vasqu cf31849
copied from remote
vasqu 69d942c
add proper structure and llama like structure
vasqu 0fa32aa
fixup
vasqu 2fc44ee
revert to state that works
vasqu 39176fd
get closer to llama
vasqu f9d5789
slow and steady
vasqu de53a33
some removal
vasqu 999e7e9
masks work
vasqu 1816efb
it is indeed the rope implementation, how dafuq does it mesh with the…
vasqu d0c5877
nice
vasqu 0361822
getting closer
vasqu 0682085
closer to transformers style
vasqu a551f22
let's simplify this, batching works now
vasqu 4a9c20b
simplified
vasqu 393c2c7
working version with modular
vasqu 936aa04
it is indeed the rotation per weights, make it complete llama style
vasqu b33edd1
cleanup conversion, next to look at -> tokenizer
vasqu 4bcb7f0
remove llama artefacts
vasqu 4129506
fix modeling tests (common ones)
vasqu e837bf5
style
vasqu be5a7b0
integration test + first look into tokenization (will need more work,…
vasqu f98c8d9
style
vasqu de0389f
working moe version, based on remote
vasqu 1a41719
lets keep it simple and go step by step - transformers annotations fo…
vasqu 484da46
more cleanup
vasqu 5627e5a
refactor namings and remove addition forXXX classes
vasqu a146b13
our moe won't cut it it seems, correction bias seems to be missing in…
vasqu d1b1144
tokenization change (remote)
vasqu 54749e5
our moe version works when adding normalization :D
vasqu c81082f
cleanup moe
vasqu 0efb8a7
nits
vasqu 8132cf4
cleanup modeling -> let's get to modular next
vasqu 0f47d74
style
vasqu be22c9a
modular v1
vasqu 522b8d9
minor things + attempt at conversion (which doesn't work)
vasqu 016bc8e
no conversion follow glm, fixup modular and other nits
vasqu ac62ca7
modular cleanup
vasqu 1f185e2
fixes
vasqu d51b257
tests, tests, tests + some moe dtype forcing
vasqu 72f1627
simplify modular, fix fatal fa2 bug, remaining tests
vasqu ba4b05e
fix import issue?
vasqu cce6eaa
some initial docs, fix bnb faulty behavior --> needs to fix some test…
vasqu 64fc88a
fix sdpa test, load on init dtype only
vasqu e9eab1f
Merge branch 'main' into ernie4_5
vasqu cf37555
fixup post merge
vasqu 76dc3f9
style
vasqu 654bcd3
fix doc links
vasqu c595c3e
Merge branch 'main' into ernie4_5
vasqu 694837a
Merge branch 'main' into ernie4_5
vasqu 40a5ade
tokenization cleanup beginnings
vasqu b0bb530
simplify tokenizer by a lot as its basically llama
vasqu 090f78e
tokenizer is full llama with different defaults + extra special tokens
vasqu 2910168
sync og special tokens of ernie
vasqu 90a6933
fix decoding with numbers (also in remote done what a timing), begin …
vasqu 68e3c94
align with remote and preserve special tokens, adjust tests to ernie …
vasqu 98a3759
nits
vasqu 1eb7b9e
docs
vasqu f6d768e
Merge branch 'main' into ernie4_5
vasqu 157c622
my daily post merge it is
vasqu 4bcd77a
check
vasqu 8253906
tokenization update with explanations and conversion script
vasqu c91193a
review on modular (til), revert some tokenizer things i did prior, re…
vasqu c22e0c2
Merge branch 'main' into ernie4_5
vasqu b8506a4
post merge fixes
vasqu 49b4639
fixup tokenization, llama fast is the way to go
vasqu ad76c7c
more fixups
vasqu c08509d
check
vasqu 3c0014a
import fixes
vasqu 2948aed
correction bias following the paddle code
vasqu 81bafc6
fix
vasqu f42e312
Merge branch 'main' into ernie4_5
vasqu 508b683
fix TP plan, fix correction bias sharding during forward
vasqu 624820c
Merge branch 'main' into ernie4_5
vasqu b8d8f5c
style
vasqu c72e4d9
whoops
vasqu b9f8db7
fix tied weights
vasqu 680353e
Merge branch 'main' into ernie4_5
vasqu 60ebe11
docs and last nit
vasqu a463588
license
vasqu 416cda6
Merge branch 'main' into ernie4_5
vasqu f31f41d
flasky tests
vasqu 961423d
move repo id, update when merged on the hub
vasqu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
<div style="float: right;"> | ||
<div class="flex flex-wrap space-x-1"> | ||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white"> | ||
</div> | ||
</div> | ||
|
||
# Ernie 4.5 | ||
|
||
## Overview | ||
|
||
The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu. | ||
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text | ||
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core. | ||
|
||
Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md). | ||
|
||
<div class="flex justify-center"> | ||
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/> | ||
</div> | ||
|
||
|
||
## Usage Tips | ||
|
||
### Generate text | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_name = "baidu/ERNIE-4.5-0.3B-PT" | ||
|
||
# load the tokenizer and the model | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
# prepare the model input | ||
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") | ||
prompt = "Hey, are you conscious? Can you talk to me?" | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
text = tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False, | ||
add_generation_prompt=True | ||
) | ||
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) | ||
|
||
# conduct text completion | ||
generated_ids = model.generate( | ||
**model_inputs, | ||
max_new_tokens=32, | ||
) | ||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | ||
|
||
# decode the generated ids | ||
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) | ||
``` | ||
|
||
This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV). | ||
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE). | ||
|
||
|
||
## Ernie4_5Config | ||
|
||
[[autodoc]] Ernie4_5Config | ||
|
||
## Ernie4_5Model | ||
|
||
[[autodoc]] Ernie4_5Model | ||
- forward | ||
|
||
## Ernie4_5ForCausalLM | ||
|
||
[[autodoc]] Ernie4_5ForCausalLM | ||
- forward |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
<div style="float: right;"> | ||
<div class="flex flex-wrap space-x-1"> | ||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white"> | ||
</div> | ||
</div> | ||
|
||
# Ernie 4.5 MoE | ||
|
||
## Overview | ||
|
||
The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu. | ||
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text | ||
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters. | ||
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared | ||
experts. | ||
|
||
Other models from the family can be found at [Ernie 4.5](./ernie4_5.md). | ||
|
||
<div class="flex justify-center"> | ||
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/> | ||
</div> | ||
|
||
|
||
## Usage Tips | ||
|
||
### Generate text | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_name = "baidu/ERNIE-4.5-21B-A3B-PT" | ||
|
||
# load the tokenizer and the model | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
# prepare the model input | ||
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") | ||
prompt = "Hey, are you conscious? Can you talk to me?" | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
text = tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False, | ||
add_generation_prompt=True | ||
) | ||
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) | ||
|
||
# conduct text completion | ||
generated_ids = model.generate( | ||
**model_inputs, | ||
max_new_tokens=32, | ||
) | ||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | ||
|
||
# decode the generated ids | ||
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) | ||
``` | ||
|
||
### Distributed Generation with Tensor Parallelism | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_name = "baidu/ERNIE-4.5-21B-A3B-PT" | ||
|
||
# load the tokenizer and the model | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
tp_plan="auto", | ||
) | ||
|
||
# prepare the model input | ||
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") | ||
prompt = "Hey, are you conscious? Can you talk to me?" | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
text = tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False, | ||
add_generation_prompt=True | ||
) | ||
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) | ||
|
||
# conduct text completion | ||
generated_ids = model.generate( | ||
**model_inputs, | ||
max_new_tokens=32, | ||
) | ||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | ||
|
||
# decode the generated ids | ||
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) | ||
``` | ||
|
||
### Quantization with Bitsandbytes | ||
|
||
```python | ||
import torch | ||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_name = "baidu/ERNIE-4.5-21B-A3B-PT" | ||
|
||
# load the tokenizer and the model | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
quantization_config=BitsAndBytesConfig(load_in_4bit=True), | ||
) | ||
|
||
# prepare the model input | ||
inputs = tokenizer("Hey, are you conscious? Can you talk to me?", return_tensors="pt") | ||
prompt = "Hey, are you conscious? Can you talk to me?" | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
text = tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False, | ||
add_generation_prompt=True | ||
) | ||
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device) | ||
|
||
# conduct text completion | ||
generated_ids = model.generate( | ||
**model_inputs, | ||
max_new_tokens=32, | ||
) | ||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | ||
|
||
# decode the generated ids | ||
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True) | ||
``` | ||
|
||
This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV). | ||
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE). | ||
|
||
|
||
## Ernie4_5_MoEConfig | ||
|
||
[[autodoc]] Ernie4_5_MoEConfig | ||
|
||
## Ernie4_5_MoEModel | ||
|
||
[[autodoc]] Ernie4_5_MoEModel | ||
- forward | ||
|
||
## Ernie4_5_MoEForCausalLM | ||
|
||
[[autodoc]] Ernie4_5_MoEForCausalLM | ||
- forward | ||
- generate |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright 2025 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_ernie4_5 import * | ||
from .modeling_ernie4_5 import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay! makes sense!