Skip to content

Commit 74d3777

Browse files
committed
fix aoa engine time long
1 parent 9d77411 commit 74d3777

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,18 @@ def __init__(
9292
self.destination_state_shard_info = destination_state_shard_info
9393
self.left_var_to_right_var_mapping = {}
9494
self.right_var_from_left_var_mapping = {}
95+
self.dst_state_keys = set()
96+
self.init_dst_state_keys()
9597

96-
def get_all_dst_state_keys(self):
97-
dst_state_keys = set()
98+
def init_dst_state_keys(self):
9899
if self.destination_state_shard_info is None:
99-
return dst_state_keys
100+
return
100101
for k in self.destination_state_shard_info.keys():
101102
model_state_key, _ = split_optimizer_state_key(k)
102-
dst_state_keys.add(model_state_key)
103-
return dst_state_keys
103+
self.dst_state_keys.add(model_state_key)
104+
105+
def get_all_dst_state_keys(self):
106+
return self.dst_state_keys
104107

105108
def get_all_src_state_keys(self):
106109
src_state_keys = set()

0 commit comments

Comments
 (0)