diff --git a/paddleformers/nn/criterion/interface.py b/paddleformers/nn/criterion/interface.py index 07ca0637c9c..797fee81f53 100644 --- a/paddleformers/nn/criterion/interface.py +++ b/paddleformers/nn/criterion/interface.py @@ -41,8 +41,8 @@ class CriterionLayer(nn.Layer): def __init__(self, config, return_tuple=True, ignore_eos_token=False, use_infohub=False, **kwargs): super().__init__() self.config = config - self.dpo_config = copy.deepcopy(config.get("dpo_config", None)) - self.kto_config = copy.deepcopy(config.get("kto_config", None)) + self.dpo_config = copy.deepcopy(config.dpo_config) if hasattr(config, "dpo_config") else None + self.kto_config = copy.deepcopy(config.kto_config) if hasattr(config, "kto_config") else None self.ignored_index = getattr(config, "ignored_index", -100) self.use_filtered_label_loss = config.get("use_filtered_label_loss", False) self.loss_subbatch_sequence_length = config.get("loss_subbatch_sequence_length", -1) diff --git a/paddleformers/trainer/integrations.py b/paddleformers/trainer/integrations.py index 94503e7d32b..a0dacfacf6c 100644 --- a/paddleformers/trainer/integrations.py +++ b/paddleformers/trainer/integrations.py @@ -17,7 +17,6 @@ # https://github.com/huggingface/transformers/blob/main/src/transformers/integrations.py import importlib -import json import numbers import os import tempfile @@ -131,9 +130,9 @@ def on_train_begin(self, args, state, control, **kwargs): if isinstance(model, PretrainedModel) and model.constructed_from_pretrained_config(): model.config.architectures = [model.__class__.__name__] self.vdl_writer.add_text("model_config", str(model.config)) - elif hasattr(model, "init_config") and model.init_config is not None: - model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2) - self.vdl_writer.add_text("model_config", model_config_json) + # elif hasattr(model, "init_config") and model.init_config is not None: + # model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2) + # self.vdl_writer.add_text("model_config", model_config_json) if hasattr(self.vdl_writer, "add_hparams"): self.vdl_writer.add_hparams(args.to_sanitized_dict(), metrics_list=[]) diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py index 4d740f40d0a..a9951d7bc75 100644 --- a/paddleformers/transformers/auto/modeling.py +++ b/paddleformers/transformers/auto/modeling.py @@ -217,6 +217,9 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file try: model_class = getattr(import_class, init_class) return model_class + except AttributeError: + model_class = getattr(import_class, init_class + "Fleet") + return model_class except AttributeError as err: try: new_import_class = importlib.import_module(f"paddleformers.transformers.{class_name}") diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index 2aa01b526fb..05ad1a4ec1e 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -13,8 +13,9 @@ # limitations under the License. from copy import deepcopy +from dataclasses import dataclass from functools import partial -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union import paddle import paddle.distributed as dist @@ -23,6 +24,9 @@ from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp from paddle.nn import functional as F +from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec + +from paddleformers.transformers.gpt_provider import GPTModelProvider from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.attention.utils import repeat_kv @@ -44,6 +48,24 @@ from ..moe_layer import MoEFlexTokenLayer from .configuration import Glm4MoeConfig +if TYPE_CHECKING: + from paddlefleet.transformer import LayerSpec + + +@dataclass +class GLMMoEModelProvider(GPTModelProvider): + """Base provider for GLM MoE Models.""" + + transformer_layer_spec: Union[ + "LayerSpec", Callable[["GPTModelProvider"], "LayerSpec"] + ] = get_gpt_decoder_block_spec + + moe_router_load_balancing_type: str = "seq_aux_loss" + + gated_linear_unit: bool = True + + bias_activation_fusion: bool = True + def eager_attention_forward( module: nn.Layer, @@ -1241,6 +1263,11 @@ def forward(self, x, position_ids): return cos.cast(dtype=x.dtype), sin.cast(dtype=x.dtype) +@register_base_model +class Glm4MoeModelFleet(Glm4MoePreTrainedModel): + pass + + @register_base_model class Glm4MoeModel(Glm4MoePreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"] @@ -1430,6 +1457,15 @@ def forward( ) +class Glm4MoeForCausalLMFleet(Glm4MoePreTrainedModel): + is_fleet = True + + def __new__(cls, config): + model_provider_class = GLMMoEModelProvider + model_provider = model_provider_class.from_config(config) + return model_provider.provide() + + class Glm4MoeForCausalLM(Glm4MoePreTrainedModel): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1602,6 +1638,10 @@ def forward(self, args): return ret +class Glm4MoeForCausalLMPipeFleet(GeneralModelForCausalLMPipe): + pass + + class Glm4MoeForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = Glm4MoeConfig _decoder_layer_cls = Glm4MoeDecoderLayer @@ -1617,4 +1657,11 @@ class Glm4MoeForCausalLMPipe(GeneralModelForCausalLMPipe): _gen_inv_aoa_config = Glm4MoeForCausalLM._gen_inv_aoa_config -__all__ = ["Glm4MoeForCausalLMPipe", "Glm4MoeModel", "Glm4MoeForCausalLM"] +__all__ = [ + "Glm4MoeForCausalLMPipeFleet", + "Glm4MoeModelFleet", + "Glm4MoeForCausalLMFleet", + "Glm4MoeForCausalLMPipe", + "Glm4MoeModel", + "Glm4MoeForCausalLM", +] diff --git a/examples/experiments/paddlefleet/gpt_provider.py b/paddleformers/transformers/gpt_provider.py similarity index 95% rename from examples/experiments/paddlefleet/gpt_provider.py rename to paddleformers/transformers/gpt_provider.py index 2c854eb9798..27f3594351e 100644 --- a/examples/experiments/paddlefleet/gpt_provider.py +++ b/paddleformers/transformers/gpt_provider.py @@ -23,17 +23,26 @@ from typing import Any, Callable, Literal, Optional, Union import paddle -from model_provider import ModelProviderMixin -from paddlefleet import parallel_state -from paddlefleet.models.gpt import GPTModel +from paddlefleet import LayerSpec, parallel_state +from paddlefleet.models.gpt import GPTModel as FleetGPTModel from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec -from paddlefleet.spec_utils import LayerSpec from paddlefleet.transformer.transformer_config import TransformerConfig -from vocab_utils import calculate_padded_vocab_size + +from paddleformers.transformers.model_utils import PretrainedModel + +from .model_provider import ModelProviderMixin +from .vocab_utils import calculate_padded_vocab_size logger = logging.getLogger(__name__) +class GPTModel(FleetGPTModel, PretrainedModel): + pass + + +# GPTModel = FleetGPTModel + + def local_layer_spec(config: "GPTModelProvider") -> LayerSpec: """Create a local layer specification without Transformer Engine. @@ -44,7 +53,7 @@ def local_layer_spec(config: "GPTModelProvider") -> LayerSpec: LayerSpec: Module specification for local implementation layers """ return get_gpt_layer_local_spec( - num_experts=config.moe_num_experts, + num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm, normalization=config.normalization, @@ -64,7 +73,7 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]): parallel_output: bool = True share_embeddings_and_output_weights: bool = True make_vocab_size_divisible_by: int = 128 - position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" + position_embedding_type: Literal["learned_absolute", "rope"] = "rope" rotary_base: int = 10000 rotary_percent: float = 1.0 seq_len_interpolation_factor: Optional[float] = None @@ -92,7 +101,7 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]): moe_grouped_gemm: bool = False use_qk_norm: bool = False fp8: Optional[str] = None - normalization: str = "LayerNorm" + normalization: str = "RMSNorm" # Multi-token prediction mtp_enabled: bool = False diff --git a/examples/experiments/paddlefleet/model_provider.py b/paddleformers/transformers/model_provider.py similarity index 100% rename from examples/experiments/paddlefleet/model_provider.py rename to paddleformers/transformers/model_provider.py diff --git a/examples/experiments/paddlefleet/vocab_utils.py b/paddleformers/transformers/vocab_utils.py similarity index 100% rename from examples/experiments/paddlefleet/vocab_utils.py rename to paddleformers/transformers/vocab_utils.py