Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,9 @@
" collaborators=collaborator_names,\n",
" director=director_info, \n",
" notebook_path='./101_MNIST_FederatedRuntime.ipynb'\n",
")"
")\n",
"# Set allowed data types\n",
"FederatedRuntime.allowed_data_types = [\"str\", \"int\", \"float\", \"bool\"]"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ async def run_flow(self) -> FLSpec:
Returns:
flow (FLSpec): Updated instance.
"""
# As an aggregator is created before a runtime, set prohibited/allowed data types for the
# runtime before running the flow.
self.flow.runtime.prohibited_data_types = FederatedRuntime.prohibited_data_types
self.flow.runtime.allowed_data_types = FederatedRuntime.allowed_data_types
# Start function will be the first step if any flow
f_name = "start"
# Creating a clones from the flow object
Expand Down Expand Up @@ -394,7 +398,7 @@ def do_task(self, f_name: str) -> Any:
# Create list of selected collaborator clones
selected_clones = ([],)
for name, clone in self.clones_dict.items():
# Check if collaboraotr is in the list of selected
# Check if collaborator is in the list of selected
# collaborators
if name in self.selected_collaborators:
selected_clones[0].append(clone)
Expand Down
25 changes: 16 additions & 9 deletions openfl/experimental/workflow/interface/fl_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
filter_attributes,
generate_artifacts,
should_transfer,
validate_data_types,
)


Expand Down Expand Up @@ -127,16 +128,16 @@ def runtime(self, runtime: Type[Runtime]) -> None:
def run(self) -> None:
"""Starts the execution of the flow."""
# Submit flow to Runtime
if str(self._runtime) == "LocalRuntime":
if str(self.runtime) == "LocalRuntime":
self._run_local()
elif str(self._runtime) == "FederatedRuntime":
elif str(self.runtime) == "FederatedRuntime":
self._run_federated()
else:
raise Exception("Runtime not implemented")

def _run_local(self) -> None:
"""Executes the flow using LocalRuntime."""
self._setup_initial_state()
self._setup_initial_state_local()
try:
# Execute all Participant (Aggregator & Collaborator) tasks and
# retrieve the final attributes
Expand Down Expand Up @@ -164,7 +165,7 @@ def _run_local(self) -> None:
for name, attr in final_attributes:
setattr(self, name, attr)

def _setup_initial_state(self) -> None:
def _setup_initial_state_local(self) -> None:
"""
Sets up the flow's initial state, initializing private attributes for
collaborators and aggregators.
Expand All @@ -176,6 +177,7 @@ def _setup_initial_state(self) -> None:
self._foreach_methods = []
FLSpec._reset_clones()
FLSpec._create_clones(self, self.runtime.collaborators)

# Initialize collaborator private attributes
self.runtime.initialize_collaborators()
if self._checkpoint:
Expand Down Expand Up @@ -334,7 +336,7 @@ def next(self, f, **kwargs) -> None:
parent = inspect.stack()[1][3]
parent_func = getattr(self, parent)

if str(self._runtime) == "LocalRuntime":
if str(self.runtime) == "LocalRuntime":
# Checkpoint current attributes (if checkpoint==True)
checkpoint(self, parent_func)

Expand All @@ -343,10 +345,15 @@ def next(self, f, **kwargs) -> None:
if aggregator_to_collaborator(f, parent_func):
agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs)

# Remove included / excluded attributes from next task
filter_attributes(self, f, **kwargs)
if kwargs:
# Remove unwanted attributes from next task
filter_attributes(self, f, **kwargs)
if self.runtime.prohibited_data_types or self.runtime.allowed_data_types:
validate_data_types(
self.runtime.prohibited_data_types, self.runtime.allowed_data_types, **kwargs
)

if str(self._runtime) == "FederatedRuntime":
if str(self.runtime) == "FederatedRuntime":
if f.collaborator_step and not f.aggregator_step:
self._foreach_methods.append(f.__name__)

Expand All @@ -359,6 +366,6 @@ def next(self, f, **kwargs) -> None:
kwargs,
)

elif str(self._runtime) == "LocalRuntime":
elif str(self.runtime) == "LocalRuntime":
# update parameters required to execute execute_task function
self.execute_task_args = [f, parent_func, agg_to_collab_ss, kwargs]
2 changes: 2 additions & 0 deletions openfl/experimental/workflow/runtime/local_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k
# new runtime object will not contain private attributes of
# aggregator or other collaborators
clone.runtime = LocalRuntime(backend="single_process")
clone.runtime.prohibited_data_types = self.prohibited_data_types
clone.runtime.allowed_data_types = self.allowed_data_types

# write the clone to the object store
# ensure clone is getting latest _metaflow_interface
Expand Down
72 changes: 71 additions & 1 deletion openfl/experimental/workflow/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,77 @@
from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator


class Runtime:
class AttributeValidationMeta(type):
"""
Metaclass that enforces validation rules on class attributes.
This metaclass ensures that `prohibited_data_types` and `allowed_data_types`
are lists of strings and that both cannot be set at the same time.
Example:
class MyClass(metaclass=AttributeValidationMeta):
pass
MyClass.prohibited_data_types = ["int", "float"] # Valid
MyClass.allowed_data_types = ["str", "bool"] # Valid
MyClass.prohibited_data_types = "int" # Raises TypeError
MyClass.allowed_data_types = 42 # Raises TypeError
"""

def __setattr__(cls, name, value):
"""
Validates and sets class attributes.
Ensures that `prohibited_data_types` and `allowed_data_types`, when assigned,
are lists of strings and that they are not used together.
Args:
name (str): The attribute name being set.
value (any): The value to be assigned to the attribute.
Raises:
TypeError: If `prohibited_data_types` or `allowed_data_types` is not a list
or contains non-string elements.
ValueError: If both `prohibited_data_types` and `allowed_data_types` are set.
"""
if name in {"prohibited_data_types", "allowed_data_types"}:
if not isinstance(value, list):
raise TypeError(f"'{name}' must be a list, got {type(value).__name__}")
if not all(isinstance(item, str) for item in value):
raise TypeError(f"All elements of '{name}' must be strings")

# Ensure both attributes are not set at the same time
other_name = (
"allowed_data_types" if name == "prohibited_data_types" else "prohibited_data_types"
)
if getattr(cls, other_name, []): # Check if the other attribute is already set
raise ValueError(
"Cannot set both 'prohibited_data_types' and 'allowed_data_types'."
)

super().__setattr__(name, value)


class Runtime(metaclass=AttributeValidationMeta):
"""
Base class for federated learning runtimes.
This class serves as an interface for runtimes that execute FLSpec flows.
Attributes:
prohibited_data_types (list): A list of data types that are prohibited from being
transmitted over the network.
allowed_data_types (list): A list of data types that are explicitly allowed to be
transmitted over the network.
Notes:
- Either `prohibited_data_types` or `allowed_data_types` may be specified.
- If neither is specified, all data types are allowed to be transmitted.
- If both are specified, a `ValueError` will be raised.
"""

prohibited_data_types = []
allowed_data_types = []

def __init__(self):
"""Initializes the Runtime object.
Expand Down
1 change: 1 addition & 0 deletions openfl/experimental/workflow/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
filter_attributes,
generate_artifacts,
parse_attrs,
validate_data_types,
)
from openfl.experimental.workflow.utilities.stream_redirect import (
RedirectStdStream,
Expand Down
59 changes: 54 additions & 5 deletions openfl/experimental/workflow/utilities/runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import itertools
from types import MethodType
from typing import List

import numpy as np

Expand Down Expand Up @@ -96,6 +97,54 @@ def filter_attributes(ctx, f, **kwargs):
_process_exclusion(ctx, cls_attrs, kwargs["exclude"], f)


def validate_data_types(
prohibited_data_types: List[str] = None,
allowed_data_types: List[str] = None,
reserved_words=["collaborators"],
**kwargs,
):
"""Validates that the types of attributes in kwargs are not among the prohibited data types
and are among the allowed data types if specified.
Raises a TypeError if any prohibited data type is found or if a type is not allowed.

Args:
prohibited_data_types (List[str], optional): A list of prohibited data type names
(e.g., ['int', 'float']).
allowed_data_types (List[str], optional): A list of allowed data type names.
reserved_words: A list of strings that should be allowed as attribute values, even if 'str'
is included in prohibited_data_types.
kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.

Raises:
TypeError: If any prohibited data types are found in kwargs or if a type is not allowed.
ValueError: If both prohibited_data_types and allowed_data_types are set simultaneously.
"""
if prohibited_data_types is None:
prohibited_data_types = []
if allowed_data_types is None:
allowed_data_types = []

if prohibited_data_types and allowed_data_types:
raise ValueError("Cannot set both 'prohibited_data_types' and 'allowed_data_types'.")

for attr_name, attr_value in kwargs.items():
attr_type = type(attr_value).__name__
if (
prohibited_data_types
and attr_type in prohibited_data_types
and attr_value not in reserved_words
):
raise TypeError(
f"The attribute '{attr_name}' = '{attr_value}' "
f"has a prohibited value type: {attr_type}"
)
if allowed_data_types and attr_type not in allowed_data_types:
raise TypeError(
f"The attribute '{attr_name}' = '{attr_value}' "
f"has a type that is not allowed: {attr_type}"
)


def _validate_include_exclude(kwargs, cls_attrs):
"""Validates that 'include' and 'exclude' are not both present, and that
attributes in 'include' or 'exclude' exist in the context.
Expand Down Expand Up @@ -152,13 +201,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
delattr(ctx, attr)


def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]):
"""Optionally saves the current state for the task just executed.

Args:
ctx (any): The context to checkpoint.
parent_func (function): The function that was just executed.
chkpnt_reserved_words (list, optional): A list of reserved words to
checkpoint_reserved_words (list, optional): A list of reserved words to
exclude from checkpointing. Defaults to ["next", "runtime"].

Returns:
Expand All @@ -173,7 +222,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
if ctx._checkpoint:
# all objects will be serialized using Metaflow interface
print(f"Saving data artifacts for {parent_func.__name__}")
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words)
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words)
task_id = ctx._metaflow_interface.create_task(parent_func.__name__)
ctx._metaflow_interface.save_artifacts(
artifacts_iter(),
Expand All @@ -195,7 +244,7 @@ def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
# But at this point the function will raise an error because
# remaining_gpu_memory is never cleared.
# The participant list should remove the participant if it fits in the gpu
# and save the partipant if it doesn't and continue to the next GPU to see
# and save the participant if it doesn't and continue to the next GPU to see
# if it fits in that one, only if we run out of GPUs should this function
# raise an error.
for gpu in np.ones(num_gpus, dtype=int):
Expand Down Expand Up @@ -230,7 +279,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
if gpu == 0:
break
if gpu < participant_gpu_usage:
# participant doesn't fitm break to next GPU
# participant doesn't fit, break to next GPU
break
else:
# if participant fits remove from need_assigned
Expand Down
Loading