diff --git a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb index 00daa8095e..a778c5c43c 100644 --- a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb +++ b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/workspace/101_MNIST_FederatedRuntime.ipynb @@ -609,7 +609,9 @@ " collaborators=collaborator_names,\n", " director=director_info, \n", " notebook_path='./101_MNIST_FederatedRuntime.ipynb'\n", - ")" + ")\n", + "# Set allowed data types\n", + "FederatedRuntime.allowed_data_types = [\"str\", \"int\", \"float\", \"bool\"]" ] }, { diff --git a/openfl/experimental/workflow/component/aggregator/aggregator.py b/openfl/experimental/workflow/component/aggregator/aggregator.py index 717436f8a0..4df3ce080f 100644 --- a/openfl/experimental/workflow/component/aggregator/aggregator.py +++ b/openfl/experimental/workflow/component/aggregator/aggregator.py @@ -198,6 +198,10 @@ async def run_flow(self) -> FLSpec: Returns: flow (FLSpec): Updated instance. """ + # As an aggregator is created before a runtime, set prohibited/allowed data types for the + # runtime before running the flow. + self.flow.runtime.prohibited_data_types = FederatedRuntime.prohibited_data_types + self.flow.runtime.allowed_data_types = FederatedRuntime.allowed_data_types # Start function will be the first step if any flow f_name = "start" # Creating a clones from the flow object @@ -394,7 +398,7 @@ def do_task(self, f_name: str) -> Any: # Create list of selected collaborator clones selected_clones = ([],) for name, clone in self.clones_dict.items(): - # Check if collaboraotr is in the list of selected + # Check if collaborator is in the list of selected # collaborators if name in self.selected_collaborators: selected_clones[0].append(clone) diff --git a/openfl/experimental/workflow/interface/fl_spec.py b/openfl/experimental/workflow/interface/fl_spec.py index 3e8365458b..68b2e8f732 100644 --- a/openfl/experimental/workflow/interface/fl_spec.py +++ b/openfl/experimental/workflow/interface/fl_spec.py @@ -22,6 +22,7 @@ filter_attributes, generate_artifacts, should_transfer, + validate_data_types, ) @@ -127,16 +128,16 @@ def runtime(self, runtime: Type[Runtime]) -> None: def run(self) -> None: """Starts the execution of the flow.""" # Submit flow to Runtime - if str(self._runtime) == "LocalRuntime": + if str(self.runtime) == "LocalRuntime": self._run_local() - elif str(self._runtime) == "FederatedRuntime": + elif str(self.runtime) == "FederatedRuntime": self._run_federated() else: raise Exception("Runtime not implemented") def _run_local(self) -> None: """Executes the flow using LocalRuntime.""" - self._setup_initial_state() + self._setup_initial_state_local() try: # Execute all Participant (Aggregator & Collaborator) tasks and # retrieve the final attributes @@ -164,7 +165,7 @@ def _run_local(self) -> None: for name, attr in final_attributes: setattr(self, name, attr) - def _setup_initial_state(self) -> None: + def _setup_initial_state_local(self) -> None: """ Sets up the flow's initial state, initializing private attributes for collaborators and aggregators. @@ -176,6 +177,7 @@ def _setup_initial_state(self) -> None: self._foreach_methods = [] FLSpec._reset_clones() FLSpec._create_clones(self, self.runtime.collaborators) + # Initialize collaborator private attributes self.runtime.initialize_collaborators() if self._checkpoint: @@ -334,7 +336,7 @@ def next(self, f, **kwargs) -> None: parent = inspect.stack()[1][3] parent_func = getattr(self, parent) - if str(self._runtime) == "LocalRuntime": + if str(self.runtime) == "LocalRuntime": # Checkpoint current attributes (if checkpoint==True) checkpoint(self, parent_func) @@ -343,10 +345,15 @@ def next(self, f, **kwargs) -> None: if aggregator_to_collaborator(f, parent_func): agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs) - # Remove included / excluded attributes from next task - filter_attributes(self, f, **kwargs) + if kwargs: + # Remove unwanted attributes from next task + filter_attributes(self, f, **kwargs) + if self.runtime.prohibited_data_types or self.runtime.allowed_data_types: + validate_data_types( + self.runtime.prohibited_data_types, self.runtime.allowed_data_types, **kwargs + ) - if str(self._runtime) == "FederatedRuntime": + if str(self.runtime) == "FederatedRuntime": if f.collaborator_step and not f.aggregator_step: self._foreach_methods.append(f.__name__) @@ -359,6 +366,6 @@ def next(self, f, **kwargs) -> None: kwargs, ) - elif str(self._runtime) == "LocalRuntime": + elif str(self.runtime) == "LocalRuntime": # update parameters required to execute execute_task function self.execute_task_args = [f, parent_func, agg_to_collab_ss, kwargs] diff --git a/openfl/experimental/workflow/runtime/local_runtime.py b/openfl/experimental/workflow/runtime/local_runtime.py index c6bd4e8855..88d714fb9c 100644 --- a/openfl/experimental/workflow/runtime/local_runtime.py +++ b/openfl/experimental/workflow/runtime/local_runtime.py @@ -738,6 +738,8 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k # new runtime object will not contain private attributes of # aggregator or other collaborators clone.runtime = LocalRuntime(backend="single_process") + clone.runtime.prohibited_data_types = self.prohibited_data_types + clone.runtime.allowed_data_types = self.allowed_data_types # write the clone to the object store # ensure clone is getting latest _metaflow_interface diff --git a/openfl/experimental/workflow/runtime/runtime.py b/openfl/experimental/workflow/runtime/runtime.py index 317c26ede6..1c135160ef 100644 --- a/openfl/experimental/workflow/runtime/runtime.py +++ b/openfl/experimental/workflow/runtime/runtime.py @@ -10,7 +10,77 @@ from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator -class Runtime: +class AttributeValidationMeta(type): + """ + Metaclass that enforces validation rules on class attributes. + + This metaclass ensures that `prohibited_data_types` and `allowed_data_types` + are lists of strings and that both cannot be set at the same time. + + Example: + class MyClass(metaclass=AttributeValidationMeta): + pass + + MyClass.prohibited_data_types = ["int", "float"] # Valid + MyClass.allowed_data_types = ["str", "bool"] # Valid + MyClass.prohibited_data_types = "int" # Raises TypeError + MyClass.allowed_data_types = 42 # Raises TypeError + """ + + def __setattr__(cls, name, value): + """ + Validates and sets class attributes. + + Ensures that `prohibited_data_types` and `allowed_data_types`, when assigned, + are lists of strings and that they are not used together. + + Args: + name (str): The attribute name being set. + value (any): The value to be assigned to the attribute. + + Raises: + TypeError: If `prohibited_data_types` or `allowed_data_types` is not a list + or contains non-string elements. + ValueError: If both `prohibited_data_types` and `allowed_data_types` are set. + """ + if name in {"prohibited_data_types", "allowed_data_types"}: + if not isinstance(value, list): + raise TypeError(f"'{name}' must be a list, got {type(value).__name__}") + if not all(isinstance(item, str) for item in value): + raise TypeError(f"All elements of '{name}' must be strings") + + # Ensure both attributes are not set at the same time + other_name = ( + "allowed_data_types" if name == "prohibited_data_types" else "prohibited_data_types" + ) + if getattr(cls, other_name, []): # Check if the other attribute is already set + raise ValueError( + "Cannot set both 'prohibited_data_types' and 'allowed_data_types'." + ) + + super().__setattr__(name, value) + + +class Runtime(metaclass=AttributeValidationMeta): + """ + Base class for federated learning runtimes. + This class serves as an interface for runtimes that execute FLSpec flows. + + Attributes: + prohibited_data_types (list): A list of data types that are prohibited from being + transmitted over the network. + allowed_data_types (list): A list of data types that are explicitly allowed to be + transmitted over the network. + + Notes: + - Either `prohibited_data_types` or `allowed_data_types` may be specified. + - If neither is specified, all data types are allowed to be transmitted. + - If both are specified, a `ValueError` will be raised. + """ + + prohibited_data_types = [] + allowed_data_types = [] + def __init__(self): """Initializes the Runtime object. diff --git a/openfl/experimental/workflow/utilities/__init__.py b/openfl/experimental/workflow/utilities/__init__.py index 6b7069bc66..63b7e83168 100644 --- a/openfl/experimental/workflow/utilities/__init__.py +++ b/openfl/experimental/workflow/utilities/__init__.py @@ -17,6 +17,7 @@ filter_attributes, generate_artifacts, parse_attrs, + validate_data_types, ) from openfl.experimental.workflow.utilities.stream_redirect import ( RedirectStdStream, diff --git a/openfl/experimental/workflow/utilities/runtime_utils.py b/openfl/experimental/workflow/utilities/runtime_utils.py index 815eca3d38..9a2d196f00 100644 --- a/openfl/experimental/workflow/utilities/runtime_utils.py +++ b/openfl/experimental/workflow/utilities/runtime_utils.py @@ -7,6 +7,7 @@ import inspect import itertools from types import MethodType +from typing import List import numpy as np @@ -96,6 +97,54 @@ def filter_attributes(ctx, f, **kwargs): _process_exclusion(ctx, cls_attrs, kwargs["exclude"], f) +def validate_data_types( + prohibited_data_types: List[str] = None, + allowed_data_types: List[str] = None, + reserved_words=["collaborators"], + **kwargs, +): + """Validates that the types of attributes in kwargs are not among the prohibited data types + and are among the allowed data types if specified. + Raises a TypeError if any prohibited data type is found or if a type is not allowed. + + Args: + prohibited_data_types (List[str], optional): A list of prohibited data type names + (e.g., ['int', 'float']). + allowed_data_types (List[str], optional): A list of allowed data type names. + reserved_words: A list of strings that should be allowed as attribute values, even if 'str' + is included in prohibited_data_types. + kwargs (dict): Arbitrary keyword arguments representing attribute names and their values. + + Raises: + TypeError: If any prohibited data types are found in kwargs or if a type is not allowed. + ValueError: If both prohibited_data_types and allowed_data_types are set simultaneously. + """ + if prohibited_data_types is None: + prohibited_data_types = [] + if allowed_data_types is None: + allowed_data_types = [] + + if prohibited_data_types and allowed_data_types: + raise ValueError("Cannot set both 'prohibited_data_types' and 'allowed_data_types'.") + + for attr_name, attr_value in kwargs.items(): + attr_type = type(attr_value).__name__ + if ( + prohibited_data_types + and attr_type in prohibited_data_types + and attr_value not in reserved_words + ): + raise TypeError( + f"The attribute '{attr_name}' = '{attr_value}' " + f"has a prohibited value type: {attr_type}" + ) + if allowed_data_types and attr_type not in allowed_data_types: + raise TypeError( + f"The attribute '{attr_name}' = '{attr_value}' " + f"has a type that is not allowed: {attr_type}" + ) + + def _validate_include_exclude(kwargs, cls_attrs): """Validates that 'include' and 'exclude' are not both present, and that attributes in 'include' or 'exclude' exist in the context. @@ -152,13 +201,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f): delattr(ctx, attr) -def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]): +def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]): """Optionally saves the current state for the task just executed. Args: ctx (any): The context to checkpoint. parent_func (function): The function that was just executed. - chkpnt_reserved_words (list, optional): A list of reserved words to + checkpoint_reserved_words (list, optional): A list of reserved words to exclude from checkpointing. Defaults to ["next", "runtime"]. Returns: @@ -173,7 +222,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]): if ctx._checkpoint: # all objects will be serialized using Metaflow interface print(f"Saving data artifacts for {parent_func.__name__}") - artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words) + artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words) task_id = ctx._metaflow_interface.create_task(parent_func.__name__) ctx._metaflow_interface.save_artifacts( artifacts_iter(), @@ -195,7 +244,7 @@ def old_check_resource_allocation(num_gpus, each_participant_gpu_usage): # But at this point the function will raise an error because # remaining_gpu_memory is never cleared. # The participant list should remove the participant if it fits in the gpu - # and save the partipant if it doesn't and continue to the next GPU to see + # and save the participant if it doesn't and continue to the next GPU to see # if it fits in that one, only if we run out of GPUs should this function # raise an error. for gpu in np.ones(num_gpus, dtype=int): @@ -230,7 +279,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage): if gpu == 0: break if gpu < participant_gpu_usage: - # participant doesn't fitm break to next GPU + # participant doesn't fit, break to next GPU break else: # if participant fits remove from need_assigned