@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
768
768
// Iterate and write all the keys first, each row is a cell
769
769
// Get whole range at a time
770
770
for (uint32_t il = 0 ; il < n_layer; ++il) {
771
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
772
+ if (r_l[il] == nullptr ) continue ;
771
773
772
774
// Write key type
773
775
const int32_t r_type_i = (int32_t )r_l[il]->type ;
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
787
789
788
790
if (!s_trans) {
789
791
for (uint32_t il = 0 ; il < n_layer; ++il) {
792
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
793
+ if (s_l[il] == nullptr ) continue ;
790
794
791
795
// Write value type
792
796
const int32_t s_type_i = (int32_t )s_l[il]->type ;
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
807
811
// When v is transposed, we also need the element size and get the element ranges from each row
808
812
const uint32_t mem_size = size;
809
813
for (uint32_t il = 0 ; il < n_layer; ++il) {
814
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
815
+ if (s_l[il] == nullptr ) continue ;
816
+
810
817
const uint32_t n_embd_s = hparams.n_embd_s ();
811
818
812
819
// Write value type
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
951
958
952
959
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
953
960
for (uint32_t il = 0 ; il < n_layer; ++il) {
961
+ // skip null layers
962
+ if (r_l[il] == nullptr ) continue ;
954
963
955
964
// Read type of key
956
965
int32_t r_type_i_ref;
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
978
987
979
988
if (!s_trans) {
980
989
for (uint32_t il = 0 ; il < n_layer; ++il) {
990
+ // skip null layers
991
+ if (s_l[il] == nullptr ) continue ;
981
992
982
993
// Read type of value
983
994
int32_t s_type_i_ref;
984
995
io.read_to (&s_type_i_ref, sizeof (s_type_i_ref));
985
996
const int32_t s_type_i = (int32_t )s_l[il]->type ;
997
+
986
998
if (s_type_i != s_type_i_ref) {
987
999
LLAMA_LOG_ERROR (" %s: mismatched s type (%d != %d, layer %d)\n " , __func__, s_type_i, s_type_i_ref, il);
988
1000
return false ;
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
1005
1017
} else {
1006
1018
// For each layer, read the values for each cell (transposed)
1007
1019
for (uint32_t il = 0 ; il < n_layer; ++il) {
1020
+ // skip null layers
1021
+ if (s_l[il] == nullptr ) continue ;
1022
+
1008
1023
const uint32_t n_embd_s = hparams.n_embd_s ();
1009
1024
1010
1025
// Read type of value
0 commit comments