File tree Expand file tree Collapse file tree 1 file changed +8
-5
lines changed
python/paddle/distributed/flex_checkpoint/aoa Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -92,15 +92,18 @@ def __init__(
9292 self .destination_state_shard_info = destination_state_shard_info
9393 self .left_var_to_right_var_mapping = {}
9494 self .right_var_from_left_var_mapping = {}
95+ self .dst_state_keys = set ()
96+ self .init_dst_state_keys ()
9597
96- def get_all_dst_state_keys (self ):
97- dst_state_keys = set ()
98+ def init_dst_state_keys (self ):
9899 if self .destination_state_shard_info is None :
99- return dst_state_keys
100+ return
100101 for k in self .destination_state_shard_info .keys ():
101102 model_state_key , _ = split_optimizer_state_key (k )
102- dst_state_keys .add (model_state_key )
103- return dst_state_keys
103+ self .dst_state_keys .add (model_state_key )
104+
105+ def get_all_dst_state_keys (self ):
106+ return self .dst_state_keys
104107
105108 def get_all_src_state_keys (self ):
106109 src_state_keys = set ()
You can’t perform that action at this time.
0 commit comments