Skip to content
This repository was archived by the owner on Aug 2, 2024. It is now read-only.

Commit e423f1b

Browse files
Vertical Federated Learning for MNIST and CCFRAUD examples (#248)
Implementation of basic vertical federated learning algorithm for CCFRAUD and MNIST examples.
1 parent e7e3b2a commit e423f1b

File tree

36 files changed

+5068
-19
lines changed

36 files changed

+5068
-19
lines changed

docs/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- [Concepts](#concepts)
1818
- [Why should you consider Federated Learning?](#why-should-you-consider-federated-learning)
1919
- [How to plan for your Federated Learning project](#how-to-plan-for-your-federated-learning-project)
20+
- [Vertical federated learning](#vertical-federated-learning)
2021
- [Glossary](#glossary)
2122
- [Tutorials](#tutorials)
2223
- [What this repo has to offer?](#what-this-repo-has-to-offer)
@@ -93,6 +94,14 @@ Creating such a graph of jobs can be complex. This repository provides a recipe
9394

9495
We wrote a generic guide on how to get started, ramp-up and mature your [FL project](./concepts/plan-your-fl-project.md).
9596

97+
## Vertical federated learning
98+
99+
> - :warning: EXPERIMENTAL :warning: We are delighted to share with you our solution for vertical federated learning, however, please keep in mind that it is still in active development.
100+
101+
Vertical federated learning is a branch of federated learning where the data are split across the features (vertically) instead of across the samples (horizontally). This provides communication challenges as the nodes running the code needs to exchange intermediate outputs and their corresponding gradients of aligned samples.
102+
103+
We provide examples on how to run **MNIST** and **CCFRAUD** examples using vertical federated learning. These are essentially copies of the original examples with features scattered across the nodes. We invite you to learn more about this approach in the [vertical federated learing tutorial](./tutorials/vertical-fl.md).
104+
96105
## Glossary
97106

98107
The complete glossary list can be seen [**here**](./concepts/glossary.md).

docs/pics/fldatatypes.png

45.5 KB
Loading

docs/pics/vfltrainingloop.png

77.6 KB
Loading

docs/tutorials/vertical-fl.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Cross-silo vertical federated learning
2+
3+
## Background
4+
Vertical federated learning (VFL) is branch of federated learning where the data are split across the features among the participants rather than across the samples (horizontal FL). In other words we can say that it takes federated learning to another level as it allows for cross-organization collaboration without need for having the same features while keeping privacy and security of each individual's data intact. Some of real-world examples include, but are not limited to:
5+
- Finance: several institutions owning different pieces of data about their clients (e.g. bank account data, credit card data, loans data, ...etc)
6+
- Healthcare: different healthcare facilities may own different modalities (e.g. x-ray scans, prescriptions, patient health records, ...etc)
7+
- Retail: each retailer owns different information about customer and aggregating this information may result in better recommendations for the customer
8+
9+
<br/><br/>
10+
<div align="center">
11+
<img src="../pics/fldatatypes.png" alt="Homogenous vs heterogenous data" width="400">
12+
</div>
13+
14+
> Note: In this tutorial we refer to "host" as the party who owns the data labels and optionally some part of features and "contributors" as parties who own only features and provide host with intermediate outputs of their share of the network
15+
16+
## Objective and contents
17+
This tutorial will guide you through steps required to set-up VFL experiments and point out important parts of the code. We target MNIST (written number recognition) and [CCFRAUD (financial tabular data)](../real-world-examples/ccfraud.md) examples in order to showcase versatility of the solution in regards to type of the data. All of the examples here make use of mean aggregation and assumption is that the host owns only labels while features are equally distributed among the contributors.
18+
19+
## Infrastructure
20+
First step towards successfully running VFL example is to provision an infrastructure. In order to do so, please navigate to [quickstart](../quickstart.md) and use **single-button deployment for vnet infrastructure deployment**. This is necessary in order for nodes to be able to communicate.
21+
22+
## Install the required dependencies
23+
24+
You'll need python to submit experiments to AzureML. You can install the required dependencies by running:
25+
26+
```bash
27+
conda env create --file ./examples/pipelines/environment.yml
28+
conda activate fl_experiment_conda_env
29+
```
30+
31+
Alternatively, you can just install the required dependencies:
32+
33+
```bash
34+
python -m pip install -r ./examples/pipelines/requirements.txt
35+
```
36+
37+
## Data provisioning
38+
The data format for VFL is different from regular FL. That is why each of our examples contains its own script for uploading data that are needed for a given example.
39+
40+
> Note: This will split the data such that each contributor owns its portion of the features and host own only the labels
41+
42+
### CCFRAUD
43+
44+
Please follow steps in [CCFRAUD - Add your Kaggle credentials to the workspace key vault](../real-world-examples/ccfraud.md#Add-your-Kaggle-credentials-to-the-workspace-key-vault). Afterwards, follow same steps as for **MNIST** and **please do not forget to replace `--example MNIST_VERTICAL` with `--example CCFRAUD_VERTICAL`**).
45+
46+
### MNIST
47+
48+
This can all be performed with ease using a data provisioning pipeline. To run it follow these steps:
49+
50+
1. If you are not using the quickstart setup, adjust the config file `config.yaml` in `examples/pipelines/utils/upload_data/` to match your setup.
51+
52+
2. Submit the experiment by running:
53+
54+
```bash
55+
python ./examples/pipelines/utils/upload_data/submit.py --example MNIST_VERTICAL --workspace_name "<workspace-name>" --resource_group "<resource-group-name>" --subscription_id "<subscription-id>"
56+
```
57+
58+
> Note: You can use --offline flag when running the job to just build and validate pipeline without submitting it.
59+
60+
:star: you can simplify this command by entering your workspace details in the file `config.yaml` in this same directory.
61+
62+
:warning: Proceed to the next step only once the pipeline completes. This pipeline will create data in 3 distinct locations.
63+
64+
## Model preparation for VFL
65+
It is an ongoing research topic on how the model can be orchestrated in VFL. We have decided to go with the most common approach by splitting it between the host and contributors, also referred to as **split learning**, this approach can be easily altered by moving layers between parties to hosting whole model on contributors while host provides only aggregation and/or activation function. We believe that this can better demonstrate capabilities of VFL on AzureML and most of the existing models can be easily split without requiring too much work.
66+
67+
## Training
68+
69+
### Overview
70+
Now, before we run the training itself let's take a step back and take a look on how such training works in VFL setup that is roughly depicted in the figure below. The first step that needs to take place ahead of the training is:
71+
72+
- **Private entity intersection and alignment** - before the training takes place we need to make sure that all of the parties involved share the same sample space and these samples are aligned during the training. **Our samples provide these guarantees by design but please make sure it's true for your custom data. This can be achieved by, for example, providing preprocessing step before training as we do not provide any for of PSI as of now.**
73+
74+
Afterwards, we can continue with regular training loop:
75+
- **Forward pass in contributors** - all contributors, and optionally host, perform forward pass on their part of the model with features they own
76+
- **Intermediate outputs transfer** - all outputs from previous step are sent to the host that performs an aggregation (for simplicity sake we make use of mean operation)
77+
- **Loss computation** - host performs either forward pass on its part of network or just passes aggregated outputs of previous step through an activation function followed by loss computation
78+
- **Gradients computation** - if host owns some part of the network, it performs backward pass, followed by computing gradients w.r.t inputs in all cases
79+
- **Gradient transfer** - all contributors, and optionally host, receives gradients w.r.t. their intermediate outputs
80+
- **Backward pass** - gradients are used to perform backward pass and update the network weights
81+
82+
<br/><br/>
83+
<div align="center">
84+
<img src="../pics/vfltrainingloop.png" alt="Vertical federated learning training loop" width="400">
85+
</div>
86+
87+
### Steps to launch
88+
1. If you are not using the quickstart setup, adjust the config file `config.yaml` in `examples/pipelines/<example-name>/` to match your setup.
89+
90+
2. Submit the experiment by running:
91+
92+
```bash
93+
python ./examples/pipelines/<example-name>/submit.py --config examples/pipelines/<example-name>/config.yaml --workspace_name "<workspace-name>" --resource_group "<resource-group-name>" --subscription_id "<subscription-id>"
94+
```
95+
96+
> Note: You can use --offline flag when running the job to just build and validate pipeline without submitting it.
97+
98+
:star: you can simplify this command by entering your workspace details in the file `config.yaml` in this same directory.
99+
100+
101+
## Tips and pitfalls
102+
1. **Vertical Federated Learning comes at a cost**
103+
There is significant overhead when launching vertical federated learning due to heavy communication among participants. As we can see in the training loop there are two transfers per each mini-batch. One for forward pass outputs, one for gradients. This means that the training may take longer than expected.
104+
2. **Intersection and entity alignment**
105+
The samples needs to be aligned across participants ahead of the training after we created set intersection of samples that are present on all involved parties. This process can reveal information to other entities that we may want to keep private. Fortunately there are **private set intersection** methods available out there that come to rescue.
106+
3. **Communication encryption**
107+
Even though the intermediate outputs and gradients are not raw data, they still have been inferred using private data. Therefore, it's good to use encryption when communicating the data to parties outside of Azure.
108+
109+
## Additional resources
110+
- [Private set intersection algorithm overview](https://xianmu.github.io/posts/2018-11-03-private-set-intersection-based-on-rsa-blind-signature.html)
111+
112+
113+

examples/components/CCFRAUD/preprocessing/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def preprocess_data(
116116
logger.debug(f"Train data samples: {len(train_data)}")
117117
logger.debug(f"Test data samples: {len(test_data)}")
118118

119+
os.makedirs(train_data_dir, exist_ok=True)
120+
os.makedirs(test_data_dir, exist_ok=True)
121+
119122
train_data = train_data.sort_values(by="trans_date_trans_time")
120123
test_data = test_data.sort_values(by="trans_date_trans_time")
121124

examples/components/CCFRAUD/upload_data/run.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
3: [["South"], ["Midwest"], ["West", "Northeast"]],
1818
4: [["South"], ["West"], ["Midwest"], ["Northeast"]],
1919
}
20-
CATEGORICAL_PROPS = ["category", "region", "gender", "state", "job"]
20+
CATEGORICAL_PROPS = ["category", "region", "gender", "state"]
2121
ENCODERS = {}
2222

2323

@@ -67,23 +67,18 @@ def preprocess_data(df):
6767
useful_props = [
6868
"amt",
6969
"age",
70-
# "cc_num",
7170
"merch_lat",
7271
"merch_long",
7372
"category",
7473
"region",
7574
"gender",
7675
"state",
77-
# "zip",
7876
"lat",
7977
"long",
8078
"city_pop",
81-
"job",
82-
# "dob",
8379
"trans_date_trans_time",
8480
"is_fraud",
8581
]
86-
categorical = ["category", "region", "gender", "state", "job"]
8782

8883
df.loc[:, "age"] = (pd.Timestamp.now() - pd.to_datetime(df["dob"])) // pd.Timedelta(
8984
"1y"
@@ -92,7 +87,7 @@ def preprocess_data(df):
9287
# Filter only useful columns
9388
df = df[useful_props]
9489

95-
for column in categorical:
90+
for column in CATEGORICAL_PROPS:
9691
encoder = ENCODERS.get(column)
9792
encoded_data = encoder.transform(df[column].values.reshape(-1, 1)).toarray()
9893
encoded_df = pd.DataFrame(
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import os
2+
import argparse
3+
import logging
4+
import sys
5+
import numpy as np
6+
7+
from sklearn.preprocessing import StandardScaler
8+
import pandas as pd
9+
import mlflow
10+
11+
SCALERS = {}
12+
13+
14+
def get_arg_parser(parser=None):
15+
"""Parse the command line arguments for merge using argparse.
16+
17+
Args:
18+
parser (argparse.ArgumentParser or CompliantArgumentParser):
19+
an argument parser instance
20+
21+
Returns:
22+
ArgumentParser: the argument parser instance
23+
24+
Notes:
25+
if parser is None, creates a new parser instance
26+
"""
27+
# add arguments that are specific to the component
28+
if parser is None:
29+
parser = argparse.ArgumentParser(description=__doc__)
30+
31+
parser.add_argument("--raw_training_data", type=str, required=True, help="")
32+
parser.add_argument("--raw_testing_data", type=str, required=True, help="")
33+
parser.add_argument("--train_output", type=str, required=True, help="")
34+
parser.add_argument("--test_output", type=str, required=True, help="")
35+
parser.add_argument(
36+
"--metrics_prefix", type=str, required=False, help="Metrics prefix"
37+
)
38+
return parser
39+
40+
41+
def apply_transforms(df):
42+
"""Applies transformation for datetime and numerical columns
43+
44+
Args:
45+
df (pd.DataFrame):
46+
dataframe to transform
47+
48+
Returns:
49+
pd.DataFrame: transformed dataframe
50+
"""
51+
global SCALERS
52+
53+
datetimes = ["trans_date_trans_time"] # "dob"
54+
normalize = [
55+
"age",
56+
"merch_lat",
57+
"merch_long",
58+
"lat",
59+
"long",
60+
"city_pop",
61+
"trans_date_trans_time",
62+
"amt",
63+
]
64+
65+
for column in datetimes:
66+
if column not in df.columns:
67+
continue
68+
df.loc[:, column] = pd.to_datetime(df[column]).view("int64")
69+
for column in normalize:
70+
if column not in df.columns:
71+
continue
72+
73+
if column not in SCALERS:
74+
print(f"Creating encoder for column: {column}")
75+
# Simply set all zeros if the category is unseen
76+
scaler = StandardScaler()
77+
scaler.fit(df[column].values.reshape(-1, 1))
78+
SCALERS[column] = scaler
79+
80+
scaler = SCALERS.get(column)
81+
df.loc[:, column] = scaler.transform(df[column].values.reshape(-1, 1))
82+
83+
return df
84+
85+
86+
def preprocess_data(
87+
raw_training_data,
88+
raw_testing_data,
89+
train_data_dir="./",
90+
test_data_dir="./",
91+
metrics_prefix="default-prefix",
92+
):
93+
"""Preprocess the raw_training_data and raw_testing_data and save the processed data to train_data_dir and test_data_dir.
94+
95+
Args:
96+
raw_training_data: Training data directory that need to be processed
97+
raw_testing_data: Testing data directory that need to be processed
98+
train_data_dir: Train data directory where processed train data will be saved
99+
test_data_dir: Test data directory where processed test data will be saved
100+
Returns:
101+
None
102+
"""
103+
104+
logger.info(
105+
f"Raw Training Data path: {raw_training_data}, Raw Testing Data path: {raw_testing_data}, Processed Training Data dir path: {train_data_dir}, Processed Testing Data dir path: {test_data_dir}"
106+
)
107+
108+
logger.debug(f"Loading data...")
109+
train_df = pd.read_csv(raw_training_data + f"/train.csv", index_col=0)
110+
test_df = pd.read_csv(raw_testing_data + f"/test.csv", index_col=0)
111+
112+
if "is_fraud" in train_df.columns:
113+
fraud_weight = (
114+
train_df["is_fraud"].value_counts()[0]
115+
/ train_df["is_fraud"].value_counts()[1]
116+
)
117+
logger.debug(f"Fraud weight: {fraud_weight}")
118+
np.savetxt(train_data_dir + "/fraud_weight.txt", np.array([fraud_weight]))
119+
120+
logger.debug(f"Applying transformations...")
121+
train_data = apply_transforms(train_df)
122+
test_data = apply_transforms(test_df)
123+
124+
logger.debug(f"Train data samples: {len(train_data)}")
125+
logger.debug(f"Test data samples: {len(test_data)}")
126+
logger.info(f"Saving processed data to {train_data_dir} and {test_data_dir}")
127+
128+
os.makedirs(train_data_dir, exist_ok=True)
129+
os.makedirs(test_data_dir, exist_ok=True)
130+
131+
train_data.to_csv(train_data_dir + "/data.csv")
132+
test_data.to_csv(test_data_dir + "/data.csv")
133+
134+
# Mlflow logging
135+
log_metadata(train_data, test_data, metrics_prefix)
136+
137+
138+
def log_metadata(train_df, test_df, metrics_prefix):
139+
with mlflow.start_run() as mlflow_run:
140+
# get Mlflow client
141+
mlflow_client = mlflow.tracking.client.MlflowClient()
142+
root_run_id = mlflow_run.data.tags.get("mlflow.rootRunId")
143+
logger.debug(f"Root runId: {root_run_id}")
144+
if root_run_id:
145+
mlflow_client.log_metric(
146+
run_id=root_run_id,
147+
key=f"{metrics_prefix}/Number of train datapoints",
148+
value=f"{train_df.shape[0]}",
149+
)
150+
151+
mlflow_client.log_metric(
152+
run_id=root_run_id,
153+
key=f"{metrics_prefix}/Number of test datapoints",
154+
value=f"{test_df.shape[0]}",
155+
)
156+
157+
158+
def main(cli_args=None):
159+
"""Component main function.
160+
161+
It parses arguments and executes run() with the right arguments.
162+
163+
Args:
164+
cli_args (List[str], optional): list of args to feed script, useful for debugging. Defaults to None.
165+
"""
166+
# build an arg parser
167+
parser = get_arg_parser()
168+
# run the parser on cli args
169+
args = parser.parse_args(cli_args)
170+
logger.info(f"Running script with arguments: {args}")
171+
172+
def run():
173+
"""Run script with arguments (the core of the component).
174+
175+
Args:
176+
args (argparse.namespace): command line arguments provided to script
177+
"""
178+
179+
preprocess_data(
180+
args.raw_training_data,
181+
args.raw_testing_data,
182+
args.train_output,
183+
args.test_output,
184+
args.metrics_prefix,
185+
)
186+
187+
run()
188+
189+
190+
if __name__ == "__main__":
191+
# Set logging to sys.out
192+
logger = logging.getLogger(__name__)
193+
logger.setLevel(logging.DEBUG)
194+
log_format = logging.Formatter("[%(asctime)s] [%(levelname)s] - %(message)s")
195+
handler = logging.StreamHandler(sys.stdout)
196+
handler.setLevel(logging.DEBUG)
197+
handler.setFormatter(log_format)
198+
logger.addHandler(handler)
199+
200+
main()

0 commit comments

Comments
 (0)