Skip to content

Commit 7233358

Browse files
l3utterflyCISC
andauthored
memory : handle saving/loading null layers in recurrent memory (#14675)
* Update llama-memory-recurrent.cpp handle saving/loading null layers in recurrent memory * fixed styling issues and updated comments * fix styling issue Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 6c88b3b commit 7233358

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

src/llama-memory-recurrent.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
768768
// Iterate and write all the keys first, each row is a cell
769769
// Get whole range at a time
770770
for (uint32_t il = 0; il < n_layer; ++il) {
771+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
772+
if (r_l[il] == nullptr) continue;
771773

772774
// Write key type
773775
const int32_t r_type_i = (int32_t)r_l[il]->type;
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
787789

788790
if (!s_trans) {
789791
for (uint32_t il = 0; il < n_layer; ++il) {
792+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
793+
if (s_l[il] == nullptr) continue;
790794

791795
// Write value type
792796
const int32_t s_type_i = (int32_t)s_l[il]->type;
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
807811
// When v is transposed, we also need the element size and get the element ranges from each row
808812
const uint32_t mem_size = size;
809813
for (uint32_t il = 0; il < n_layer; ++il) {
814+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
815+
if (s_l[il] == nullptr) continue;
816+
810817
const uint32_t n_embd_s = hparams.n_embd_s();
811818

812819
// Write value type
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
951958

952959
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
953960
for (uint32_t il = 0; il < n_layer; ++il) {
961+
// skip null layers
962+
if (r_l[il] == nullptr) continue;
954963

955964
// Read type of key
956965
int32_t r_type_i_ref;
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
978987

979988
if (!s_trans) {
980989
for (uint32_t il = 0; il < n_layer; ++il) {
990+
// skip null layers
991+
if (s_l[il] == nullptr) continue;
981992

982993
// Read type of value
983994
int32_t s_type_i_ref;
984995
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
985996
const int32_t s_type_i = (int32_t)s_l[il]->type;
997+
986998
if (s_type_i != s_type_i_ref) {
987999
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
9881000
return false;
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
10051017
} else {
10061018
// For each layer, read the values for each cell (transposed)
10071019
for (uint32_t il = 0; il < n_layer; ++il) {
1020+
// skip null layers
1021+
if (s_l[il] == nullptr) continue;
1022+
10081023
const uint32_t n_embd_s = hparams.n_embd_s();
10091024

10101025
// Read type of value

0 commit comments

Comments
 (0)