@@ -92,22 +92,28 @@ def __init__(
9292 self .destination_state_shard_info = destination_state_shard_info
9393 self .left_var_to_right_var_mapping = {}
9494 self .right_var_from_left_var_mapping = {}
95+ self .src_state_keys = set ()
96+ self .dst_state_keys = set ()
97+ self .init_dst_state_keys ()
98+ self .init_dst_state_keys ()
9599
96- def get_all_dst_state_keys (self ):
97- dst_state_keys = set ()
100+ def init_src_state_keys (self ):
101+ for k in self .source_state_shard_info .keys ():
102+ model_state_key , _ = split_optimizer_state_key (k )
103+ self .src_state_keys .add (model_state_key )
104+
105+ def init_dst_state_keys (self ):
98106 if self .destination_state_shard_info is None :
99- return dst_state_keys
107+ return
100108 for k in self .destination_state_shard_info .keys ():
101109 model_state_key , _ = split_optimizer_state_key (k )
102- dst_state_keys .add (model_state_key )
103- return dst_state_keys
110+ self .dst_state_keys .add (model_state_key )
111+
112+ def get_all_dst_state_keys (self ):
113+ return self .dst_state_keys
104114
105115 def get_all_src_state_keys (self ):
106- src_state_keys = set ()
107- for k in self .source_state_shard_info .keys ():
108- model_state_key , _ = split_optimizer_state_key (k )
109- src_state_keys .add (model_state_key )
110- return src_state_keys
116+ return self .src_state_keys
111117
112118 def get_num_hidden_layers (
113119 self ,
0 commit comments