Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,6 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta):
# dist.barrier() are called when _save_model_to_hf finished

if dist.get_rank() == 0:
fut.result()

update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
Expand All @@ -401,6 +399,8 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta):
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
)

fut.result()

dist.barrier(device_ids=[self.device.index])
current_platform.synchronize()

Expand All @@ -420,6 +420,9 @@ def connect_engine(self, engine: InferenceEngine, meta: WeightUpdateMeta):
)
self.rollout_engine = engine

if meta.type == "disk":
return

if not self.weight_update_group_initialized:
self._init_weight_update_from_distributed(meta)
self.weight_update_group_initialized = True
Expand Down
Loading