Skip to content

Commit 69c7f39

Browse files
committed
Support for the exclusion of prohibited data types.
Signed-off-by: yuliasherman <[email protected]>
1 parent bf0fbb8 commit 69c7f39

File tree

7 files changed

+152
-17
lines changed

7 files changed

+152
-17
lines changed

openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,9 @@
609609
" collaborators=collaborator_names,\n",
610610
" director=director_info, \n",
611611
" notebook_path='./101_MNIST_FederatedRuntime.ipynb'\n",
612-
")"
612+
")\n",
613+
"# Set allowed data types\n",
614+
"FederatedRuntime.allowed_data_types = [\"str\", \"int\", \"float\", \"bool\"]"
613615
]
614616
},
615617
{

openfl/experimental/workflow/component/aggregator/aggregator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ async def run_flow(self) -> FLSpec:
198198
Returns:
199199
flow (FLSpec): Updated instance.
200200
"""
201+
# As an aggregator is created before a runtime, set prohibited/allowed data types for the
202+
# runtime before running the flow.
203+
self.flow.runtime.prohibited_data_types = FederatedRuntime.prohibited_data_types
204+
self.flow.runtime.allowed_data_types = FederatedRuntime.allowed_data_types
201205
# Start function will be the first step if any flow
202206
f_name = "start"
203207
# Creating a clones from the flow object
@@ -394,7 +398,7 @@ def do_task(self, f_name: str) -> Any:
394398
# Create list of selected collaborator clones
395399
selected_clones = ([],)
396400
for name, clone in self.clones_dict.items():
397-
# Check if collaboraotr is in the list of selected
401+
# Check if collaborator is in the list of selected
398402
# collaborators
399403
if name in self.selected_collaborators:
400404
selected_clones[0].append(clone)

openfl/experimental/workflow/interface/fl_spec.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
filter_attributes,
2323
generate_artifacts,
2424
should_transfer,
25+
validate_data_types,
2526
)
2627

2728

@@ -127,16 +128,16 @@ def runtime(self, runtime: Type[Runtime]) -> None:
127128
def run(self) -> None:
128129
"""Starts the execution of the flow."""
129130
# Submit flow to Runtime
130-
if str(self._runtime) == "LocalRuntime":
131+
if str(self.runtime) == "LocalRuntime":
131132
self._run_local()
132-
elif str(self._runtime) == "FederatedRuntime":
133+
elif str(self.runtime) == "FederatedRuntime":
133134
self._run_federated()
134135
else:
135136
raise Exception("Runtime not implemented")
136137

137138
def _run_local(self) -> None:
138139
"""Executes the flow using LocalRuntime."""
139-
self._setup_initial_state()
140+
self._setup_initial_state_local()
140141
try:
141142
# Execute all Participant (Aggregator & Collaborator) tasks and
142143
# retrieve the final attributes
@@ -164,7 +165,7 @@ def _run_local(self) -> None:
164165
for name, attr in final_attributes:
165166
setattr(self, name, attr)
166167

167-
def _setup_initial_state(self) -> None:
168+
def _setup_initial_state_local(self) -> None:
168169
"""
169170
Sets up the flow's initial state, initializing private attributes for
170171
collaborators and aggregators.
@@ -176,6 +177,7 @@ def _setup_initial_state(self) -> None:
176177
self._foreach_methods = []
177178
FLSpec._reset_clones()
178179
FLSpec._create_clones(self, self.runtime.collaborators)
180+
179181
# Initialize collaborator private attributes
180182
self.runtime.initialize_collaborators()
181183
if self._checkpoint:
@@ -334,7 +336,7 @@ def next(self, f, **kwargs) -> None:
334336
parent = inspect.stack()[1][3]
335337
parent_func = getattr(self, parent)
336338

337-
if str(self._runtime) == "LocalRuntime":
339+
if str(self.runtime) == "LocalRuntime":
338340
# Checkpoint current attributes (if checkpoint==True)
339341
checkpoint(self, parent_func)
340342

@@ -343,10 +345,15 @@ def next(self, f, **kwargs) -> None:
343345
if aggregator_to_collaborator(f, parent_func):
344346
agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs)
345347

346-
# Remove included / excluded attributes from next task
347-
filter_attributes(self, f, **kwargs)
348+
if kwargs:
349+
# Remove unwanted attributes from next task
350+
filter_attributes(self, f, **kwargs)
351+
if self.runtime.prohibited_data_types or self.runtime.allowed_data_types:
352+
validate_data_types(
353+
self.runtime.prohibited_data_types, self.runtime.allowed_data_types, **kwargs
354+
)
348355

349-
if str(self._runtime) == "FederatedRuntime":
356+
if str(self.runtime) == "FederatedRuntime":
350357
if f.collaborator_step and not f.aggregator_step:
351358
self._foreach_methods.append(f.__name__)
352359

@@ -359,6 +366,6 @@ def next(self, f, **kwargs) -> None:
359366
kwargs,
360367
)
361368

362-
elif str(self._runtime) == "LocalRuntime":
369+
elif str(self.runtime) == "LocalRuntime":
363370
# update parameters required to execute execute_task function
364371
self.execute_task_args = [f, parent_func, agg_to_collab_ss, kwargs]

openfl/experimental/workflow/runtime/local_runtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,8 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k
738738
# new runtime object will not contain private attributes of
739739
# aggregator or other collaborators
740740
clone.runtime = LocalRuntime(backend="single_process")
741+
clone.runtime.prohibited_data_types = self.prohibited_data_types
742+
clone.runtime.allowed_data_types = self.allowed_data_types
741743

742744
# write the clone to the object store
743745
# ensure clone is getting latest _metaflow_interface

openfl/experimental/workflow/runtime/runtime.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,77 @@
1010
from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator
1111

1212

13-
class Runtime:
13+
class AttributeValidationMeta(type):
14+
"""
15+
Metaclass that enforces validation rules on class attributes.
16+
17+
This metaclass ensures that `prohibited_data_types` and `allowed_data_types`
18+
are lists of strings and that both cannot be set at the same time.
19+
20+
Example:
21+
class MyClass(metaclass=AttributeValidationMeta):
22+
pass
23+
24+
MyClass.prohibited_data_types = ["int", "float"] # Valid
25+
MyClass.allowed_data_types = ["str", "bool"] # Valid
26+
MyClass.prohibited_data_types = "int" # Raises TypeError
27+
MyClass.allowed_data_types = 42 # Raises TypeError
28+
"""
29+
30+
def __setattr__(cls, name, value):
31+
"""
32+
Validates and sets class attributes.
33+
34+
Ensures that `prohibited_data_types` and `allowed_data_types`, when assigned,
35+
are lists of strings and that they are not used together.
36+
37+
Args:
38+
name (str): The attribute name being set.
39+
value (any): The value to be assigned to the attribute.
40+
41+
Raises:
42+
TypeError: If `prohibited_data_types` or `allowed_data_types` is not a list
43+
or contains non-string elements.
44+
ValueError: If both `prohibited_data_types` and `allowed_data_types` are set.
45+
"""
46+
if name in {"prohibited_data_types", "allowed_data_types"}:
47+
if not isinstance(value, list):
48+
raise TypeError(f"'{name}' must be a list, got {type(value).__name__}")
49+
if not all(isinstance(item, str) for item in value):
50+
raise TypeError(f"All elements of '{name}' must be strings")
51+
52+
# Ensure both attributes are not set at the same time
53+
other_name = (
54+
"allowed_data_types" if name == "prohibited_data_types" else "prohibited_data_types"
55+
)
56+
if getattr(cls, other_name, []): # Check if the other attribute is already set
57+
raise ValueError(
58+
"Cannot set both 'prohibited_data_types' and 'allowed_data_types'."
59+
)
60+
61+
super().__setattr__(name, value)
62+
63+
64+
class Runtime(metaclass=AttributeValidationMeta):
65+
"""
66+
Base class for federated learning runtimes.
67+
This class serves as an interface for runtimes that execute FLSpec flows.
68+
69+
Attributes:
70+
prohibited_data_types (list): A list of data types that are prohibited from being
71+
transmitted over the network.
72+
allowed_data_types (list): A list of data types that are explicitly allowed to be
73+
transmitted over the network.
74+
75+
Notes:
76+
- Either `prohibited_data_types` or `allowed_data_types` may be specified.
77+
- If neither is specified, all data types are allowed to be transmitted.
78+
- If both are specified, a `ValueError` will be raised.
79+
"""
80+
81+
prohibited_data_types = []
82+
allowed_data_types = []
83+
1484
def __init__(self):
1585
"""Initializes the Runtime object.
1686

openfl/experimental/workflow/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
filter_attributes,
1818
generate_artifacts,
1919
parse_attrs,
20+
validate_data_types,
2021
)
2122
from openfl.experimental.workflow.utilities.stream_redirect import (
2223
RedirectStdStream,

openfl/experimental/workflow/utilities/runtime_utils.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88
import itertools
99
from types import MethodType
10+
from typing import List
1011

1112
import numpy as np
1213

@@ -96,6 +97,54 @@ def filter_attributes(ctx, f, **kwargs):
9697
_process_exclusion(ctx, cls_attrs, kwargs["exclude"], f)
9798

9899

100+
def validate_data_types(
101+
prohibited_data_types: List[str] = None,
102+
allowed_data_types: List[str] = None,
103+
reserved_words=["collaborators"],
104+
**kwargs,
105+
):
106+
"""Validates that the types of attributes in kwargs are not among the prohibited data types
107+
and are among the allowed data types if specified.
108+
Raises a TypeError if any prohibited data type is found or if a type is not allowed.
109+
110+
Args:
111+
prohibited_data_types (List[str], optional): A list of prohibited data type names
112+
(e.g., ['int', 'float']).
113+
allowed_data_types (List[str], optional): A list of allowed data type names.
114+
reserved_words: A list of strings that should be allowed as attribute values, even if 'str'
115+
is included in prohibited_data_types.
116+
kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
117+
118+
Raises:
119+
TypeError: If any prohibited data types are found in kwargs or if a type is not allowed.
120+
ValueError: If both prohibited_data_types and allowed_data_types are set simultaneously.
121+
"""
122+
if prohibited_data_types is None:
123+
prohibited_data_types = []
124+
if allowed_data_types is None:
125+
allowed_data_types = []
126+
127+
if prohibited_data_types and allowed_data_types:
128+
raise ValueError("Cannot set both 'prohibited_data_types' and 'allowed_data_types'.")
129+
130+
for attr_name, attr_value in kwargs.items():
131+
attr_type = type(attr_value).__name__
132+
if (
133+
prohibited_data_types
134+
and attr_type in prohibited_data_types
135+
and attr_value not in reserved_words
136+
):
137+
raise TypeError(
138+
f"The attribute '{attr_name}' = '{attr_value}' "
139+
f"has a prohibited value type: {attr_type}"
140+
)
141+
if allowed_data_types and attr_type not in allowed_data_types:
142+
raise TypeError(
143+
f"The attribute '{attr_name}' = '{attr_value}' "
144+
f"has a type that is not allowed: {attr_type}"
145+
)
146+
147+
99148
def _validate_include_exclude(kwargs, cls_attrs):
100149
"""Validates that 'include' and 'exclude' are not both present, and that
101150
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +201,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152201
delattr(ctx, attr)
153202

154203

155-
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
204+
def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]):
156205
"""Optionally saves the current state for the task just executed.
157206
158207
Args:
159208
ctx (any): The context to checkpoint.
160209
parent_func (function): The function that was just executed.
161-
chkpnt_reserved_words (list, optional): A list of reserved words to
210+
checkpoint_reserved_words (list, optional): A list of reserved words to
162211
exclude from checkpointing. Defaults to ["next", "runtime"].
163212
164213
Returns:
@@ -173,7 +222,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173222
if ctx._checkpoint:
174223
# all objects will be serialized using Metaflow interface
175224
print(f"Saving data artifacts for {parent_func.__name__}")
176-
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words)
225+
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words)
177226
task_id = ctx._metaflow_interface.create_task(parent_func.__name__)
178227
ctx._metaflow_interface.save_artifacts(
179228
artifacts_iter(),
@@ -195,7 +244,7 @@ def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
195244
# But at this point the function will raise an error because
196245
# remaining_gpu_memory is never cleared.
197246
# The participant list should remove the participant if it fits in the gpu
198-
# and save the partipant if it doesn't and continue to the next GPU to see
247+
# and save the participant if it doesn't and continue to the next GPU to see
199248
# if it fits in that one, only if we run out of GPUs should this function
200249
# raise an error.
201250
for gpu in np.ones(num_gpus, dtype=int):
@@ -230,7 +279,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230279
if gpu == 0:
231280
break
232281
if gpu < participant_gpu_usage:
233-
# participant doesn't fitm break to next GPU
282+
# participant doesn't fit, break to next GPU
234283
break
235284
else:
236285
# if participant fits remove from need_assigned

0 commit comments

Comments
 (0)