Skip to content

Commit 758496b

Browse files
committed
Support for the exclusion of prohibited data types.
Signed-off-by: yuliasherman <[email protected]>
1 parent 47ad2d0 commit 758496b

File tree

7 files changed

+65
-15
lines changed

7 files changed

+65
-15
lines changed

openfl/experimental/workflow/component/aggregator/aggregator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ async def run_flow(self) -> FLSpec:
198198
Returns:
199199
flow (FLSpec): Updated instance.
200200
"""
201+
# As an aggregator is created before a runtime, set prohibited data types for the runtime
202+
# before running the flow.
203+
self.flow.runtime.prohibited_data_types = FederatedRuntime.prohibited_data_types
201204
# Start function will be the first step if any flow
202205
f_name = "start"
203206
# Creating a clones from the flow object
@@ -394,7 +397,7 @@ def do_task(self, f_name: str) -> Any:
394397
# Create list of selected collaborator clones
395398
selected_clones = ([],)
396399
for name, clone in self.clones_dict.items():
397-
# Check if collaboraotr is in the list of selected
400+
# Check if collaborator is in the list of selected
398401
# collaborators
399402
if name in self.selected_collaborators:
400403
selected_clones[0].append(clone)

openfl/experimental/workflow/component/collaborator/collaborator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Collaborator:
2323
Aggregator Server.
2424
2525
private_attrs_callable (Callable): Function for Collaborator
26-
private attriubtes.
26+
private attributes.
2727
private_attrs_kwargs (Dict): Arguments to call private_attrs_callable.
2828
2929
Note:

openfl/experimental/workflow/interface/fl_spec.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
filter_attributes,
2323
generate_artifacts,
2424
should_transfer,
25+
validate_data_types,
2526
)
2627

2728

@@ -164,6 +165,7 @@ def _run_local(self) -> None:
164165
for name, attr in final_attributes:
165166
setattr(self, name, attr)
166167

168+
# Runs only for LocalRuntime
167169
def _setup_initial_state(self) -> None:
168170
"""
169171
Sets up the flow's initial state, initializing private attributes for
@@ -176,6 +178,7 @@ def _setup_initial_state(self) -> None:
176178
self._foreach_methods = []
177179
FLSpec._reset_clones()
178180
FLSpec._create_clones(self, self.runtime.collaborators)
181+
179182
# Initialize collaborator private attributes
180183
self.runtime.initialize_collaborators()
181184
if self._checkpoint:
@@ -343,8 +346,11 @@ def next(self, f, **kwargs) -> None:
343346
if aggregator_to_collaborator(f, parent_func):
344347
agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs)
345348

346-
# Remove included / excluded attributes from next task
347-
filter_attributes(self, f, **kwargs)
349+
# Remove unwanted attributes from next task
350+
if kwargs:
351+
filter_attributes(self, f, **kwargs)
352+
if self._runtime.prohibited_data_types:
353+
validate_data_types(self._runtime.prohibited_data_types, **kwargs)
348354

349355
if str(self._runtime) == "FederatedRuntime":
350356
if f.collaborator_step and not f.aggregator_step:

openfl/experimental/workflow/runtime/local_runtime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(
324324
The RayGroups run concurrently while participants in the
325325
group run serially.
326326
The default is 1 RayGroup and can be changed by using the
327-
num_actors=1 kwarg. By using more RayGroups more concurency
327+
num_actors=1 kwarg. By using more RayGroups more concurrency
328328
is allowed with the trade off being that each RayGroup has
329329
extra memory overhead in the form of extra CUDA CONTEXTS.
330330
@@ -737,7 +737,9 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k
737737
# Set new LocalRuntime for clone as it is required
738738
# new runtime object will not contain private attributes of
739739
# aggregator or other collaborators
740-
clone.runtime = LocalRuntime(backend="single_process")
740+
clone.runtime = LocalRuntime(
741+
backend="single_process", prohibited_data_types=self.prohibited_data_types
742+
)
741743

742744
# write the clone to the object store
743745
# ensure clone is getting latest _metaflow_interface

openfl/experimental/workflow/runtime/runtime.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@
1010
from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator
1111

1212

13-
class Runtime:
13+
class AttributeValidationMeta(type):
14+
def __setattr__(cls, name, value):
15+
if name == "prohibited_data_types":
16+
if not isinstance(value, list):
17+
raise TypeError(f"'{name}' must be a list, got {type(value).__name__}")
18+
if not all(isinstance(item, str) for item in value):
19+
raise TypeError(f"All elements of '{name}' must be strings")
20+
super().__setattr__(name, value)
21+
22+
23+
class Runtime(metaclass=AttributeValidationMeta):
24+
prohibited_data_types = []
25+
1426
def __init__(self):
1527
"""Initializes the Runtime object.
1628

openfl/experimental/workflow/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
filter_attributes,
1818
generate_artifacts,
1919
parse_attrs,
20+
validate_data_types,
2021
)
2122
from openfl.experimental.workflow.utilities.stream_redirect import (
2223
RedirectStdStream,

openfl/experimental/workflow/utilities/runtime_utils.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88
import itertools
99
from types import MethodType
10+
from typing import List
1011

1112
import numpy as np
1213

@@ -96,6 +97,31 @@ def filter_attributes(ctx, f, **kwargs):
9697
_process_exclusion(ctx, cls_attrs, kwargs["exclude"], f)
9798

9899

100+
def validate_data_types(
101+
prohibited_data_types: List[str], reserved_words=["collaborators"], **kwargs
102+
):
103+
"""Validates that the types of attributes in kwargs are not among the prohibited data types.
104+
Raises a TypeError if any prohibited data type is found.
105+
106+
Args:
107+
prohibited_data_types (List[str]): A list of prohibited data type names
108+
(e.g., ['int', 'float']).
109+
kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
110+
111+
Raises:
112+
TypeError: If any prohibited data types are found in kwargs.
113+
ValueError: If prohibited_data_types is empty.
114+
"""
115+
if not prohibited_data_types:
116+
raise ValueError("prohibited_data_types must not be empty.")
117+
for attr_name, attr_value in kwargs.items():
118+
if type(attr_value).__name__ in prohibited_data_types and attr_value not in reserved_words:
119+
raise TypeError(
120+
f"The attribute '{attr_name}' = '{attr_value}' has a prohibited value type: "
121+
f"{type(attr_value).__name__}"
122+
)
123+
124+
99125
def _validate_include_exclude(kwargs, cls_attrs):
100126
"""Validates that 'include' and 'exclude' are not both present, and that
101127
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +178,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152178
delattr(ctx, attr)
153179

154180

155-
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
181+
def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]):
156182
"""Optionally saves the current state for the task just executed.
157183
158184
Args:
159185
ctx (any): The context to checkpoint.
160186
parent_func (function): The function that was just executed.
161-
chkpnt_reserved_words (list, optional): A list of reserved words to
187+
checkpoint_reserved_words (list, optional): A list of reserved words to
162188
exclude from checkpointing. Defaults to ["next", "runtime"].
163189
164190
Returns:
@@ -173,7 +199,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173199
if ctx._checkpoint:
174200
# all objects will be serialized using Metaflow interface
175201
print(f"Saving data artifacts for {parent_func.__name__}")
176-
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words)
202+
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words)
177203
task_id = ctx._metaflow_interface.create_task(parent_func.__name__)
178204
ctx._metaflow_interface.save_artifacts(
179205
artifacts_iter(),
@@ -188,15 +214,15 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
188214

189215
def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
190216
remaining_gpu_memory = {}
191-
# TODO for each GPU the funtion tries see if all participant usages fit
217+
# TODO for each GPU the function tries see if all participant usages fit
192218
# into a GPU, it it doesn't it removes that participant from the
193219
# participant list, and adds it to the remaining_gpu_memory dict. So any
194220
# sum of GPU requirements above 1 triggers this.
195-
# But at this point the funtion will raise an error because
221+
# But at this point the function will raise an error because
196222
# remaining_gpu_memory is never cleared.
197223
# The participant list should remove the participant if it fits in the gpu
198-
# and save the partipant if it doesn't and continue to the next GPU to see
199-
# if it fits in that one, only if we run out of GPUs should this funtion
224+
# and save the participant if it doesn't and continue to the next GPU to see
225+
# if it fits in that one, only if we run out of GPUs should this function
200226
# raise an error.
201227
for gpu in np.ones(num_gpus, dtype=int):
202228
for i, (participant_name, participant_gpu_usage) in enumerate(
@@ -230,7 +256,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230256
if gpu == 0:
231257
break
232258
if gpu < participant_gpu_usage:
233-
# participant doesn't fitm break to next GPU
259+
# participant doesn't fit; break to next GPU
234260
break
235261
else:
236262
# if participant fits remove from need_assigned

0 commit comments

Comments
 (0)