-
Notifications
You must be signed in to change notification settings - Fork 2.2k
FleetModel Dpo, AutoModel => FleetModel. #3024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef9f520
958568e
b6f6b0f
fab4e3b
076aedf
27eb0d2
d148f4b
822119e
4e7c827
40b927c
4a5a543
bab7f56
25f03c7
468ebc3
05108cf
ac134c8
5bac4ff
4668620
f58c6da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
| # https://github.com/huggingface/transformers/blob/main/src/transformers/integrations.py | ||
|
|
||
| import importlib | ||
| import json | ||
| import numbers | ||
| import os | ||
| import tempfile | ||
|
|
@@ -131,9 +130,9 @@ def on_train_begin(self, args, state, control, **kwargs): | |
| if isinstance(model, PretrainedModel) and model.constructed_from_pretrained_config(): | ||
| model.config.architectures = [model.__class__.__name__] | ||
| self.vdl_writer.add_text("model_config", str(model.config)) | ||
| elif hasattr(model, "init_config") and model.init_config is not None: | ||
| model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2) | ||
| self.vdl_writer.add_text("model_config", model_config_json) | ||
| # elif hasattr(model, "init_config") and model.init_config is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么要注释掉 |
||
| # model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2) | ||
| # self.vdl_writer.add_text("model_config", model_config_json) | ||
|
|
||
| if hasattr(self.vdl_writer, "add_hparams"): | ||
| self.vdl_writer.add_hparams(args.to_sanitized_dict(), metrics_list=[]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -217,6 +217,9 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file | |
| try: | ||
| model_class = getattr(import_class, init_class) | ||
| return model_class | ||
| except AttributeError: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 临时写法的话,这里写个注释吧 |
||
| model_class = getattr(import_class, init_class + "Fleet") | ||
| return model_class | ||
| except AttributeError as err: | ||
| try: | ||
| new_import_class = importlib.import_module(f"paddleformers.transformers.{class_name}") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,9 @@ | |
| # limitations under the License. | ||
|
|
||
| from copy import deepcopy | ||
| from dataclasses import dataclass | ||
| from functools import partial | ||
| from typing import Optional, Tuple, Union | ||
| from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union | ||
|
|
||
| import paddle | ||
| import paddle.distributed as dist | ||
|
|
@@ -23,6 +24,9 @@ | |
| from paddle.distributed.fleet.utils import recompute | ||
| from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp | ||
| from paddle.nn import functional as F | ||
| from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec | ||
|
|
||
| from paddleformers.transformers.gpt_provider import GPTModelProvider | ||
|
|
||
| from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS | ||
| from ...nn.attention.utils import repeat_kv | ||
|
|
@@ -44,6 +48,24 @@ | |
| from ..moe_layer import MoEFlexTokenLayer | ||
| from .configuration import Glm4MoeConfig | ||
|
|
||
| if TYPE_CHECKING: | ||
| from paddlefleet.transformer import LayerSpec | ||
|
|
||
|
|
||
| @dataclass | ||
| class GLMMoEModelProvider(GPTModelProvider): | ||
| """Base provider for GLM MoE Models.""" | ||
|
|
||
| transformer_layer_spec: Union[ | ||
| "LayerSpec", Callable[["GPTModelProvider"], "LayerSpec"] | ||
| ] = get_gpt_decoder_block_spec | ||
|
|
||
| moe_router_load_balancing_type: str = "seq_aux_loss" | ||
|
|
||
| gated_linear_unit: bool = True | ||
|
|
||
| bias_activation_fusion: bool = True | ||
|
|
||
|
|
||
| def eager_attention_forward( | ||
| module: nn.Layer, | ||
|
|
@@ -1241,6 +1263,11 @@ def forward(self, x, position_ids): | |
| return cos.cast(dtype=x.dtype), sin.cast(dtype=x.dtype) | ||
|
|
||
|
|
||
| @register_base_model | ||
| class Glm4MoeModelFleet(Glm4MoePreTrainedModel): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要用Fleet后缀吗? |
||
| pass | ||
|
|
||
|
|
||
| @register_base_model | ||
| class Glm4MoeModel(Glm4MoePreTrainedModel): | ||
| _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"] | ||
|
|
@@ -1430,6 +1457,15 @@ def forward( | |
| ) | ||
|
|
||
|
|
||
| class Glm4MoeForCausalLMFleet(Glm4MoePreTrainedModel): | ||
| is_fleet = True | ||
|
|
||
| def __new__(cls, config): | ||
| model_provider_class = GLMMoEModelProvider | ||
| model_provider = model_provider_class.from_config(config) | ||
| return model_provider.provide() | ||
|
|
||
|
|
||
| class Glm4MoeForCausalLM(Glm4MoePreTrainedModel): | ||
| _tied_weights_keys = ["lm_head.weight"] | ||
| _tp_plan = {"lm_head": "colwise_rep"} | ||
|
|
@@ -1602,6 +1638,10 @@ def forward(self, args): | |
| return ret | ||
|
|
||
|
|
||
| class Glm4MoeForCausalLMPipeFleet(GeneralModelForCausalLMPipe): | ||
| pass | ||
|
|
||
|
|
||
| class Glm4MoeForCausalLMPipe(GeneralModelForCausalLMPipe): | ||
| config_class = Glm4MoeConfig | ||
| _decoder_layer_cls = Glm4MoeDecoderLayer | ||
|
|
@@ -1617,4 +1657,11 @@ class Glm4MoeForCausalLMPipe(GeneralModelForCausalLMPipe): | |
| _gen_inv_aoa_config = Glm4MoeForCausalLM._gen_inv_aoa_config | ||
|
|
||
|
|
||
| __all__ = ["Glm4MoeForCausalLMPipe", "Glm4MoeModel", "Glm4MoeForCausalLM"] | ||
| __all__ = [ | ||
| "Glm4MoeForCausalLMPipeFleet", | ||
| "Glm4MoeModelFleet", | ||
| "Glm4MoeForCausalLMFleet", | ||
| "Glm4MoeForCausalLMPipe", | ||
| "Glm4MoeModel", | ||
| "Glm4MoeForCausalLM", | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要这么改,原本的写法没问题