Skip to content

Commit 6c6e397

Browse files
wdl339CISCggerganov
authored
model : add support for SmallThinker series (#14898)
* support smallthinker * support 20b softmax, 4b no sliding window * new build_moe_ffn_from_probs, and can run 4b * fix 4b rope bug * fix python type check * remove is_moe judge * remove set_dense_start_swa_pattern function and modify set_swa_pattern function * trim trailing whitespace * remove get_vocab_base of SmallThinkerModel in convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * better whitespace Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> * use GGML_ASSERT for expert count validation Co-authored-by: Sigbjørn Skjæret <[email protected]> * Improve null pointer check for probs Co-authored-by: Sigbjørn Skjæret <[email protected]> * use template parameter for SWA attention logic * better whitespace Co-authored-by: Georgi Gerganov <[email protected]> * move the creation of inp_out_ids before the layer loop * remove redundant judge for probs --------- Co-authored-by: Sigbjørn Skjæret <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent afc0e89 commit 6c6e397

File tree

10 files changed

+443
-6
lines changed

10 files changed

+443
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7589,6 +7589,88 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
75897589
return [(self.map_tensor_name(name), data_torch)]
75907590

75917591

7592+
@ModelBase.register("SmallThinkerForCausalLM")
7593+
class SmallThinkerModel(TextModel):
7594+
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
7595+
7596+
def set_gguf_parameters(self):
7597+
super().set_gguf_parameters()
7598+
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
7599+
self.gguf_writer.add_expert_count(n_experts)
7600+
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
7601+
self.gguf_writer.add_expert_used_count(n_experts_used)
7602+
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
7603+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
7604+
self.gguf_writer.add_feed_forward_length(moe_intermediate_size)
7605+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
7606+
if (self.hparams.get('moe_primary_router_apply_softmax')):
7607+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
7608+
else:
7609+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
7610+
# YaRN is not enabled by default
7611+
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
7612+
rope_scaling = self.hparams.get("rope_scaling") or {}
7613+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
7614+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
7615+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
7616+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
7617+
7618+
sliding_window_layout = self.hparams.get("sliding_window_layout")
7619+
if sliding_window_layout:
7620+
for i in sliding_window_layout:
7621+
if i != 0:
7622+
sliding_window = self.hparams.get("sliding_window_size")
7623+
if sliding_window:
7624+
self.gguf_writer.add_sliding_window(sliding_window)
7625+
break
7626+
7627+
_experts: list[dict[str, Tensor]] | None = None
7628+
7629+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7630+
# process the experts separately
7631+
if name.find("experts") != -1:
7632+
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
7633+
assert bid is not None
7634+
7635+
if self._experts is None:
7636+
self._experts = [{} for _ in range(self.block_count)]
7637+
7638+
self._experts[bid][name] = data_torch
7639+
7640+
if len(self._experts[bid]) >= n_experts * 3:
7641+
tensors: list[tuple[str, Tensor]] = []
7642+
7643+
# merge the experts into a single 3d tensor
7644+
for w_name in ["down", "gate", "up"]:
7645+
datas: list[Tensor] = []
7646+
7647+
for xid in range(n_experts):
7648+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
7649+
datas.append(self._experts[bid][ename])
7650+
del self._experts[bid][ename]
7651+
7652+
data_torch = torch.stack(datas, dim=0)
7653+
7654+
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
7655+
7656+
new_name = self.map_tensor_name(merged_name)
7657+
7658+
tensors.append((new_name, data_torch))
7659+
return tensors
7660+
else:
7661+
return []
7662+
7663+
return [(self.map_tensor_name(name), data_torch)]
7664+
7665+
def prepare_tensors(self):
7666+
super().prepare_tensors()
7667+
7668+
if self._experts is not None:
7669+
# flatten `list[dict[str, Tensor]]` into `list[str]`
7670+
experts = [k for d in self._experts for k in d.keys()]
7671+
if len(experts) > 0:
7672+
raise ValueError(f"Unprocessed experts: {experts}")
7673+
75927674
###### CONVERSION LOGIC ######
75937675

75947676

gguf-py/gguf/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ class MODEL_ARCH(IntEnum):
376376
SMOLLM3 = auto()
377377
LFM2 = auto()
378378
DREAM = auto()
379+
SMALLTHINKER = auto()
379380

380381

381382
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -695,6 +696,7 @@ class MODEL_TENSOR(IntEnum):
695696
MODEL_ARCH.SMOLLM3: "smollm3",
696697
MODEL_ARCH.LFM2: "lfm2",
697698
MODEL_ARCH.DREAM: "dream",
699+
MODEL_ARCH.SMALLTHINKER: "smallthinker",
698700
}
699701

700702
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2483,6 +2485,24 @@ class MODEL_TENSOR(IntEnum):
24832485
MODEL_TENSOR.ATTN_V,
24842486
MODEL_TENSOR.ATTN_OUT,
24852487
],
2488+
MODEL_ARCH.SMALLTHINKER: [
2489+
MODEL_TENSOR.TOKEN_EMBD,
2490+
MODEL_TENSOR.OUTPUT_NORM,
2491+
MODEL_TENSOR.OUTPUT,
2492+
MODEL_TENSOR.ATTN_NORM,
2493+
MODEL_TENSOR.ATTN_Q,
2494+
MODEL_TENSOR.ATTN_K,
2495+
MODEL_TENSOR.ATTN_V,
2496+
MODEL_TENSOR.ATTN_OUT,
2497+
MODEL_TENSOR.FFN_NORM,
2498+
MODEL_TENSOR.FFN_GATE,
2499+
MODEL_TENSOR.FFN_DOWN,
2500+
MODEL_TENSOR.FFN_UP,
2501+
MODEL_TENSOR.FFN_GATE_INP,
2502+
MODEL_TENSOR.FFN_GATE_EXP,
2503+
MODEL_TENSOR.FFN_DOWN_EXP,
2504+
MODEL_TENSOR.FFN_UP_EXP,
2505+
],
24862506
# TODO
24872507
}
24882508

gguf-py/gguf/tensor_mapping.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class TensorNameMap:
317317
"model.layers.{bid}.feed_forward.router", # llama4 jamba
318318
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
319319
"model.layers.{bid}.mlp.gate.wg", # hunyuan
320+
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
320321
),
321322

322323
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -362,6 +363,7 @@ class TensorNameMap:
362363
"transformer.h.{bid}.mlp.c_fc_1", # exaone
363364
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
364365
"transformer_encoder.{bid}.ffn.w12", # neobert
366+
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
365367
),
366368

367369
MODEL_TENSOR.FFN_UP_EXP: (
@@ -372,6 +374,7 @@ class TensorNameMap:
372374
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
373375
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
374376
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
377+
"model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
375378
),
376379

377380
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -401,6 +404,7 @@ class TensorNameMap:
401404
"model.layers.{bid}.residual_mlp.w1", # arctic
402405
"transformer.h.{bid}.mlp.c_fc_0", # exaone
403406
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
407+
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
404408
),
405409

406410
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -410,6 +414,7 @@ class TensorNameMap:
410414
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
411415
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
412416
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
417+
"model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
413418
),
414419

415420
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -448,6 +453,7 @@ class TensorNameMap:
448453
"model.layers.h.{bid}.mlp.c_proj", # exaone
449454
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
450455
"transformer_encoder.{bid}.ffn.w3", # neobert
456+
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
451457
),
452458

453459
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -459,6 +465,7 @@ class TensorNameMap:
459465
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
460466
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
461467
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
468+
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
462469
),
463470

464471
MODEL_TENSOR.FFN_DOWN_SHEXP: (

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8888
{ LLM_ARCH_SMOLLM3, "smollm3" },
8989
{ LLM_ARCH_LFM2, "lfm2" },
9090
{ LLM_ARCH_DREAM, "dream" },
91+
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
9192
{ LLM_ARCH_UNKNOWN, "(unknown)" },
9293
};
9394

@@ -1933,6 +1934,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
19331934
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
19341935
}
19351936
},
1937+
{
1938+
LLM_ARCH_SMALLTHINKER,
1939+
{
1940+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1941+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1942+
{ LLM_TENSOR_OUTPUT, "output" },
1943+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1944+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1945+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1946+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1947+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1948+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1949+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1950+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1951+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1952+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1953+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1954+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1955+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
1956+
},
1957+
},
19361958
{
19371959
LLM_ARCH_DREAM,
19381960
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ enum llm_arch {
9292
LLM_ARCH_SMOLLM3,
9393
LLM_ARCH_LFM2,
9494
LLM_ARCH_DREAM,
95+
LLM_ARCH_SMALLTHINKER,
9596
LLM_ARCH_UNKNOWN,
9697
};
9798

src/llama-graph.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
938938
return moe_out;
939939
}
940940

941+
ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
942+
ggml_tensor * cur,
943+
ggml_tensor * probs,
944+
ggml_tensor * up_exps,
945+
ggml_tensor * gate_exps,
946+
ggml_tensor * down_exps,
947+
ggml_tensor * exp_probs_b,
948+
int64_t n_expert,
949+
int64_t n_expert_used,
950+
llama_expert_gating_func_type gating_op,
951+
int il) const {
952+
const int64_t n_embd = cur->ne[0];
953+
const int64_t n_tokens = cur->ne[1];
954+
955+
// add experts selection bias - introduced in DeepSeek V3
956+
// leave probs unbiased as it's later used to get expert weights
957+
ggml_tensor * selection_probs = probs;
958+
if (exp_probs_b != nullptr) {
959+
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
960+
cb(selection_probs, "ffn_moe_probs_biased", il);
961+
}
962+
963+
// select experts
964+
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
965+
cb(selected_experts->src[0], "ffn_moe_argsort", il);
966+
cb(selected_experts, "ffn_moe_topk", il);
967+
968+
ggml_tensor * weights = ggml_get_rows(ctx0,
969+
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
970+
cb(weights, "ffn_moe_weights", il);
971+
972+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
973+
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
974+
weights = ggml_soft_max(ctx0, weights);
975+
} else {
976+
weights = ggml_sigmoid(ctx0, weights);
977+
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
978+
cb(weights_sum, "ffn_moe_weights_sum", il);
979+
980+
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
981+
cb(weights, "ffn_moe_weights_norm", il);
982+
}
983+
984+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
985+
986+
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
987+
988+
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
989+
cb(up, "ffn_moe_up", il);
990+
991+
ggml_tensor * experts = nullptr;
992+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
993+
cb(cur, "ffn_moe_gate", il);
994+
995+
cur = ggml_reglu_split(ctx0, cur, up);
996+
cb(cur, "ffn_moe_reglu", il);
997+
998+
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
999+
cb(experts, "ffn_moe_down", il);
1000+
1001+
experts = ggml_mul(ctx0, experts, weights);
1002+
cb(cur, "ffn_moe_weighted", il);
1003+
1004+
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1005+
1006+
assert(n_expert_used > 0);
1007+
1008+
// order the views before the adds
1009+
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1010+
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1011+
1012+
ggml_build_forward_expand(gf, cur_experts[i]);
1013+
}
1014+
1015+
// aggregate experts
1016+
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1017+
// to avoid potentially a large number of add nodes during warmup
1018+
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
1019+
ggml_tensor * moe_out = cur_experts[0];
1020+
1021+
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1022+
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
1023+
}
1024+
1025+
if (n_expert_used == 1) {
1026+
// avoid returning a non-contiguous tensor
1027+
moe_out = ggml_cont(ctx0, moe_out);
1028+
}
1029+
1030+
cb(moe_out, "ffn_moe_out", il);
1031+
1032+
return moe_out;
1033+
}
1034+
9411035
// input embeddings with optional lora
9421036
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
9431037
const int64_t n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,18 @@ struct llm_graph_context {
625625
llama_expert_gating_func_type gating_op,
626626
int il) const;
627627

628+
ggml_tensor * build_moe_ffn_from_probs(
629+
ggml_tensor * cur,
630+
ggml_tensor * probs,
631+
ggml_tensor * up_exps,
632+
ggml_tensor * gate_exps,
633+
ggml_tensor * down_exps,
634+
ggml_tensor * exp_probs_b,
635+
int64_t n_expert,
636+
int64_t n_expert_used,
637+
llama_expert_gating_func_type gating_op,
638+
int il) const;
639+
628640
//
629641
// inputs
630642
//

src/llama-hparams.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22

33
#include "ggml.h"
44

5-
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
6-
for (uint32_t il = 0; il < n_layer; ++il) {
7-
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
5+
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
6+
if (dense_first) {
7+
for (uint32_t il = 0; il < n_layer; ++il) {
8+
swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
9+
}
10+
} else {
11+
for (uint32_t il = 0; il < n_layer; ++il) {
12+
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
13+
}
814
}
915
}
1016

0 commit comments

Comments
 (0)