@@ -901,17 +901,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
901
901
const int64_t n_embd = hparams.n_embd ;
902
902
903
903
// when computing embeddings, all tokens are output
904
- const bool embd_all = cparams.embeddings ;
904
+ const bool output_all = cparams.embeddings ;
905
905
906
- if (!batch_allocr->init (batch_inp, vocab, memory.get (), n_embd, embd_all )) {
906
+ if (!batch_allocr->init (batch_inp, vocab, memory.get (), n_embd, output_all )) {
907
907
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
908
908
return -1 ;
909
909
}
910
910
911
911
const uint32_t n_tokens_all = batch_allocr->get_n_tokens ();
912
912
const uint32_t n_outputs_all = batch_allocr->get_n_outputs ();
913
913
914
- if (embd_all ) {
914
+ if (output_all ) {
915
915
// require that all tokens are output
916
916
if (n_outputs_all != n_tokens_all) {
917
917
LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " ,
@@ -940,7 +940,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
940
940
llama_memory_state_ptr mstate;
941
941
942
942
while (true ) {
943
- mstate = memory->init_batch (batch_allocr.get (), cparams.n_ubatch , embd_all );
943
+ mstate = memory->init_batch (batch_allocr.get (), cparams.n_ubatch , output_all );
944
944
if (!mstate) {
945
945
return -2 ;
946
946
}
0 commit comments