Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddleformers/nn/criterion/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class CriterionLayer(nn.Layer):
def __init__(self, config, return_tuple=True, ignore_eos_token=False, use_infohub=False, **kwargs):
super().__init__()
self.config = config
self.dpo_config = copy.deepcopy(config.get("dpo_config", None))
self.kto_config = copy.deepcopy(config.get("kto_config", None))
self.dpo_config = copy.deepcopy(config.dpo_config) if hasattr(config, "dpo_config") else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要这么改,原本的写法没问题

self.kto_config = copy.deepcopy(config.kto_config) if hasattr(config, "kto_config") else None
self.ignored_index = getattr(config, "ignored_index", -100)
self.use_filtered_label_loss = config.get("use_filtered_label_loss", False)
self.loss_subbatch_sequence_length = config.get("loss_subbatch_sequence_length", -1)
Expand Down
7 changes: 3 additions & 4 deletions paddleformers/trainer/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/integrations.py

import importlib
import json
import numbers
import os
import tempfile
Expand Down Expand Up @@ -131,9 +130,9 @@ def on_train_begin(self, args, state, control, **kwargs):
if isinstance(model, PretrainedModel) and model.constructed_from_pretrained_config():
model.config.architectures = [model.__class__.__name__]
self.vdl_writer.add_text("model_config", str(model.config))
elif hasattr(model, "init_config") and model.init_config is not None:
model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2)
self.vdl_writer.add_text("model_config", model_config_json)
# elif hasattr(model, "init_config") and model.init_config is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要注释掉

# model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2)
# self.vdl_writer.add_text("model_config", model_config_json)

if hasattr(self.vdl_writer, "add_hparams"):
self.vdl_writer.add_hparams(args.to_sanitized_dict(), metrics_list=[])
Expand Down
3 changes: 3 additions & 0 deletions paddleformers/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file
try:
model_class = getattr(import_class, init_class)
return model_class
except AttributeError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

临时写法的话,这里写个注释吧

model_class = getattr(import_class, init_class + "Fleet")
return model_class
except AttributeError as err:
try:
new_import_class = importlib.import_module(f"paddleformers.transformers.{class_name}")
Expand Down
51 changes: 49 additions & 2 deletions paddleformers/transformers/glm4_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

import paddle
import paddle.distributed as dist
Expand All @@ -23,6 +24,9 @@
from paddle.distributed.fleet.utils import recompute
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp
from paddle.nn import functional as F
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec

from paddleformers.transformers.gpt_provider import GPTModelProvider

from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
from ...nn.attention.utils import repeat_kv
Expand All @@ -44,6 +48,24 @@
from ..moe_layer import MoEFlexTokenLayer
from .configuration import Glm4MoeConfig

if TYPE_CHECKING:
from paddlefleet.transformer import LayerSpec


@dataclass
class GLMMoEModelProvider(GPTModelProvider):
"""Base provider for GLM MoE Models."""

transformer_layer_spec: Union[
"LayerSpec", Callable[["GPTModelProvider"], "LayerSpec"]
] = get_gpt_decoder_block_spec

moe_router_load_balancing_type: str = "seq_aux_loss"

gated_linear_unit: bool = True

bias_activation_fusion: bool = True


def eager_attention_forward(
module: nn.Layer,
Expand Down Expand Up @@ -1241,6 +1263,11 @@ def forward(self, x, position_ids):
return cos.cast(dtype=x.dtype), sin.cast(dtype=x.dtype)


@register_base_model
class Glm4MoeModelFleet(Glm4MoePreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要用Fleet后缀吗?

pass


@register_base_model
class Glm4MoeModel(Glm4MoePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"]
Expand Down Expand Up @@ -1430,6 +1457,15 @@ def forward(
)


class Glm4MoeForCausalLMFleet(Glm4MoePreTrainedModel):
is_fleet = True

def __new__(cls, config):
model_provider_class = GLMMoEModelProvider
model_provider = model_provider_class.from_config(config)
return model_provider.provide()


class Glm4MoeForCausalLM(Glm4MoePreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
Expand Down Expand Up @@ -1602,6 +1638,10 @@ def forward(self, args):
return ret


class Glm4MoeForCausalLMPipeFleet(GeneralModelForCausalLMPipe):
pass


class Glm4MoeForCausalLMPipe(GeneralModelForCausalLMPipe):
config_class = Glm4MoeConfig
_decoder_layer_cls = Glm4MoeDecoderLayer
Expand All @@ -1617,4 +1657,11 @@ class Glm4MoeForCausalLMPipe(GeneralModelForCausalLMPipe):
_gen_inv_aoa_config = Glm4MoeForCausalLM._gen_inv_aoa_config


__all__ = ["Glm4MoeForCausalLMPipe", "Glm4MoeModel", "Glm4MoeForCausalLM"]
__all__ = [
"Glm4MoeForCausalLMPipeFleet",
"Glm4MoeModelFleet",
"Glm4MoeForCausalLMFleet",
"Glm4MoeForCausalLMPipe",
"Glm4MoeModel",
"Glm4MoeForCausalLM",
]
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@
from typing import Any, Callable, Literal, Optional, Union

import paddle
from model_provider import ModelProviderMixin
from paddlefleet import parallel_state
from paddlefleet.models.gpt import GPTModel
from paddlefleet import LayerSpec, parallel_state
from paddlefleet.models.gpt import GPTModel as FleetGPTModel
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from paddlefleet.spec_utils import LayerSpec
from paddlefleet.transformer.transformer_config import TransformerConfig
from vocab_utils import calculate_padded_vocab_size

from paddleformers.transformers.model_utils import PretrainedModel

from .model_provider import ModelProviderMixin
from .vocab_utils import calculate_padded_vocab_size

logger = logging.getLogger(__name__)


class GPTModel(FleetGPTModel, PretrainedModel):
pass


# GPTModel = FleetGPTModel


def local_layer_spec(config: "GPTModelProvider") -> LayerSpec:
"""Create a local layer specification without Transformer Engine.

Expand All @@ -44,7 +53,7 @@ def local_layer_spec(config: "GPTModelProvider") -> LayerSpec:
LayerSpec: Module specification for local implementation layers
"""
return get_gpt_layer_local_spec(
num_experts=config.moe_num_experts,
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
normalization=config.normalization,
Expand All @@ -64,7 +73,7 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]):
parallel_output: bool = True
share_embeddings_and_output_weights: bool = True
make_vocab_size_divisible_by: int = 128
position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute"
position_embedding_type: Literal["learned_absolute", "rope"] = "rope"
rotary_base: int = 10000
rotary_percent: float = 1.0
seq_len_interpolation_factor: Optional[float] = None
Expand Down Expand Up @@ -92,7 +101,7 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]):
moe_grouped_gemm: bool = False
use_qk_norm: bool = False
fp8: Optional[str] = None
normalization: str = "LayerNorm"
normalization: str = "RMSNorm"

# Multi-token prediction
mtp_enabled: bool = False
Expand Down