diff --git a/src/llama_engine.cc b/src/llama_engine.cc index f1326df..6e4de49 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -6,6 +6,7 @@ #include "json-schema-to-grammar.h" #include "json/writer.h" #include "llama_utils.h" +#include "src/llama-arch.h" #include "trantor/utils/Logger.h" #if defined(_WIN32) @@ -805,6 +806,13 @@ void LlamaEngine::HandleInferenceImpl( for (const auto& elem : completion.logit_bias) { arr.push_back(llama::inferences::ConvertJsonCppToNlohmann(elem)); } + + if (si.ctx.model->arch == LLM_ARCH_QWEN3) { + json qwen3 = json::array(); + qwen3.push_back(151643); + qwen3.push_back(-100000); + arr.push_back(qwen3); + } data["logit_bias"] = std::move(arr); int n_probs = completion.n_probs; const Json::Value& messages = completion.messages; diff --git a/src/llama_server_context.h b/src/llama_server_context.h index 05c581b..16b1cb6 100644 --- a/src/llama_server_context.h +++ b/src/llama_server_context.h @@ -11,6 +11,7 @@ // External +#include "src/llama-model.h" #include "llama_client_slot.h" #if defined(_WIN32) @@ -107,11 +108,11 @@ static T json_value(const json& body, const std::string& key, struct LlamaServerContext { common_init_result llama_init; - + llama_model* model = nullptr; llama_context* ctx = nullptr; - const llama_vocab * vocab = nullptr; + const llama_vocab* vocab = nullptr; clip_ctx* clp_ctx = nullptr;