Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,15 +960,13 @@ struct common_init_result common_init_from_params(common_params & params) {

bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;

if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
if (!has_eos && !has_sep && !has_rerank_prompt) {
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}

if (!ok) {
Expand Down
64 changes: 64 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3652,11 +3652,29 @@ def prepare_tensors(self):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3

# extra logic for rerank models
is_rerank: bool = False
is_tied_embeddings: bool = False
token_false_id: int | None = None
token_true_id: int | None = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# track for intern-s1-mini
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]

# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()

def set_vocab(self):
# deal with intern-s1-mini
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
Expand All @@ -3665,6 +3683,52 @@ def set_vocab(self):

super().set_vocab()

def _find_rerank_config(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)

self.is_rerank = True
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
self.sep_token_id = tokenizer.convert_tokens_to_ids("|")

assert self.token_false_id is not None and self.token_true_id is not None

def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.is_rerank:
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
self.gguf_writer.add_chat_template([{
"name": "rerank",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
"<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}\n"
"<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
}])

def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
# extract "yes" and "no" tokens from the output lm_head tensor
false_row = data_torch[self.token_false_id]
true_row = data_torch[self.token_true_id]
return torch.stack([true_row, false_row], dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.is_rerank:
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
is_real_head = not self.is_tied_embeddings and "lm_head" in name
if is_tied_head or is_real_head:
cls_out_head = (
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight",
self._get_cls_out_tensor(data_torch),
)
if is_tied_head:
embed = (self.map_tensor_name(name), data_torch)
return [cls_out_head, embed]
if is_real_head:
return [cls_out_head]

return super().modify_tensors(data_torch, name, bid)

@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
Expand Down
43 changes: 28 additions & 15 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
params.n_batch = params.n_ctx;
}

// For non-causal models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;
// for non-causal models, batch size must be equal to ubatch size
if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
params.n_ubatch = params.n_batch;
}

// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();

llama_backend_init();
llama_numa_init(params.numa);
Expand Down Expand Up @@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
// get added sep and eos token, if any
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
const char * rerank_prompt = llama_model_chat_template(model, "rerank");

// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
Expand All @@ -153,21 +159,28 @@ int main(int argc, char ** argv) {
// split classification pairs and insert expected separator tokens
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
std::string final_prompt;

for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i];
if (i != pairs.size() - 1) {
if (!added_eos_token.empty()) {
final_prompt += added_eos_token;
}
if (!added_sep_token.empty()) {
final_prompt += added_sep_token;
if (rerank_prompt != nullptr) {
const std::string query = pairs[0];
const std::string doc = pairs[1];
std::string final_prompt = rerank_prompt;
string_replace_all(final_prompt, "{query}" , query);
string_replace_all(final_prompt, "{document}", doc );
inp = common_tokenize(vocab, final_prompt, true, false);
} else {
std::string final_prompt;
for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i];
if (i != pairs.size() - 1) {
if (!added_eos_token.empty()) {
final_prompt += added_eos_token;
}
if (!added_sep_token.empty()) {
final_prompt += added_sep_token;
}
}
}
inp = common_tokenize(ctx, final_prompt, true, true);
}

inp = common_tokenize(ctx, final_prompt, true, true);
} else {
inp = common_tokenize(ctx, prompt, true, true);
}
Expand Down Expand Up @@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();

// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
Expand Down
11 changes: 8 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
std::vector<int> target_pos(n_seqs_unq, -1);
std::vector<int> target_row(n_seqs_unq, -1);

bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
const bool last = (
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
);

for (int i = 0; i < n_tokens; ++i) {
const llama_pos pos = ubatch->pos[i];
Expand Down Expand Up @@ -1177,7 +1180,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
}

ggml_tensor * llm_graph_context::build_inp_cls() const {
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);

auto & cur = inp->cls;

Expand Down Expand Up @@ -1899,7 +1902,9 @@ void llm_graph_context::build_pooling(
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
cur = ggml_mul_mat(ctx0, cls_out, inp);
if (cls_out_b) {
if (arch == LLM_ARCH_QWEN3) {
cur = ggml_log(ctx0, ggml_soft_max(ctx0, cur));
} else if (cls_out_b) {
cur = ggml_add(ctx0, cur, cls_out_b);
}
} else {
Expand Down
3 changes: 2 additions & 1 deletion src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,15 @@ class llm_graph_input_mean : public llm_graph_input_i {

class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
virtual ~llm_graph_input_cls() = default;

void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * cls; // I32 [n_batch]

const llama_cparams cparams;
const llm_arch arch;
};

class llm_graph_input_rs : public llm_graph_input_i {
Expand Down
3 changes: 3 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3062,6 +3062,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}

// output rerank head
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

Expand Down
12 changes: 3 additions & 9 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4928,21 +4928,15 @@ int main(int argc, char ** argv) {
return;
}

std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
if (tokenized_queries.size() != 1) {
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
}

// create and queue the task
json responses = json::array();
bool error = false;
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
Expand Down
68 changes: 40 additions & 28 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,34 +1317,6 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
return std::to_string(hash);
}


// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
static server_tokens format_rerank(const struct llama_vocab * vocab, server_tokens & query, server_tokens & doc) {
server_tokens result = {};

// Get EOS token - use SEP token as fallback if EOS is not available
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}
if (llama_vocab_get_add_bos(vocab)) {
result.push_back(llama_vocab_bos(vocab));
}
result.push_back(query);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
if (llama_vocab_get_add_sep(vocab)) {
result.push_back(llama_vocab_sep(vocab));
}
result.push_back(doc);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
return result;
}


static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
mtmd::bitmaps bitmaps;
for (auto & file : files) {
Expand Down Expand Up @@ -1450,3 +1422,43 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * voc
}
return result;
}

// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) {
server_tokens result = {};

const char * rerank_prompt = llama_model_chat_template(model, "rerank");

if (rerank_prompt != nullptr) {
std::string prompt = rerank_prompt;
string_replace_all(prompt, "{query}" , query);
string_replace_all(prompt, "{document}", doc );
server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
result.push_back(tokens);
} else {
// Get EOS token - use SEP token as fallback if EOS is not available
server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false);
server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false);
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}

if (llama_vocab_get_add_bos(vocab)) {
result.push_back(llama_vocab_bos(vocab));
}
result.push_back(query_tokens);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
if (llama_vocab_get_add_sep(vocab)) {
result.push_back(llama_vocab_sep(vocab));
}
result.push_back(doc_tokens);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
}

return result;
}
Loading