Skip to content

Commit 39ad0a3

Browse files
committed
Support for the exclusion of prohibited data types.
Signed-off-by: yuliasherman <[email protected]>
1 parent 2d7c32c commit 39ad0a3

File tree

7 files changed

+68
-16
lines changed

7 files changed

+68
-16
lines changed

openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@
490490
"local_runtime = LocalRuntime(\n",
491491
" aggregator=aggregator, collaborators=collaborators, backend=\"single_process\"\n",
492492
")\n",
493+
"# Set prohibited data types\n",
494+
"LocalRuntime.prohibited_data_types = [\"bytes\", \"bytearray\"]\n",
493495
"print(f\"Local runtime collaborators = {local_runtime.collaborators}\")"
494496
]
495497
},
@@ -609,7 +611,9 @@
609611
" collaborators=collaborator_names,\n",
610612
" director=director_info, \n",
611613
" notebook_path='./101_MNIST_FederatedRuntime.ipynb'\n",
612-
")"
614+
")\n",
615+
"# Set prohibited data types\n",
616+
"FederatedRuntime.prohibited_data_types = [\"bytes\", \"bytearray\"]"
613617
]
614618
},
615619
{

openfl/experimental/workflow/component/aggregator/aggregator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ async def run_flow(self) -> FLSpec:
198198
Returns:
199199
flow (FLSpec): Updated instance.
200200
"""
201+
# As an aggregator is created before a runtime, set prohibited data types for the runtime
202+
# before running the flow.
203+
self.flow.runtime.prohibited_data_types = FederatedRuntime.prohibited_data_types
201204
# Start function will be the first step if any flow
202205
f_name = "start"
203206
# Creating a clones from the flow object
@@ -394,7 +397,7 @@ def do_task(self, f_name: str) -> Any:
394397
# Create list of selected collaborator clones
395398
selected_clones = ([],)
396399
for name, clone in self.clones_dict.items():
397-
# Check if collaboraotr is in the list of selected
400+
# Check if collaborator is in the list of selected
398401
# collaborators
399402
if name in self.selected_collaborators:
400403
selected_clones[0].append(clone)

openfl/experimental/workflow/interface/fl_spec.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
filter_attributes,
2323
generate_artifacts,
2424
should_transfer,
25+
validate_data_types,
2526
)
2627

2728

@@ -136,7 +137,7 @@ def run(self) -> None:
136137

137138
def _run_local(self) -> None:
138139
"""Executes the flow using LocalRuntime."""
139-
self._setup_initial_state()
140+
self._setup_initial_state_local()
140141
try:
141142
# Execute all Participant (Aggregator & Collaborator) tasks and
142143
# retrieve the final attributes
@@ -164,7 +165,7 @@ def _run_local(self) -> None:
164165
for name, attr in final_attributes:
165166
setattr(self, name, attr)
166167

167-
def _setup_initial_state(self) -> None:
168+
def _setup_initial_state_local(self) -> None:
168169
"""
169170
Sets up the flow's initial state, initializing private attributes for
170171
collaborators and aggregators.
@@ -176,6 +177,7 @@ def _setup_initial_state(self) -> None:
176177
self._foreach_methods = []
177178
FLSpec._reset_clones()
178179
FLSpec._create_clones(self, self.runtime.collaborators)
180+
179181
# Initialize collaborator private attributes
180182
self.runtime.initialize_collaborators()
181183
if self._checkpoint:
@@ -343,8 +345,11 @@ def next(self, f, **kwargs) -> None:
343345
if aggregator_to_collaborator(f, parent_func):
344346
agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs)
345347

346-
# Remove included / excluded attributes from next task
347-
filter_attributes(self, f, **kwargs)
348+
# Remove unwanted attributes from next task
349+
if kwargs:
350+
filter_attributes(self, f, **kwargs)
351+
if self._runtime.prohibited_data_types:
352+
validate_data_types(self._runtime.prohibited_data_types, **kwargs)
348353

349354
if str(self._runtime) == "FederatedRuntime":
350355
if f.collaborator_step and not f.aggregator_step:

openfl/experimental/workflow/runtime/local_runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(
324324
The RayGroups run concurrently while participants in the
325325
group run serially.
326326
The default is 1 RayGroup and can be changed by using the
327-
num_actors=1 kwarg. By using more RayGroups more concurency
327+
num_actors=1 kwarg. By using more RayGroups more concurrency
328328
is allowed with the trade off being that each RayGroup has
329329
extra memory overhead in the form of extra CUDA CONTEXTS.
330330
@@ -738,6 +738,7 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k
738738
# new runtime object will not contain private attributes of
739739
# aggregator or other collaborators
740740
clone.runtime = LocalRuntime(backend="single_process")
741+
clone.runtime.prohibited_data_types = self.prohibited_data_types
741742

742743
# write the clone to the object store
743744
# ensure clone is getting latest _metaflow_interface

openfl/experimental/workflow/runtime/runtime.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@
1010
from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator
1111

1212

13-
class Runtime:
13+
class AttributeValidationMeta(type):
14+
def __setattr__(cls, name, value):
15+
if name == "prohibited_data_types":
16+
if not isinstance(value, list):
17+
raise TypeError(f"'{name}' must be a list, got {type(value).__name__}")
18+
if not all(isinstance(item, str) for item in value):
19+
raise TypeError(f"All elements of '{name}' must be strings")
20+
super().__setattr__(name, value)
21+
22+
23+
class Runtime(metaclass=AttributeValidationMeta):
24+
prohibited_data_types = []
25+
1426
def __init__(self):
1527
"""Initializes the Runtime object.
1628

openfl/experimental/workflow/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
filter_attributes,
1818
generate_artifacts,
1919
parse_attrs,
20+
validate_data_types,
2021
)
2122
from openfl.experimental.workflow.utilities.stream_redirect import (
2223
RedirectStdStream,

openfl/experimental/workflow/utilities/runtime_utils.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88
import itertools
99
from types import MethodType
10+
from typing import List
1011

1112
import numpy as np
1213

@@ -96,6 +97,31 @@ def filter_attributes(ctx, f, **kwargs):
9697
_process_exclusion(ctx, cls_attrs, kwargs["exclude"], f)
9798

9899

100+
def validate_data_types(
101+
prohibited_data_types: List[str], reserved_words=["collaborators"], **kwargs
102+
):
103+
"""Validates that the types of attributes in kwargs are not among the prohibited data types.
104+
Raises a TypeError if any prohibited data type is found.
105+
106+
Args:
107+
prohibited_data_types (List[str]): A list of prohibited data type names
108+
(e.g., ['int', 'float']).
109+
kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
110+
111+
Raises:
112+
TypeError: If any prohibited data types are found in kwargs.
113+
ValueError: If prohibited_data_types is empty.
114+
"""
115+
if not prohibited_data_types:
116+
raise ValueError("prohibited_data_types must not be empty.")
117+
for attr_name, attr_value in kwargs.items():
118+
if type(attr_value).__name__ in prohibited_data_types and attr_value not in reserved_words:
119+
raise TypeError(
120+
f"The attribute '{attr_name}' = '{attr_value}' has a prohibited value type: "
121+
f"{type(attr_value).__name__}"
122+
)
123+
124+
99125
def _validate_include_exclude(kwargs, cls_attrs):
100126
"""Validates that 'include' and 'exclude' are not both present, and that
101127
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +178,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152178
delattr(ctx, attr)
153179

154180

155-
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
181+
def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]):
156182
"""Optionally saves the current state for the task just executed.
157183
158184
Args:
159185
ctx (any): The context to checkpoint.
160186
parent_func (function): The function that was just executed.
161-
chkpnt_reserved_words (list, optional): A list of reserved words to
187+
checkpoint_reserved_words (list, optional): A list of reserved words to
162188
exclude from checkpointing. Defaults to ["next", "runtime"].
163189
164190
Returns:
@@ -173,7 +199,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173199
if ctx._checkpoint:
174200
# all objects will be serialized using Metaflow interface
175201
print(f"Saving data artifacts for {parent_func.__name__}")
176-
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words)
202+
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words)
177203
task_id = ctx._metaflow_interface.create_task(parent_func.__name__)
178204
ctx._metaflow_interface.save_artifacts(
179205
artifacts_iter(),
@@ -188,15 +214,15 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
188214

189215
def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
190216
remaining_gpu_memory = {}
191-
# TODO for each GPU the funtion tries see if all participant usages fit
217+
# TODO for each GPU the function tries see if all participant usages fit
192218
# into a GPU, it it doesn't it removes that participant from the
193219
# participant list, and adds it to the remaining_gpu_memory dict. So any
194220
# sum of GPU requirements above 1 triggers this.
195-
# But at this point the funtion will raise an error because
221+
# But at this point the function will raise an error because
196222
# remaining_gpu_memory is never cleared.
197223
# The participant list should remove the participant if it fits in the gpu
198-
# and save the partipant if it doesn't and continue to the next GPU to see
199-
# if it fits in that one, only if we run out of GPUs should this funtion
224+
# and save the participant if it doesn't and continue to the next GPU to see
225+
# if it fits in that one, only if we run out of GPUs should this function
200226
# raise an error.
201227
for gpu in np.ones(num_gpus, dtype=int):
202228
for i, (participant_name, participant_gpu_usage) in enumerate(
@@ -230,7 +256,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230256
if gpu == 0:
231257
break
232258
if gpu < participant_gpu_usage:
233-
# participant doesn't fitm break to next GPU
259+
# participant doesn't fit, break to next GPU
234260
break
235261
else:
236262
# if participant fits remove from need_assigned

0 commit comments

Comments
 (0)