Skip to content

Commit aeaf8a3

Browse files
authored
llama : support LiquidAI LFM2-MoE hybrid model (ggml-org#16464)
* llama : support LiquidAI LFM2-MoE hybrid model Add support for [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B) model. For more information about models, please read [the blog post](https://www.liquid.ai/company/news). [HF PR](huggingface/transformers#41401) [GGUFs](https://huggingface.co/LiquidAI/LFM2-8B-A1B-GGUF) * Do not use defaultdict * Address PR feedback
1 parent df1b612 commit aeaf8a3

File tree

7 files changed

+192
-15
lines changed

7 files changed

+192
-15
lines changed

convert_hf_to_gguf.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8836,6 +8836,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
88368836
return [(self.map_tensor_name(name), data_torch)]
88378837

88388838

8839+
@ModelBase.register("Lfm2MoeForCausalLM")
8840+
class LFM2MoeModel(TextModel):
8841+
model_arch = gguf.MODEL_ARCH.LFM2MOE
8842+
8843+
def set_gguf_parameters(self):
8844+
# set num_key_value_heads only for attention layers
8845+
self.hparams["num_key_value_heads"] = [
8846+
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
8847+
for layer_type in self.hparams["layer_types"]
8848+
]
8849+
8850+
super().set_gguf_parameters()
8851+
8852+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
8853+
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
8854+
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
8855+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
8856+
8857+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
8858+
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
8859+
8860+
# cache for experts weights for merging
8861+
_experts_cache: dict[int, dict[str, Tensor]] = {}
8862+
8863+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8864+
# conv op requires 2d tensor
8865+
if 'conv.conv' in name:
8866+
data_torch = data_torch.squeeze(1)
8867+
8868+
if name.endswith(".expert_bias"):
8869+
name = name.replace(".expert_bias", ".expert_bias.bias")
8870+
8871+
# merge expert weights
8872+
if 'experts' in name:
8873+
n_experts = self.hparams["num_experts"]
8874+
assert bid is not None
8875+
8876+
expert_cache = self._experts_cache.setdefault(bid, {})
8877+
expert_cache[name] = data_torch
8878+
expert_weights = ["w1", "w2", "w3"]
8879+
8880+
# not enough expert weights to merge
8881+
if len(expert_cache) < n_experts * len(expert_weights):
8882+
return []
8883+
8884+
tensors: list[tuple[str, Tensor]] = []
8885+
for w_name in expert_weights:
8886+
datas: list[Tensor] = []
8887+
8888+
for xid in range(n_experts):
8889+
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
8890+
datas.append(expert_cache[ename])
8891+
del expert_cache[ename]
8892+
8893+
data_torch = torch.stack(datas, dim=0)
8894+
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
8895+
new_name = self.map_tensor_name(merged_name)
8896+
tensors.append((new_name, data_torch))
8897+
8898+
del self._experts_cache[bid]
8899+
return tensors
8900+
8901+
return [(self.map_tensor_name(name), data_torch)]
8902+
8903+
def prepare_tensors(self):
8904+
super().prepare_tensors()
8905+
assert not self._experts_cache
8906+
8907+
88398908
@ModelBase.register("Lfm2VlForConditionalGeneration")
88408909
class LFM2VLModel(MmprojModel):
88418910
def __init__(self, *args, **kwargs):

gguf-py/gguf/constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class MODEL_ARCH(IntEnum):
407407
SMOLLM3 = auto()
408408
GPT_OSS = auto()
409409
LFM2 = auto()
410+
LFM2MOE = auto()
410411
DREAM = auto()
411412
SMALLTHINKER = auto()
412413
LLADA = auto()
@@ -749,6 +750,7 @@ class MODEL_TENSOR(IntEnum):
749750
MODEL_ARCH.SMOLLM3: "smollm3",
750751
MODEL_ARCH.GPT_OSS: "gpt-oss",
751752
MODEL_ARCH.LFM2: "lfm2",
753+
MODEL_ARCH.LFM2MOE: "lfm2moe",
752754
MODEL_ARCH.DREAM: "dream",
753755
MODEL_ARCH.SMALLTHINKER: "smallthinker",
754756
MODEL_ARCH.LLADA: "llada",
@@ -2698,6 +2700,29 @@ class MODEL_TENSOR(IntEnum):
26982700
MODEL_TENSOR.ATTN_OUT,
26992701
MODEL_TENSOR.OUTPUT,
27002702
],
2703+
MODEL_ARCH.LFM2MOE: [
2704+
MODEL_TENSOR.TOKEN_EMBD,
2705+
MODEL_TENSOR.TOKEN_EMBD_NORM,
2706+
MODEL_TENSOR.SHORTCONV_CONV,
2707+
MODEL_TENSOR.SHORTCONV_INPROJ,
2708+
MODEL_TENSOR.SHORTCONV_OUTPROJ,
2709+
MODEL_TENSOR.FFN_GATE,
2710+
MODEL_TENSOR.FFN_DOWN,
2711+
MODEL_TENSOR.FFN_UP,
2712+
MODEL_TENSOR.FFN_NORM,
2713+
MODEL_TENSOR.ATTN_NORM, # operator_norm
2714+
MODEL_TENSOR.ATTN_Q_NORM,
2715+
MODEL_TENSOR.ATTN_K_NORM,
2716+
MODEL_TENSOR.ATTN_Q,
2717+
MODEL_TENSOR.ATTN_K,
2718+
MODEL_TENSOR.ATTN_V,
2719+
MODEL_TENSOR.ATTN_OUT,
2720+
MODEL_TENSOR.FFN_GATE_INP,
2721+
MODEL_TENSOR.FFN_GATE_EXP,
2722+
MODEL_TENSOR.FFN_DOWN_EXP,
2723+
MODEL_TENSOR.FFN_UP_EXP,
2724+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2725+
],
27012726
MODEL_ARCH.SMALLTHINKER: [
27022727
MODEL_TENSOR.TOKEN_EMBD,
27032728
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class TensorNameMap:
358358
"model.layers.{bid}.mlp.router", # openai-moe
359359
"model.layers.{bid}.mlp.gate.wg", # hunyuan
360360
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
361+
"model.layers.{bid}.feed_forward.gate", # lfm2moe
361362
),
362363

363364
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -367,6 +368,7 @@ class TensorNameMap:
367368
MODEL_TENSOR.FFN_EXP_PROBS_B: (
368369
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
369370
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
371+
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
370372
),
371373

372374
# Feed-forward up

src/llama-arch.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9393
{ LLM_ARCH_SMOLLM3, "smollm3" },
9494
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
9595
{ LLM_ARCH_LFM2, "lfm2" },
96+
{ LLM_ARCH_LFM2MOE, "lfm2moe" },
9697
{ LLM_ARCH_DREAM, "dream" },
9798
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
9899
{ LLM_ARCH_LLADA, "llada" },
@@ -2104,6 +2105,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
21042105
{ LLM_TENSOR_OUTPUT, "output" },
21052106
}
21062107
},
2108+
{
2109+
LLM_ARCH_LFM2MOE,
2110+
{
2111+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2112+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2113+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2114+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2115+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2116+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2117+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2118+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2119+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2120+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2121+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2122+
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
2123+
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
2124+
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
2125+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2126+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
2127+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2128+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2129+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2130+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2131+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
2132+
}
2133+
},
21072134
{
21082135
LLM_ARCH_SMALLTHINKER,
21092136
{
@@ -2493,6 +2520,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
24932520
case LLM_ARCH_PLAMO2:
24942521
case LLM_ARCH_GRANITE_HYBRID:
24952522
case LLM_ARCH_LFM2:
2523+
case LLM_ARCH_LFM2MOE:
24962524
case LLM_ARCH_NEMOTRON_H:
24972525
return true;
24982526
default:

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ enum llm_arch {
9797
LLM_ARCH_SMOLLM3,
9898
LLM_ARCH_OPENAI_MOE,
9999
LLM_ARCH_LFM2,
100+
LLM_ARCH_LFM2MOE,
100101
LLM_ARCH_DREAM,
101102
LLM_ARCH_SMALLTHINKER,
102103
LLM_ARCH_LLADA,

src/llama-model.cpp

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) {
114114
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
115115
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
116116
case LLM_TYPE_A13B: return "A13B";
117+
case LLM_TYPE_8B_A1B: return "8B.A1B";
117118
case LLM_TYPE_21B_A3B: return "21B.A3B";
118119
case LLM_TYPE_30B_A3B: return "30B.A3B";
119120
case LLM_TYPE_106B_A12B: return "106B.A12B";
@@ -1995,14 +1996,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
19951996
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
19961997
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
19971998
}
1999+
hparams.n_layer_dense_lead = hparams.n_layer;
19982000
switch (hparams.n_ff()) {
19992001
case 4608: type = LLM_TYPE_350M; break;
20002002
case 6912: type = LLM_TYPE_700M; break;
20012003
case 8192: type = LLM_TYPE_1_2B; break;
20022004
case 10752: type = LLM_TYPE_2_6B; break;
2003-
default: type = LLM_TYPE_UNKNOWN;
2005+
default: type = LLM_TYPE_UNKNOWN;
20042006
}
20052007
} break;
2008+
case LLM_ARCH_LFM2MOE:
2009+
{
2010+
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
2011+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2012+
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
2013+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
2014+
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
2015+
2016+
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
2017+
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
2018+
}
2019+
2020+
type = LLM_TYPE_8B_A1B;
2021+
} break;
20062022
case LLM_ARCH_SMALLTHINKER:
20072023
{
20082024
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -5814,6 +5830,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
58145830
}
58155831
} break;
58165832
case LLM_ARCH_LFM2:
5833+
case LLM_ARCH_LFM2MOE:
58175834
{
58185835
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
58195836
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
@@ -5825,11 +5842,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
58255842

58265843
for (int i = 0; i < n_layer; ++i) {
58275844
auto & layer = layers[i];
5828-
// ffn is same for transformer and conv layers
5845+
5846+
const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead);
5847+
5848+
// ffn/moe is same for transformer and conv layers
58295849
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
5830-
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
5831-
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
5832-
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
5850+
if (is_moe_layer) {
5851+
GGML_ASSERT(n_expert && n_expert_used);
5852+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
5853+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
5854+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0);
5855+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
5856+
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
5857+
} else { // dense
5858+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
5859+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
5860+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
5861+
}
58335862

58345863
// for operator_norm
58355864
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@@ -6310,7 +6339,7 @@ void llama_model::print_info() const {
63106339
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
63116340
}
63126341

6313-
if (arch == LLM_ARCH_SMALLTHINKER) {
6342+
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
63146343
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
63156344
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
63166345
}
@@ -18602,6 +18631,8 @@ struct llm_build_lfm2 : public llm_graph_context {
1860218631
ggml_tensor * inp_out_ids = build_inp_out_ids();
1860318632

1860418633
for (int il = 0; il < n_layer; ++il) {
18634+
const bool is_moe_layer = il >= static_cast<int>(hparams.n_layer_dense_lead);
18635+
1860518636
auto * prev_cur = cur;
1860618637
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
1860718638
cb(cur, "model.layers.{}.operator_norm", il);
@@ -18616,7 +18647,16 @@ struct llm_build_lfm2 : public llm_graph_context {
1861618647
}
1861718648

1861818649
cur = ggml_add(ctx0, prev_cur, cur);
18619-
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
18650+
18651+
auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
18652+
cb(ffn_norm_out, "model.layers.{}.ffn_norm", il);
18653+
18654+
ggml_tensor * ffn_out = is_moe_layer ?
18655+
build_moe_feed_forward(ffn_norm_out, il) :
18656+
build_dense_feed_forward(ffn_norm_out, il);
18657+
cb(ffn_norm_out, "model.layers.{}.ffn_out", il);
18658+
18659+
cur = ggml_add(ctx0, cur, ffn_out);
1862018660
}
1862118661

1862218662
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
@@ -18631,23 +18671,32 @@ struct llm_build_lfm2 : public llm_graph_context {
1863118671
ggml_build_forward_expand(gf, cur);
1863218672
}
1863318673

18634-
ggml_tensor * build_feed_forward(ggml_tensor * cur,
18635-
int il) const {
18636-
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
18637-
cb(cur, "model.layers.{}.ffn_norm", il);
18674+
ggml_tensor * build_moe_feed_forward(ggml_tensor * cur,
18675+
int il) const {
18676+
return build_moe_ffn(cur,
18677+
model.layers[il].ffn_gate_inp,
18678+
model.layers[il].ffn_up_exps,
18679+
model.layers[il].ffn_gate_exps,
18680+
model.layers[il].ffn_down_exps,
18681+
model.layers[il].ffn_exp_probs_b,
18682+
n_expert, n_expert_used,
18683+
LLM_FFN_SILU, true,
18684+
false, 0.0,
18685+
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
18686+
il);
18687+
}
1863818688

18689+
ggml_tensor * build_dense_feed_forward(ggml_tensor * cur,
18690+
int il) const {
1863918691
GGML_ASSERT(!model.layers[il].ffn_up_b);
1864018692
GGML_ASSERT(!model.layers[il].ffn_gate_b);
1864118693
GGML_ASSERT(!model.layers[il].ffn_down_b);
18642-
cur = build_ffn(cur,
18694+
return build_ffn(cur,
1864318695
model.layers[il].ffn_up, NULL, NULL,
1864418696
model.layers[il].ffn_gate, NULL, NULL,
1864518697
model.layers[il].ffn_down, NULL, NULL,
1864618698
NULL,
1864718699
LLM_FFN_SILU, LLM_FFN_PAR, il);
18648-
cb(cur, "model.layers.{}.feed_forward.w2", il);
18649-
18650-
return cur;
1865118700
}
1865218701

1865318702
ggml_tensor * build_attn_block(ggml_tensor * cur,
@@ -19817,6 +19866,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1981719866
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
1981819867
} break;
1981919868
case LLM_ARCH_LFM2:
19869+
case LLM_ARCH_LFM2MOE:
1982019870
{
1982119871
llm = std::make_unique<llm_build_lfm2>(*this, params);
1982219872
} break;
@@ -20039,6 +20089,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
2003920089
case LLM_ARCH_OPENAI_MOE:
2004020090
case LLM_ARCH_HUNYUAN_DENSE:
2004120091
case LLM_ARCH_LFM2:
20092+
case LLM_ARCH_LFM2MOE:
2004220093
case LLM_ARCH_SMALLTHINKER:
2004320094
case LLM_ARCH_GLM4_MOE:
2004420095
case LLM_ARCH_SEED_OSS:

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ enum llm_type {
107107
LLM_TYPE_17B_16E, // llama4 Scout
108108
LLM_TYPE_17B_128E, // llama4 Maverick
109109
LLM_TYPE_A13B,
110+
LLM_TYPE_8B_A1B, // lfm2moe
110111
LLM_TYPE_21B_A3B, // Ernie MoE small
111112
LLM_TYPE_30B_A3B,
112113
LLM_TYPE_106B_A12B, // GLM-4.5-Air

0 commit comments

Comments
 (0)