Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openfl-workspace/torch/mnist/plan/cols.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:
- collaborator1
- collaborator2
5 changes: 4 additions & 1 deletion openfl-workspace/torch/mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ compression_pipeline:
template: openfl.pipelines.NoCompressionPipeline
data_loader:
settings:
batch_size: 64
batch_size: 1
collaborator_count: 2
template: src.dataloader.PyTorchMNISTInMemory
network:
Expand All @@ -39,16 +39,19 @@ tasks:
apply: global
metrics:
- acc
use_tqdm: 1
locally_tuned_model_validation:
function: validate_task
kwargs:
apply: local
metrics:
- acc
use_tqdm: 1
settings: {}
train:
function: train_task
kwargs:
epochs: 1
metrics:
- loss
use_tqdm: 1
14 changes: 4 additions & 10 deletions openfl-workspace/torch/mnist/src/cnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class DigitRecognizerCNN(nn.Module):
"""
Expand Down Expand Up @@ -44,10 +45,8 @@ def __init__(self, **kwargs):
fc2 (nn.Linear): Second fully connected layer with 500 input features and 10 output features.
"""
super(DigitRecognizerCNN, self).__init__(**kwargs)
self.conv1 = nn.Conv2d(1, 20, 2, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(800, 500)
self.fc2 = nn.Linear(500, 10)
self.model = models.resnet50(pretrained=True)
self.fc2 = nn.Linear(1000, 10) # Update the number of output features to 10

def forward(self, x):
"""
Expand All @@ -63,12 +62,7 @@ def forward(self, x):
Returns:
torch.Tensor: Output tensor after passing through the CNN layers.
"""
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 800)
x = F.relu(self.fc1(x))
x = self.model(x)
x = self.fc2(x)

return x
Expand Down
37 changes: 28 additions & 9 deletions openfl-workspace/torch/mnist/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,21 @@
from openfl.federated import PyTorchDataLoader
from torchvision import datasets
from torchvision import transforms
import torch
import numpy as np
from logging import getLogger
from torchvision import transforms
from torch.utils.data import DataLoader

# Define the preprocessing transformations
preprocess = transforms.Compose(
[
transforms.Resize(64), # Resize the shorter side to 256
transforms.CenterCrop(64), # Crop the center to a 224x224 square
transforms.ToTensor(), # Convert to a PyTorch tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

logger = getLogger(__name__)

Expand All @@ -30,20 +43,24 @@ def __init__(self, data_path, batch_size, **kwargs):
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
"Expected `%s` to be representable as `int`, as it refers to the data shard "
+ "number used by the collaborator.",
data_path,
)

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
t = torch.from_numpy

number = 1
self.X_train = t(np.random.random([number, 3, 64, 64])).float()
self.y_train = t(np.random.randint(0, 9, [number]))

self.train_loader = self.get_train_loader()

self.X_valid = X_valid
self.y_valid = y_valid
self.X_valid = t(np.random.random([number, 3, 64, 64])).float()
self.y_valid = t(np.random.randint(0, 9, [number]))
self.val_loader = self.get_valid_loader()

self.num_classes = num_classes
Expand Down Expand Up @@ -76,7 +93,7 @@ def load_mnist_shard(
num_classes = 10

(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
shard_num, collaborator_count, transform=transforms.ToTensor()
shard_num, collaborator_count, transform=preprocess
)

logger.info(f"MNIST > X_train Shape : {X_train.shape}")
Expand Down Expand Up @@ -121,7 +138,9 @@ def _load_raw_datashards(shard_num, collaborator_count, transform=None):
2 tuples: (image, label) of the training, validation dataset
"""
train_data, val_data = (
datasets.MNIST("data", train=train, download=True, transform=transform)
datasets.MNIST(
"~/workspace/giant_data", train=train, download=True, transform=transform
)
for train in (True, False)
)
X_train_tot, y_train_tot = train_data.train_data, train_data.train_labels
Expand Down
11 changes: 11 additions & 0 deletions openfl/callbacks/callback_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from openfl.callbacks.callback import Callback
from openfl.callbacks.memory_profiler import MemoryProfiler
from openfl.callbacks.metric_writer import MetricWriter
from tictoc import bench_dict, timer


class CallbackList(Callback):
Expand Down Expand Up @@ -68,6 +69,16 @@ def _add_default_callbacks(self, add_memory_profiler, add_metric_writer):
self.callbacks.append(self._metric_writer)

def on_round_begin(self, round_num: int, logs=None):
if logs == 'agg':
if round_num > 0:
elapsed_time = timer.toc()
timer.tic()
with open(f"elapsed_time.txt", "a") as file:
file.write(str(elapsed_time) + "\n")
else:
timer.tic()

bench_dict['global'].gstep()
for callback in self.callbacks:
callback.on_round_begin(round_num, logs)

Expand Down
86 changes: 82 additions & 4 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

import openfl.callbacks as callbacks_module
from openfl.component.aggregator.straggler_handling import StragglerPolicy, WaitForAllPolicy
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.databases import PersistentTensorDB, TensorDB, TRY_CHANGE
from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.protocols.base_pb2 import NamedTensor
from openfl.utilities import TaskResultKey, TensorKey, change_tags
from tictoc import bench_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -231,7 +232,7 @@ def __init__(
# TODO: Aggregator has no concrete notion of round_begin.
# https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537
self.callbacks.on_experiment_begin()
self.callbacks.on_round_begin(self.round_number)
self.callbacks.on_round_begin(self.round_number, 'agg')

def _recover(self):
"""Populates the aggregator state to the state it was prior a restart"""
Expand Down Expand Up @@ -476,6 +477,7 @@ def get_tasks(self, collaborator_name):
sleep_time (int): Sleep time.
time_to_quit (bool): Whether it's time to quit.
"""
bench_dict['global'].step('wait get tasks')
logger.debug(
f"Aggregator GetTasks function reached from collaborator {collaborator_name}..."
)
Expand Down Expand Up @@ -543,6 +545,7 @@ def get_tasks(self, collaborator_name):
# Start straggler handling policy for timer based callback is required
# for %age based policy callback is not required
self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed)
bench_dict['global'].step('get_tasks')

return tasks, self.round_number, sleep_time, time_to_quit

Expand Down Expand Up @@ -607,6 +610,9 @@ def get_aggregated_tensor(
Raises:
ValueError: if Aggregator does not have an aggregated tensor for {tensor_key}.
"""
bench_dict['global'].step('wait get aggregated tensor')
bench_dict['get_aggregate_tensor'].gstep()

if "compressed" in tags or require_lossless:
compress_lossless = True
else:
Expand All @@ -624,16 +630,23 @@ def get_aggregated_tensor(
tags = change_tags(tags, remove_field="compressed")
if "lossy_compressed" in tags:
tags = change_tags(tags, remove_field="lossy_compressed")

bench_dict['get_aggregate_tensor'].step('change tag')

tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, tags)
tensor_name, origin, round_number, report, tags = tensor_key

bench_dict['get_aggregate_tensor'].step('get tensorkey')

if "aggregated" in tags and "delta" in tags and round_number != 0:
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("aggregated",))
else:
agg_tensor_key = tensor_key

bench_dict['get_aggregate_tensor'].step('tensorkey if')

nparray = self.tensor_db.get_tensor_from_cache(agg_tensor_key)
bench_dict['get_aggregate_tensor'].step('tensor from cache')

start_retrieving_time = time.time()
while nparray is None:
Expand All @@ -642,6 +655,7 @@ def get_aggregated_tensor(
nparray = self.tensor_db.get_tensor_from_cache(agg_tensor_key)
if (time.time() - start_retrieving_time) > 60:
break
bench_dict['get_aggregate_tensor'].step('wait for tensorkey')

if nparray is None:
raise ValueError(f"Aggregator does not have an aggregated tensor for {tensor_key}")
Expand All @@ -652,6 +666,9 @@ def get_aggregated_tensor(
named_tensor = self._nparray_to_named_tensor(
agg_tensor_key, nparray, send_model_deltas=True, compress_lossless=compress_lossless
)
bench_dict['get_aggregate_tensor'].step('_nparray_to_named_tensor')
bench_dict['get_aggregate_tensor'].gstop()
bench_dict['global'].step('get_aggregate_tensor')

return named_tensor

Expand Down Expand Up @@ -773,6 +790,7 @@ def send_local_task_results(
Returns:
None
"""
bench_dict['global'].step('wait send local task')
# Check if secure aggregation is enabled.
if self._secure_aggregation_enabled:
secagg_setup = self.secagg.process_secagg_setup_tensors(named_tensors)
Expand All @@ -794,7 +812,7 @@ def send_local_task_results(
f"Collaborator {collaborator_name} is sending task results "
f"for {task_name}, round {round_number}"
)

bench_dict['global'].step('send task results')
self.process_task_results(
collaborator_name, round_number, task_name, data_size, named_tensors
)
Expand Down Expand Up @@ -869,6 +887,7 @@ def process_task_results(

# Check if collaborator or round is done.
self._is_collaborator_done(collaborator_name, round_number)
bench_dict['global'].step('process task')
self._end_of_round_with_stragglers_check()

def _end_of_round_with_stragglers_check(self):
Expand Down Expand Up @@ -1191,6 +1210,9 @@ def _end_of_round_check(self):
for task_name in self.assigner.get_all_tasks_for_round(self.round_number):
logs.update(self._compute_validation_related_task_metrics(task_name))

# End of round callbacks.
self.callbacks.on_round_end(self.round_number, logs)

# Once all of the task results have been processed
self._end_of_round_check_done[self.round_number] = True

Expand All @@ -1211,7 +1233,20 @@ def _end_of_round_check(self):

# End of round callbacks.
# todo handle case when aggregator restarted before callback was successful

bench_dict['global'].step('save model')

self.callbacks.on_round_end(self.round_number, logs)
bench_dict['global'].step('on round end')
if self.round_number % 10 == 0:
bench_dict.save()
bench_dict['global'].step('save tictoc')

if self.round_number % 3 == 0:
self.tensor_db.tensor_db.to_pickle(f'tensor_db_{str(self.round_number).zfill(2)}.pkl')
if TRY_CHANGE:
self.tensor_db.secondary_db.to_pickle(f'secondary_tensor_db_{str(self.round_number).zfill(2)}.pkl')
bench_dict['global'].step('save_db')

self.round_number += 1

Expand All @@ -1227,15 +1262,58 @@ def _end_of_round_check(self):
logger.info("Experiment Completed. Cleaning up...")
# End of experiment callbacks.
self.callbacks.on_experiment_end()
bench_dict.save()
else:
logger.info("Starting round %s...", self.round_number)
# https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537
self.callbacks.on_round_begin(self.round_number)
bench_dict['global'].step('other')
bench_dict['global'].gstop()
self.callbacks.on_round_begin(self.round_number, 'agg')

# Cleaning tensor db
self.tensor_db.clean_up(self.db_store_rounds)
bench_dict['global'].step('Cleaning tensor db')
# Reset straggler handling policy for the next round.
self.straggler_handling_policy.reset_policy_for_round()
bench_dict['global'].step('reset straggler')

def _has_analytics_results(self):
"""
Check if the current round has analytics results.

Returns:
bool: True if the current round has analytics results, False otherwise.
"""
analytics_result = self.tensor_db.get_tensors_by_round_and_tags(
self.round_number, ("analytics",)
)
return len(analytics_result) > 0

def save_analytics_result(self):
"""
Save analytics results to a JSON file.
This method retrieves tensors tagged with "analytics" for the current round
from the tensor database and saves them as a JSON file at the path specified
by `self.last_state_path`. The tensor values are converted to lists if they
are NumPy arrays.
The saved JSON file contains a dictionary where the keys are tensor names
and the values are the corresponding tensor data.
Logs the saved analytics result for reference.
Returns:
None
"""
analytics_result = self.tensor_db.get_tensors_by_round_and_tags(
self.round_number, ("analytics",)
)
if len(analytics_result) > 0 and self.last_state_path:
with open(self.last_state_path, "w") as jsonfile:
analytics_result_json = {}
for tensorkey, values in analytics_result.items():
if isinstance(values, np.ndarray):
values = values.tolist()
analytics_result_json[tensorkey.tensor_name] = values
json.dump(analytics_result_json, jsonfile, indent=4)
logger.debug(f"Analytics result: {analytics_result_json}")

def _has_analytics_results(self):
"""
Expand Down
Loading
Loading