Skip to content

Commit fd62188

Browse files
gabe-l-hartCISC
andauthored
aLoRA Support (#15327)
* feat: Add python-side constants and conversion for adapter.lora.invocation_string Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Add c++ side constants for adapter.lora.invocation_string Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Parse invocation string for adapters from GGUF Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * fix(python): Update conversion to alora_invocation_tokens This is the preferred method in PEFT which is the source of ground truth https://github.com/huggingface/peft/pull/2609/files#diff-13380145401d203d5935c5189dd09879f990b81aa63e8e3aaff8ce9110333f0e Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * fix(cpp): Update to alora_invocation_tokens on c++ side Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Add C APIs to get alora invocation token array from lora Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Initial implementation of alora cache logic in server This does not yet do the part to identify the invocation tokens and only apply the lora adapter afterwards, but it does seem to produce correct results if the invocation tokens are the beginning of the uncached input. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Identify alora invocation sequences This currently limits to a single enabled alora per slot. Multiple aloras with different invocation sequences would be possible, but it would require a more complex integration of the adapter toggling and is not really a well studied case for alora since it's unclear if one alora can reuse cache from previous prefill computed with a different alora. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Only reuse cache for tokens before the alora invocation start This is a bit of an edge case, but theoretically a user could try the same query with the alora disabled (just using the base model), then retry with the alora. The cached tokens from the first pass should be invalid. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * feat: Handle un-cached tokens that come before the alora activation The solution is to only fill up to the token before the invocation start in the batch if there are any tokens to be prefilled between those pulled from cache and the invocation start. When this is detected, the alora is temporarily disabled with a scale of 0.0, then immediately re-enabled after it has been initialized for the internal graph. Since the batch does not complete the prompt tokens, the remaining prompt tokens are handled in the next task, pulling all of the non-alora tokens from cache and proceeding with prefill for the alora tokens. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * fix: Use || instead of 'or' Too much python 🤦 Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * fix: Fix off-by-one for limiting cached tokens to before alora start This was the cause of the inconsistent results from the dummy test script with and without the turn that runs the prompt without the adapter before running it with the adapter. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * fix: Support backwards-compatibility for "invocation_string" in adapter_config.json While this has been replaced in the PEFT PR in favor of alora_invocation_tokens, the existing adapters in the ibm-granite org on HF use "invocation_string," so this will enable backwards compatibility and enable testing now (before PEFT PR changes have percolated everywhere). Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> * fix: Remove duplicate logging Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Sigbjørn Skjæret <[email protected]> * feat: Report alora_invocation_string and alora_invocation_tokens from /lora-adapters Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 4281c7b commit fd62188

File tree

9 files changed

+232
-14
lines changed

9 files changed

+232
-14
lines changed

convert_lora_to_gguf.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from math import prod
1313
from pathlib import Path
1414
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
15-
from transformers import AutoConfig
15+
from transformers import AutoConfig, AutoTokenizer
1616

1717
import torch
1818

@@ -26,6 +26,8 @@
2626
# reuse model definitions from convert_hf_to_gguf.py
2727
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
2828

29+
from gguf.constants import GGUFValueType
30+
2931
logger = logging.getLogger("lora-to-gguf")
3032

3133

@@ -369,7 +371,31 @@ def set_type(self):
369371
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
370372

371373
def set_gguf_parameters(self):
374+
logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
372375
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
376+
alora_invocation_tokens = lparams.get("alora_invocation_tokens")
377+
invocation_string = lparams.get("invocation_string")
378+
if invocation_string and not alora_invocation_tokens:
379+
logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
380+
base_model_path_or_id = hparams.get("_name_or_path")
381+
try:
382+
tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
383+
except ValueError:
384+
logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
385+
raise
386+
# NOTE: There's an off-by-one with the older aLoRAs where
387+
# the invocation string includes the "<|start_of_turn|>"
388+
# token, but the adapters themselves were trained to
389+
# activate _after_ that first token, so we drop it here.
390+
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
391+
if alora_invocation_tokens:
392+
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
393+
self.gguf_writer.add_key_value(
394+
gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
395+
alora_invocation_tokens,
396+
GGUFValueType.ARRAY,
397+
GGUFValueType.UINT32,
398+
)
373399

374400
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
375401
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters

gguf-py/gguf/constants.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,11 @@ class Tokenizer:
231231
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
232232

233233
class Adapter:
234-
TYPE = "adapter.type"
235-
LORA_ALPHA = "adapter.lora.alpha"
236-
LORA_TASK_NAME = "adapter.lora.task_name"
237-
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
234+
TYPE = "adapter.type"
235+
LORA_ALPHA = "adapter.lora.alpha"
236+
LORA_TASK_NAME = "adapter.lora.task_name"
237+
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
238+
ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens"
238239

239240
class IMatrix:
240241
CHUNK_COUNT = "imatrix.chunk_count"

include/llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,10 @@ extern "C" {
583583
// Note: loaded adapters will be free when the associated model is deleted
584584
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
585585

586+
// Get the invocation tokens if the current lora is an alora
587+
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
588+
LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter);
589+
586590
// The following functions operate on a llama_context, hence the naming: llama_verb_...
587591

588592
// Add a loaded LoRA adapter to given context

src/llama-adapter.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <map>
88
#include <cassert>
9+
#include <sstream>
910
#include <stdexcept>
1011

1112
// vec
@@ -215,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
215216
}
216217

217218
adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
219+
220+
// parse alora invocation sequence vector
221+
const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS);
222+
const int kid = gguf_find_key(ctx_gguf.get(), key.c_str());
223+
if (kid >= 0) {
224+
if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) {
225+
throw std::runtime_error("invalid gguf type for " + key);
226+
}
227+
const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid);
228+
if (arr_type != GGUF_TYPE_UINT32) {
229+
throw std::runtime_error("invalid gguf element type for " + key);
230+
}
231+
const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid);
232+
const void * data = gguf_get_arr_data(ctx_gguf.get(), kid);
233+
adapter.alora_invocation_tokens.resize(seq_len);
234+
std::copy(
235+
(const llama_token *)data,
236+
(const llama_token *)data + seq_len,
237+
adapter.alora_invocation_tokens.begin());
238+
}
218239
}
219240

220241
int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
@@ -450,3 +471,15 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
450471
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
451472
delete adapter;
452473
}
474+
475+
uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
476+
if (!adapter) {
477+
return 0;
478+
}
479+
return adapter->alora_invocation_tokens.size();
480+
}
481+
482+
const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) {
483+
GGML_ASSERT(adapter);
484+
return adapter->alora_invocation_tokens.data();
485+
}

src/llama-adapter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ struct llama_adapter_lora {
7070
// gguf metadata
7171
std::unordered_map<std::string, std::string> gguf_kv;
7272

73+
// activated lora (aLoRA)
74+
std::vector<llama_token> alora_invocation_tokens;
75+
7376
llama_adapter_lora() = default;
7477
~llama_adapter_lora() = default;
7578

src/llama-arch.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
237237
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
238238
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
239239

240-
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
241-
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
242-
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
243-
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
240+
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
241+
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
242+
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
243+
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
244+
{ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },
244245

245246
// deprecated
246247
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ enum llm_kv {
235235
LLM_KV_ADAPTER_LORA_ALPHA,
236236
LLM_KV_ADAPTER_LORA_TASK_NAME,
237237
LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
238+
LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS,
238239

239240
LLM_KV_POSNET_EMBEDDING_LENGTH,
240241
LLM_KV_POSNET_BLOCK_COUNT,

tools/server/server.cpp

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ struct slot_params {
117117
int32_t n_keep = 0; // number of tokens to keep from initial prompt
118118
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
119119
int32_t n_predict = -1; // new tokens to predict
120-
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
120+
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
121121

122122
int64_t t_max_prompt_ms = -1; // TODO: implement
123123
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -1382,6 +1382,7 @@ struct server_slot {
13821382
common_speculative * spec = nullptr;
13831383

13841384
std::vector<common_adapter_lora_info> lora;
1385+
int32_t alora_invocation_start = -1;
13851386

13861387
// the index relative to completion multi-task request
13871388
size_t index = 0;
@@ -1476,6 +1477,9 @@ struct server_slot {
14761477
// clear speculative decoding stats
14771478
n_draft_total = 0;
14781479
n_draft_accepted = 0;
1480+
1481+
// clear alora start
1482+
alora_invocation_start = -1;
14791483
}
14801484

14811485
bool need_embd() const {
@@ -2367,11 +2371,65 @@ struct server_context {
23672371
slot.prompt_tokens = std::move(task.prompt_tokens);
23682372

23692373
if (!are_lora_equal(slot.params.lora, slot.lora)) {
2370-
// if lora is changed, we cannot reuse cached tokens
2371-
slot.cache_tokens.clear();
2374+
// if lora has changed, check to see if the cache should be cleared
2375+
if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
2376+
SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
2377+
slot.cache_tokens.clear();
2378+
} else {
2379+
SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
2380+
}
23722381
slot.lora = slot.params.lora;
23732382
}
23742383

2384+
// if using alora, make sure it's only a single one requested and active
2385+
size_t alora_invocation_start = slot.prompt_tokens.size();
2386+
if (lora_all_alora(slot.lora)) {
2387+
2388+
const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
2389+
// TODO: This will error out if a user requests two aloras, but only
2390+
// provides the activation string for one. We could, instead search
2391+
// for all requested alora activation strings and then either keep
2392+
// only the last one, or reject if multiple are found.
2393+
if (enabled_ids.size() != 1) {
2394+
send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
2395+
return false;
2396+
}
2397+
const auto & lora = slot.lora[enabled_ids[0]].ptr;
2398+
2399+
// get the pointer and count for the invocation tokens
2400+
const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
2401+
const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
2402+
2403+
// scan backwards through the prompt tokens to find the last
2404+
// occurrence of the invocation sequence
2405+
int match_idx = static_cast<int>(n_invocation_tokens) - 1;
2406+
for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
2407+
// the token in this position matches the next token to find in
2408+
// the invocation sequence
2409+
if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
2410+
// if it's a full match, we've found the start
2411+
if (match_idx == 0) {
2412+
alora_invocation_start = i;
2413+
break;
2414+
}
2415+
// otherwise, check the next token in the sequence
2416+
--match_idx;
2417+
} else {
2418+
// no match in this position, so start looking over again
2419+
match_idx = static_cast<int>(n_invocation_tokens) - 1;
2420+
}
2421+
}
2422+
2423+
// if the activation string is not found, disable the alora
2424+
if (alora_invocation_start == slot.prompt_tokens.size()) {
2425+
SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
2426+
slot.lora[enabled_ids[0]].scale = 0.0f;
2427+
} else {
2428+
SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
2429+
slot.alora_invocation_start = alora_invocation_start;
2430+
}
2431+
}
2432+
23752433
if (!slot.prompt_tokens.validate(ctx)) {
23762434
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
23772435
return false;
@@ -3247,6 +3305,8 @@ struct server_context {
32473305
int32_t n_ubatch = llama_n_ubatch(ctx);
32483306

32493307
// next, batch any pending prompts without exceeding n_batch
3308+
float alora_scale = -1.0f;
3309+
size_t alora_disabled_id = 0;
32503310
if (params_base.cont_batching || batch.n_tokens == 0) {
32513311
for (auto & slot : slots) {
32523312
// check if we can batch this slot with the previous one
@@ -3367,6 +3427,12 @@ struct server_context {
33673427
// reuse any previously computed tokens that are common with the new prompt
33683428
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
33693429

3430+
// if there is an alora invoked, don't cache after the invocation start
3431+
if (slot.alora_invocation_start >= 0) {
3432+
SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
3433+
slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
3434+
}
3435+
33703436
// reuse chunks from the cached prompt by shifting their KV cache in the new position
33713437
if (params_base.n_cache_reuse > 0) {
33723438
size_t head_c = slot.n_past; // cache
@@ -3539,6 +3605,20 @@ struct server_context {
35393605
slot.n_prompt_tokens_processed += n_pos;
35403606
}
35413607

3608+
// If using an alora, there may be uncached tokens that come
3609+
// before the invocation sequence. When this happens, the
3610+
// tokens before the invocation sequence need to be
3611+
// processed without the adpter in a separate batch, then
3612+
// the adapter needs to be enabled for the remaining tokens.
3613+
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
3614+
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
3615+
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
3616+
GGML_ASSERT(enabled_loras.size() == 1);
3617+
alora_scale = slot.lora[enabled_loras[0]].scale;
3618+
slot.lora[enabled_loras[0]].scale = 0.0f;
3619+
alora_disabled_id = enabled_loras[0];
3620+
}
3621+
35423622
// add prompt tokens for processing in the current batch
35433623
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
35443624
// get next token to process
@@ -3547,6 +3627,14 @@ struct server_context {
35473627
break; // end of text chunk
35483628
}
35493629

3630+
// if this is an alora request with pre-invocation
3631+
// tokens that are not cached, we need to stop filling
3632+
// this batch at those pre-invocation tokens.
3633+
if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
3634+
SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
3635+
break;
3636+
}
3637+
35503638
// embedding requires all tokens in the batch to be output
35513639
const bool need_embd = server_task_type_need_embd(slot.task_type);
35523640

@@ -3605,6 +3693,13 @@ struct server_context {
36053693
// apply lora, only need to do it once per batch
36063694
common_set_adapter_lora(ctx, slot_batched->lora);
36073695

3696+
// if the lora is temporarily disabled for an alora, re-enable it
3697+
// for next time
3698+
if (alora_scale > 0.0f) {
3699+
SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
3700+
slot_batched->lora[alora_disabled_id].scale = alora_scale;
3701+
}
3702+
36083703
llama_set_embeddings(ctx, slot_batched->need_embd());
36093704
}
36103705

@@ -4990,13 +5085,26 @@ int main(int argc, char ** argv) {
49905085
const auto & loras = ctx_server.params_base.lora_adapters;
49915086
for (size_t i = 0; i < loras.size(); ++i) {
49925087
auto & lora = loras[i];
4993-
result.push_back({
5088+
json entry = {
49945089
{"id", i},
49955090
{"path", lora.path},
49965091
{"scale", lora.scale},
49975092
{"task_name", lora.task_name},
49985093
{"prompt_prefix", lora.prompt_prefix},
4999-
});
5094+
};
5095+
std::string alora_invocation_string = "";
5096+
const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
5097+
std::vector<llama_token> alora_invocation_tokens;
5098+
if (n_alora_tokens) {
5099+
const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
5100+
for (uint64_t i = 0; i < n_alora_tokens; ++i) {
5101+
alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]);
5102+
alora_invocation_tokens.push_back(alora_tokens[i]);
5103+
}
5104+
entry["alora_invocation_string"] = alora_invocation_string;
5105+
entry["alora_invocation_tokens"] = alora_invocation_tokens;
5106+
}
5107+
result.push_back(std::move(entry));
50005108
}
50015109
res_ok(res, result);
50025110
res.status = 200; // HTTP OK

0 commit comments

Comments
 (0)