Skip to content

Commit d33b0f9

Browse files
committed
Support for the exclusion of prohibited data types.
Signed-off-by: yuliasherman <[email protected]>
1 parent 7e67171 commit d33b0f9

File tree

7 files changed

+105
-17
lines changed

7 files changed

+105
-17
lines changed

openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@
490490
"local_runtime = LocalRuntime(\n",
491491
" aggregator=aggregator, collaborators=collaborators, backend=\"single_process\"\n",
492492
")\n",
493+
"# Set prohibited data types\n",
494+
"LocalRuntime.prohibited_data_types = [\"bytes\", \"bytearray\"]\n",
493495
"print(f\"Local runtime collaborators = {local_runtime.collaborators}\")"
494496
]
495497
},
@@ -609,7 +611,9 @@
609611
" collaborators=collaborator_names,\n",
610612
" director=director_info, \n",
611613
" notebook_path='./101_MNIST_FederatedRuntime.ipynb'\n",
612-
")"
614+
")\n",
615+
"# Set prohibited data types\n",
616+
"FederatedRuntime.prohibited_data_types = [\"bytes\", \"bytearray\"]"
613617
]
614618
},
615619
{

openfl/experimental/workflow/component/aggregator/aggregator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ async def run_flow(self) -> FLSpec:
198198
Returns:
199199
flow (FLSpec): Updated instance.
200200
"""
201+
# As an aggregator is created before a runtime, set prohibited data types for the runtime
202+
# before running the flow.
203+
self.flow.runtime.prohibited_data_types = FederatedRuntime.prohibited_data_types
201204
# Start function will be the first step if any flow
202205
f_name = "start"
203206
# Creating a clones from the flow object
@@ -394,7 +397,7 @@ def do_task(self, f_name: str) -> Any:
394397
# Create list of selected collaborator clones
395398
selected_clones = ([],)
396399
for name, clone in self.clones_dict.items():
397-
# Check if collaboraotr is in the list of selected
400+
# Check if collaborator is in the list of selected
398401
# collaborators
399402
if name in self.selected_collaborators:
400403
selected_clones[0].append(clone)

openfl/experimental/workflow/interface/fl_spec.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
filter_attributes,
2323
generate_artifacts,
2424
should_transfer,
25+
validate_data_types,
2526
)
2627

2728

@@ -127,16 +128,16 @@ def runtime(self, runtime: Type[Runtime]) -> None:
127128
def run(self) -> None:
128129
"""Starts the execution of the flow."""
129130
# Submit flow to Runtime
130-
if str(self._runtime) == "LocalRuntime":
131+
if str(self.runtime) == "LocalRuntime":
131132
self._run_local()
132-
elif str(self._runtime) == "FederatedRuntime":
133+
elif str(self.runtime) == "FederatedRuntime":
133134
self._run_federated()
134135
else:
135136
raise Exception("Runtime not implemented")
136137

137138
def _run_local(self) -> None:
138139
"""Executes the flow using LocalRuntime."""
139-
self._setup_initial_state()
140+
self._setup_initial_state_local()
140141
try:
141142
# Execute all Participant (Aggregator & Collaborator) tasks and
142143
# retrieve the final attributes
@@ -164,7 +165,7 @@ def _run_local(self) -> None:
164165
for name, attr in final_attributes:
165166
setattr(self, name, attr)
166167

167-
def _setup_initial_state(self) -> None:
168+
def _setup_initial_state_local(self) -> None:
168169
"""
169170
Sets up the flow's initial state, initializing private attributes for
170171
collaborators and aggregators.
@@ -176,6 +177,7 @@ def _setup_initial_state(self) -> None:
176177
self._foreach_methods = []
177178
FLSpec._reset_clones()
178179
FLSpec._create_clones(self, self.runtime.collaborators)
180+
179181
# Initialize collaborator private attributes
180182
self.runtime.initialize_collaborators()
181183
if self._checkpoint:
@@ -334,7 +336,7 @@ def next(self, f, **kwargs) -> None:
334336
parent = inspect.stack()[1][3]
335337
parent_func = getattr(self, parent)
336338

337-
if str(self._runtime) == "LocalRuntime":
339+
if str(self.runtime) == "LocalRuntime":
338340
# Checkpoint current attributes (if checkpoint==True)
339341
checkpoint(self, parent_func)
340342

@@ -343,10 +345,13 @@ def next(self, f, **kwargs) -> None:
343345
if aggregator_to_collaborator(f, parent_func):
344346
agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs)
345347

346-
# Remove included / excluded attributes from next task
347-
filter_attributes(self, f, **kwargs)
348+
# Remove unwanted attributes from next task
349+
if kwargs:
350+
filter_attributes(self, f, **kwargs)
351+
if self.runtime.prohibited_data_types:
352+
validate_data_types(self.runtime.prohibited_data_types, **kwargs)
348353

349-
if str(self._runtime) == "FederatedRuntime":
354+
if str(self.runtime) == "FederatedRuntime":
350355
if f.collaborator_step and not f.aggregator_step:
351356
self._foreach_methods.append(f.__name__)
352357

@@ -359,6 +364,6 @@ def next(self, f, **kwargs) -> None:
359364
kwargs,
360365
)
361366

362-
elif str(self._runtime) == "LocalRuntime":
367+
elif str(self.runtime) == "LocalRuntime":
363368
# update parameters required to execute execute_task function
364369
self.execute_task_args = [f, parent_func, agg_to_collab_ss, kwargs]

openfl/experimental/workflow/runtime/local_runtime.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k
738738
# new runtime object will not contain private attributes of
739739
# aggregator or other collaborators
740740
clone.runtime = LocalRuntime(backend="single_process")
741+
clone.runtime.prohibited_data_types = self.prohibited_data_types
741742

742743
# write the clone to the object store
743744
# ensure clone is getting latest _metaflow_interface

openfl/experimental/workflow/runtime/runtime.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,53 @@
1010
from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator
1111

1212

13-
class Runtime:
13+
class AttributeValidationMeta(type):
14+
"""
15+
Metaclass that enforces validation rules on class attributes.
16+
17+
This metaclass ensures that when assigning a value to the `prohibited_data_types`
18+
attribute, it must be a list of strings. If the conditions are not met, a
19+
TypeError is raised.
20+
21+
Example:
22+
class MyClass(metaclass=AttributeValidationMeta):
23+
pass
24+
25+
MyClass.prohibited_data_types = ["int", "float"] # Valid
26+
MyClass.prohibited_data_types = "int" # Raises TypeError
27+
"""
28+
29+
def __setattr__(cls, name, value):
30+
"""
31+
Validates and sets class attributes.
32+
Ensures that `prohibited_data_types`, when assigned, is a list of strings.
33+
34+
Args:
35+
name (str): The attribute name being set.
36+
value (any): The value to be assigned to the attribute.
37+
38+
Raises:
39+
TypeError: If `prohibited_data_types` is not a list or contains non-string elements.
40+
"""
41+
if name == "prohibited_data_types":
42+
if not isinstance(value, list):
43+
raise TypeError(f"'{name}' must be a list, got {type(value).__name__}")
44+
if not all(isinstance(item, str) for item in value):
45+
raise TypeError(f"All elements of '{name}' must be strings")
46+
super().__setattr__(name, value)
47+
48+
49+
class Runtime(metaclass=AttributeValidationMeta):
50+
"""
51+
Base class for federated learning runtimes.
52+
This class serves as an interface for runtimes that execute FLSpec flows.
53+
54+
Attributes:
55+
prohibited_data_types (list): A list of data types that are prohibited from being
56+
transmitted over the network.
57+
"""
58+
prohibited_data_types = []
59+
1460
def __init__(self):
1561
"""Initializes the Runtime object.
1662

openfl/experimental/workflow/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
filter_attributes,
1818
generate_artifacts,
1919
parse_attrs,
20+
validate_data_types,
2021
)
2122
from openfl.experimental.workflow.utilities.stream_redirect import (
2223
RedirectStdStream,

openfl/experimental/workflow/utilities/runtime_utils.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88
import itertools
99
from types import MethodType
10+
from typing import List
1011

1112
import numpy as np
1213

@@ -96,6 +97,33 @@ def filter_attributes(ctx, f, **kwargs):
9697
_process_exclusion(ctx, cls_attrs, kwargs["exclude"], f)
9798

9899

100+
def validate_data_types(
101+
prohibited_data_types: List[str], reserved_words=["collaborators"], **kwargs
102+
):
103+
"""Validates that the types of attributes in kwargs are not among the prohibited data types.
104+
Raises a TypeError if any prohibited data type is found.
105+
106+
Args:
107+
prohibited_data_types (List[str]): A list of prohibited data type names
108+
(e.g., ['int', 'float']).
109+
reserved_words: A list of strings that should be allowed as attribute values, even if 'str'
110+
is included in prohibited_data_types.
111+
kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
112+
113+
Raises:
114+
TypeError: If any prohibited data types are found in kwargs.
115+
ValueError: If prohibited_data_types is empty.
116+
"""
117+
if not prohibited_data_types:
118+
raise ValueError("prohibited_data_types must not be empty.")
119+
for attr_name, attr_value in kwargs.items():
120+
if type(attr_value).__name__ in prohibited_data_types and attr_value not in reserved_words:
121+
raise TypeError(
122+
f"The attribute '{attr_name}' = '{attr_value}' has a prohibited value type: "
123+
f"{type(attr_value).__name__}"
124+
)
125+
126+
99127
def _validate_include_exclude(kwargs, cls_attrs):
100128
"""Validates that 'include' and 'exclude' are not both present, and that
101129
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +180,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152180
delattr(ctx, attr)
153181

154182

155-
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
183+
def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]):
156184
"""Optionally saves the current state for the task just executed.
157185
158186
Args:
159187
ctx (any): The context to checkpoint.
160188
parent_func (function): The function that was just executed.
161-
chkpnt_reserved_words (list, optional): A list of reserved words to
189+
checkpoint_reserved_words (list, optional): A list of reserved words to
162190
exclude from checkpointing. Defaults to ["next", "runtime"].
163191
164192
Returns:
@@ -173,7 +201,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173201
if ctx._checkpoint:
174202
# all objects will be serialized using Metaflow interface
175203
print(f"Saving data artifacts for {parent_func.__name__}")
176-
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words)
204+
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words)
177205
task_id = ctx._metaflow_interface.create_task(parent_func.__name__)
178206
ctx._metaflow_interface.save_artifacts(
179207
artifacts_iter(),
@@ -195,7 +223,7 @@ def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
195223
# But at this point the function will raise an error because
196224
# remaining_gpu_memory is never cleared.
197225
# The participant list should remove the participant if it fits in the gpu
198-
# and save the partipant if it doesn't and continue to the next GPU to see
226+
# and save the participant if it doesn't and continue to the next GPU to see
199227
# if it fits in that one, only if we run out of GPUs should this function
200228
# raise an error.
201229
for gpu in np.ones(num_gpus, dtype=int):
@@ -230,7 +258,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230258
if gpu == 0:
231259
break
232260
if gpu < participant_gpu_usage:
233-
# participant doesn't fitm break to next GPU
261+
# participant doesn't fit, break to next GPU
234262
break
235263
else:
236264
# if participant fits remove from need_assigned

0 commit comments

Comments
 (0)