Skip to content

Commit e83fb4c

Browse files
ishant162payalchateoparvanov
authored
[Workflow API] Optimize flow state being transferred between participants in LocalRuntime & FederatedRuntime (#1589)
* add execute_task_args to reserved_keywords Signed-off-by: Ishant Thakare <[email protected]> * optimized flow state Signed-off-by: Ishant Thakare <[email protected]> --------- Signed-off-by: Ishant Thakare <[email protected]> Co-authored-by: Payal Chaurasiya <[email protected]> Co-authored-by: teoparvanov <[email protected]>
1 parent a61993e commit e83fb4c

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

openfl/experimental/workflow/component/aggregator/aggregator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import inspect
99
import queue
1010
import time
11+
from copy import deepcopy
1112
from logging import getLogger
1213
from threading import Event
1314
from typing import Any, Callable, Dict, List, Tuple
@@ -16,7 +17,11 @@
1617

1718
from openfl.experimental.workflow.interface import FLSpec
1819
from openfl.experimental.workflow.runtime import FederatedRuntime
19-
from openfl.experimental.workflow.utilities import aggregator_to_collaborator, checkpoint
20+
from openfl.experimental.workflow.utilities import (
21+
aggregator_to_collaborator,
22+
checkpoint,
23+
generate_artifacts,
24+
)
2025
from openfl.experimental.workflow.utilities.metaflow_utils import MetaflowInterface
2126

2227
logger = getLogger(__name__)
@@ -43,7 +48,8 @@ class Aggregator:
4348
collaborator_task_results (Event): Event to inform aggregator that
4449
collaborators have sent the results.
4550
__collaborator_tasks_queue (Dict[Queue]): queue for each collaborator.
46-
flow (Any): Flow class.
51+
flow (FLSpec): Flow class.
52+
final_flow_state (FLSpec): Final flow state.
4753
name (str): aggregator in string format.
4854
checkpoint (bool): Whether to save checkpoint or not (default=False).
4955
private_attrs_callable (Callable): Function for Aggregator private
@@ -124,6 +130,7 @@ def __init__(
124130
self.__collaborator_tasks_queue = {collab: queue.Queue() for collab in self.authorized_cols}
125131

126132
self.flow = flow
133+
self.final_flow_state = deepcopy(flow)
127134
self.checkpoint = checkpoint
128135
self.flow._foreach_methods = []
129136
logger.info("MetaflowInterface creation.")
@@ -181,6 +188,12 @@ def _log_big_warning(self) -> None:
181188
f" WARNED!!!"
182189
)
183190

191+
def _update_final_flow(self) -> None:
192+
"""Update the final flow state with current flow artifacts."""
193+
artifacts_iter, _ = generate_artifacts(ctx=self.flow)
194+
for name, attr in artifacts_iter():
195+
setattr(self.final_flow_state, name, deepcopy(attr))
196+
184197
@staticmethod
185198
def _get_sleep_time() -> int:
186199
"""Sleep 10 seconds.
@@ -250,7 +263,8 @@ async def run_flow(self) -> FLSpec:
250263
self.flow.restore_instance_snapshot(self.flow, list(self.instance_snapshot))
251264
delattr(self, "instance_snapshot")
252265

253-
return self.flow
266+
self._update_final_flow()
267+
return self.final_flow_state
254268

255269
def call_checkpoint(
256270
self, name: str, ctx: Any, f: Callable, stream_buffer: bytes = None

openfl/experimental/workflow/utilities/runtime_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,16 @@ def parse_attrs(ctx, exclude=[], reserved_words=["next", "runtime", "input"]):
4545
return cls_attrs, valid_artifacts
4646

4747

48-
def generate_artifacts(ctx, reserved_words=["next", "runtime", "input", "checkpoint"]):
48+
def generate_artifacts(
49+
ctx, reserved_words=["next", "runtime", "input", "checkpoint", "execute_task_args"]
50+
):
4951
"""Generates artifacts from the given context, excluding specified reserved
5052
words.
5153
5254
Args:
5355
ctx (any): The context to generate artifacts from.
5456
reserved_words (list, optional): A list of reserved words to exclude.
55-
Defaults to ["next", "runtime", "input", "checkpoint"].
57+
Defaults to ["next", "runtime", "input", "checkpoint", "execute_task_args"].
5658
5759
Returns:
5860
tuple: A tuple containing a generator of artifacts and a list of
@@ -152,14 +154,14 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152154
delattr(ctx, attr)
153155

154156

155-
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
157+
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime", "execute_task_args"]):
156158
"""Optionally saves the current state for the task just executed.
157159
158160
Args:
159161
ctx (any): The context to checkpoint.
160162
parent_func (function): The function that was just executed.
161163
chkpnt_reserved_words (list, optional): A list of reserved words to
162-
exclude from checkpointing. Defaults to ["next", "runtime"].
164+
exclude from checkpointing. Defaults to ["next", "runtime", "execute_task_args"].
163165
164166
Returns:
165167
step_stdout (io.StringIO): parent_func stdout

0 commit comments

Comments
 (0)