Skip to content

Commit c3a398a

Browse files
committed
fix aoa engine time long
Fix long running time in AOA Engine
1 parent 9d77411 commit c3a398a

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,28 @@ def __init__(
9292
self.destination_state_shard_info = destination_state_shard_info
9393
self.left_var_to_right_var_mapping = {}
9494
self.right_var_from_left_var_mapping = {}
95+
self.src_state_keys = set()
96+
self.dst_state_keys = set()
97+
self.init_dst_state_keys()
98+
self.init_dst_state_keys()
9599

96-
def get_all_dst_state_keys(self):
97-
dst_state_keys = set()
100+
def init_src_state_keys(self):
101+
for k in self.source_state_shard_info.keys():
102+
model_state_key, _ = split_optimizer_state_key(k)
103+
self.src_state_keys.add(model_state_key)
104+
105+
def init_dst_state_keys(self):
98106
if self.destination_state_shard_info is None:
99-
return dst_state_keys
107+
return
100108
for k in self.destination_state_shard_info.keys():
101109
model_state_key, _ = split_optimizer_state_key(k)
102-
dst_state_keys.add(model_state_key)
103-
return dst_state_keys
110+
self.dst_state_keys.add(model_state_key)
111+
112+
def get_all_dst_state_keys(self):
113+
return self.dst_state_keys
104114

105115
def get_all_src_state_keys(self):
106-
src_state_keys = set()
107-
for k in self.source_state_shard_info.keys():
108-
model_state_key, _ = split_optimizer_state_key(k)
109-
src_state_keys.add(model_state_key)
110-
return src_state_keys
116+
return self.src_state_keys
111117

112118
def get_num_hidden_layers(
113119
self,

0 commit comments

Comments
 (0)