From 11648d55e5571fb193091550da00ac392247374d Mon Sep 17 00:00:00 2001 From: ddh0 Date: Fri, 1 Aug 2025 23:48:55 -0500 Subject: [PATCH 01/16] initial PR commit --- gguf-py/gguf/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5707085cb6687..3bc928bc20e51 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -382,6 +382,7 @@ class MODEL_ARCH(IntEnum): DREAM = auto() SMALLTHINKER = auto() LLADA = auto() + GLM4_MOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): From 69d1c58e8c4e3ad98335533c76931eb8ef6d486d Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 01:24:56 -0500 Subject: [PATCH 02/16] add GGUF constants --- gguf-py/gguf/constants.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3bc928bc20e51..d139369702618 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -705,6 +705,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", + MODEL_ARCH.GLM4_MOE: "glm4_moe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2542,6 +2543,27 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, # AKA "e_score_correction_bias" in transformers + ], # TODO } From 2586ae5af7b88462f125c665f8ff7aeea0099f00 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 02:07:32 -0500 Subject: [PATCH 03/16] initial GLM-4.5 integration --- src/llama-arch.h | 1 + src/llama-graph.cpp | 12 ++++++++---- src/llama-model.cpp | 3 +++ src/llama-model.h | 2 ++ 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/llama-arch.h b/src/llama-arch.h index 9b8bd65b2322f..140ae6788865c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ATCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 491a26b6346de..b5f428315d5ad 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -749,8 +749,10 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 FFNs seem to have numerical issues with half-precision accumulators + // -- ref: https://github.com/ggml-org/llama.cpp/pull/13101 + // (GLM4_MOE uses some GLM4 FFNs, so we need to match it too) ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1391,8 +1393,10 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 FFNs seem to have numerical issues with half-precision accumulators + // -- ref: https://github.com/ggml-org/llama.cpp/pull/13101 + // (GLM4_MOE uses some GLM4 FFNs, so we need to match it too) ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3f12edd9bd56..3109ee3515ae5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -111,6 +111,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_355B_A32B: return "355B.A32B (GLM-4.5)"; + case LLM_TYPE_106B_A12B: return "106B.A12B (GLM-4.5)"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -18153,6 +18155,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GLM4: + case LLM_ARCH_GLM4_MOE: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: diff --git a/src/llama-model.h b/src/llama-model.h index 094e23808a813..1e5057bcc8a33 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -103,6 +103,8 @@ enum llm_type { LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_E2B, LLM_TYPE_E4B, }; From 2c6e198e669717dd37911cd58e8c789830dfd3e8 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 02:11:06 -0500 Subject: [PATCH 04/16] fix typo `LLM_ATCH_GLM4_MOE` --> `LLM_ARCH_GLM4_MOE` --- src/llama-arch.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-arch.h b/src/llama-arch.h index 140ae6788865c..d424f8cc1a0f1 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,7 +66,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, - LLM_ATCH_GLM4_MOE, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, From dbe9f10b1ad3e89e78a36efa0214cb1f30500204 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 02:24:37 -0500 Subject: [PATCH 05/16] add glm4_moe tensor mapping --- src/llama-arch.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ba7bf9598670f..4bee9d7723938 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1391,6 +1391,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + }, + }, { LLM_ARCH_BITNET, { From 5f9e4e1425467eac8d53987bf0924e8415223413 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 02:48:08 -0500 Subject: [PATCH 06/16] add `attn_k_norm` and `attn_q_norm` tensors for GLM-4.5 --- gguf-py/gguf/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d139369702618..efd443bc18ed2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2549,7 +2549,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, # not used in the 106B.A12B model MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, # not used in the 106B.A12B model MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.FFN_NORM, From 41169a8729cb7ff8a2eae48e7a8d8f10c3d19304 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 04:04:59 -0500 Subject: [PATCH 07/16] more consistent organization --- gguf-py/gguf/constants.py | 48 +++++++++++++++++++-------------------- src/llama-arch.cpp | 8 ++++--- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index efd443bc18ed2..8fcad32fa4d77 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -357,6 +357,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() GLM4 = auto() + GLM4_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -382,7 +383,6 @@ class MODEL_ARCH(IntEnum): DREAM = auto() SMALLTHINKER = auto() LLADA = auto() - GLM4_MOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -2126,6 +2126,29 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_K_NORM, # not always present + MODEL_TENSOR.ATTN_Q_NORM, # not always present + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, # AKA "e_score_correction_bias" in transformers + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -2543,29 +2566,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], - MODEL_ARCH.GLM4_MOE: [ - MODEL_TENSOR.TOKEN_EMBD, - MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.OUTPUT, - MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_Q, - MODEL_TENSOR.ATTN_Q_NORM, # not used in the 106B.A12B model - MODEL_TENSOR.ATTN_K, - MODEL_TENSOR.ATTN_K_NORM, # not used in the 106B.A12B model - MODEL_TENSOR.ATTN_V, - MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, - MODEL_TENSOR.FFN_GATE, - MODEL_TENSOR.FFN_DOWN, - MODEL_TENSOR.FFN_UP, - MODEL_TENSOR.FFN_GATE_EXP, - MODEL_TENSOR.FFN_DOWN_EXP, - MODEL_TENSOR.FFN_UP_EXP, - MODEL_TENSOR.FFN_GATE_SHEXP, - MODEL_TENSOR.FFN_DOWN_SHEXP, - MODEL_TENSOR.FFN_UP_SHEXP, - MODEL_TENSOR.FFN_EXP_PROBS_B, # AKA "e_score_correction_bias" in transformers - ], # TODO } diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4bee9d7723938..8f10749b90bda 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1395,15 +1395,19 @@ static const std::map> LLM_TENSOR_N LLM_ARCH_GLM4_MOE, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, @@ -1412,8 +1416,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, }, }, { From 3cf2e4a8d4c60c2fe60b5d9f6d12c64bf42c9dbb Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 04:09:04 -0500 Subject: [PATCH 08/16] more consistent organization (cont.) --- gguf-py/gguf/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8fcad32fa4d77..ca9824dce1aee 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -679,6 +679,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.GLM4_MOE: "glm4_moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -705,7 +706,6 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", - MODEL_ARCH.GLM4_MOE: "glm4_moe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { From 428f079e42ad9cebf756867c83918516732bba07 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 15:25:45 -0500 Subject: [PATCH 09/16] llama-hparams : group MoE-specific params together --- src/llama-hparams.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 8b7e2a1130755..5c07adf29b616 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -43,8 +43,6 @@ struct llama_hparams { uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head - uint32_t n_expert = 0; - uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA @@ -61,14 +59,17 @@ struct llama_hparams { std::array n_head_kv_arr; std::array n_ff_arr; - uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; - uint32_t n_ff_exp = 0; - uint32_t n_ff_shexp = 0; - uint32_t n_expert_shared = 0; uint32_t n_norm_groups = 0; + // these params are specific to MoE models + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; + uint32_t n_expert_shared = 0; + uint32_t n_layer_dense_lead = 0; + uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; float expert_weights_scale = 0.0; bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; From 64fbb24081fdd73105b89c2a665e1ccc83f7e95c Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 15:34:37 -0500 Subject: [PATCH 10/16] dummy graph --- src/llama-model.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d4f1d438538e1..b1883f1b44d19 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13566,6 +13566,12 @@ struct llm_build_glm4 : public llm_graph_context { } }; +struct llm_build_glm4_moe : public llm_graph_context { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // TODO + }; +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17878,6 +17884,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); From d85099ecae72c1912f0cffd7d6c18e468cdd7a7a Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 17:17:52 -0500 Subject: [PATCH 11/16] support loading GLM4 hparams --- src/llama-model.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b1883f1b44d19..2d157e184f9b6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1436,6 +1436,36 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + + GGML_ASSERT(hparams.n_expert_shared == 1); + GGML_ASSERT(hparams.expert_weights_scale > 0.0); + + // NOTE: currently only two models use this arch - we need to update the switch + // statement below if more are released + + switch (hparams.n_expert) { + // ref: https://github.com/ggml-org/llama.cpp/pull/15026#issue-3285604563 + case 128: { + type = LLM_TYPE_106B_A12B; + hparams.use_kq_norm = false; + }; break; + case 160: { + type = LLM_TYPE_355B_A32B; + hparams.use_kq_norm = true; + }; break; + default: { + type = LLM_TYPE_UNKNOWN; + hparams.use_kq_norm = false; + }; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); From 261775d1aec19a32d84ebc324b836b6a5829d713 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 17:38:09 -0500 Subject: [PATCH 12/16] add "glm4_moe" LLM_ARCH --- src/llama-arch.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8f10749b90bda..7551852b3ddeb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4_moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1398,8 +1399,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, // optional + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, // optional { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, From a30d9a673da9a2299810e6036bd14f23ca4677cb Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 20:25:33 -0500 Subject: [PATCH 13/16] implement load_tensors for GLM4_MOE --- src/llama-model.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2d157e184f9b6..449245ccf5b04 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4459,6 +4459,59 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GLM4_MOE: + { + const auto tn = LLM_TN(arch); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // self-attention + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional QK norms + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // pre-FFN norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (static_cast(i) < hparams.n_layer_dense_lead) { + // this layer uses a dense FFN block + const int64_t n_ff_dense = hparams.n_ff(i); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_dense}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_dense, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_dense}, 0); + } else { + // this layer uses a MoE FFN block (1 group of conditional experts + 1 shared expert) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // router input and expert biases + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + + // conditional experts + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // shared expert + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_exp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp}, 0); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -13599,7 +13652,7 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // TODO - }; + } }; struct llm_build_nemotron : public llm_graph_context { From d048901d904c15c83eec2857d3af117ad08ff6dc Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sat, 2 Aug 2025 20:26:40 -0500 Subject: [PATCH 14/16] remove trailing whitespace --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 449245ccf5b04..f7d0dba5738c3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18419,4 +18419,4 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; -} +} \ No newline at end of file From 3e18442ff588ff3de0f04752001f3420a04335c3 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sun, 3 Aug 2025 01:20:42 -0500 Subject: [PATCH 15/16] llama-model : implement GLM4 MoE inference graph --- src/llama-model.cpp | 137 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f7d0dba5738c3..77fb489300397 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13651,7 +13651,142 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - // TODO + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_rot = hparams.n_rot; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_rot == n_embd_head / 2); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention block + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // optional QK norm + if (hparams.use_kq_norm) { + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + + // reshape QKV + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // apply RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_roped", il); + cb(Kcur, "Kcur_roped", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + // first residual + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // pre-ffn RMSnorm + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // + if (static_cast(il) < hparams.n_layer_dense_lead) { + // dense FFN + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_dense_out", il); + } else { + // shared expert + ggml_tensor * shexp_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shexp_out, "ffn_shexp_out", il); + + // conditional experts + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, + true, // norm_topk_prob + true, // use expert bias + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, // IMPORTANT -- MUST USE SIGMOID + il); + cb(moe_out, "ffn_moe_out", il); + + // combine output from shared and routed experts + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_combined", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + // second residual + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // output norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "output_norm", -1); + res->t_embd = cur; + + // final output + cur = build_lora_mm(model.output, cur); + cb(cur, "output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); } }; From 61b844200dac7fa7eada0d2f5c0eb06997998959 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Sun, 3 Aug 2025 02:41:27 -0500 Subject: [PATCH 16/16] add `LLM_KV_ATTENTION_USE_KQ_NORM` for GLM4_MOE - remove `ffn_norm` per CISC - re-organize some small things --- gguf-py/gguf/constants.py | 7 +++---- src/llama-arch.cpp | 2 +- src/llama-arch.h | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ca9824dce1aee..28693f20f7dfb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2137,16 +2137,15 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, - MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_GATE_EXP, - MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_GATE_SHEXP, - MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, # AKA "e_score_correction_bias" in transformers ], MODEL_ARCH.BITNET: [ diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7551852b3ddeb..97805b7e29d0c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -164,6 +164,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_USE_KQ_NORM, "%s.attention.use_kq_norm" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -1405,7 +1406,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index d424f8cc1a0f1..c11d3c6c35806 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -168,6 +168,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_USE_KQ_NORM, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS,