diff --git a/.github/scripts/test_notebooks.sh b/.github/scripts/test_notebooks.sh index 75ab4b9c..8bd10a35 100644 --- a/.github/scripts/test_notebooks.sh +++ b/.github/scripts/test_notebooks.sh @@ -6,6 +6,8 @@ echo "Removing cells..." python3 .github/scripts/remove_cells.py echo "Downloading datasets..." wget 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv' -O tmp/titanic.csv +wget 'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip' -O tmp/smsspamcollection.zip +unzip tmp/smsspamcollection.zip -d tmp # wget 'https://raw.githubusercontent.com/rinbaruah/COVID_preconditions_Kaggle/master/Data/covid.csv' -O tmp/covid.csv cp tmp/titanic.csv tmp/train.csv diff --git a/client/src/bastionlab/torch/data.py b/client/src/bastionlab/torch/data.py index 363e6786..da077045 100644 --- a/client/src/bastionlab/torch/data.py +++ b/client/src/bastionlab/torch/data.py @@ -101,7 +101,6 @@ def _get_tensor_metadata(meta_bytes: bytes): meta = TensorMetaData() meta.ParseFromString(meta_bytes) - print(meta.input_dtype) return [torch_dtypes[dt] for dt in meta.input_dtype], [ torch.Size(list(meta.input_shape)) ] diff --git a/docs/docs/how-to-guides/covid_19_deep_learning_cleaning.ipynb b/docs/docs/how-to-guides/covid_19_deep_learning_cleaning.ipynb new file mode 100644 index 00000000..c6c2389a --- /dev/null +++ b/docs/docs/how-to-guides/covid_19_deep_learning_cleaning.ipynb @@ -0,0 +1,1195 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "

Training a Deep Learning Model

\n", + " \n", + " \"Open\n", + "
\n", + "________________________________________________________________\n", + "\n", + "AI models are very good at preventing accidents or diagnosing medical conditions. They can analyse pixels systematically in ways our eyes can't and avoid many human errors. But for them to be trained at all, they need massive amounts of data that are often very sensitive since they hold lots of private informations (*name, age, sex, adress, previous medical conditions...*). \n", + "\n", + "Take the Covid-19 crisis for example: many collaborations couldn't happen because of privacy concerns despite the general urgency.\n", + "\n", + "This is where BastionLab comes in. Our framework offers tools to share datasets and train AI models with security garantees. It lets data scientists handle datasets remotely and train ML models without ever having access to the full data in clear.\n", + "\n", + "In this notebook, we'll use a **real-world Covid-19 dataset** to show you how you could **clean datasets**, **run queries** and **visualize data** with BastionLab and Torch. \n", + "\n", + "## Pre-requisites\n", + "___________________________________________\n", + "\n", + "### Installation and dataset\n", + "\n", + "In order to run this notebook, we need to:\n", + "- Have [Python3.7](https://www.python.org/downloads/) (or greater) and [Python Pip](https://pypi.org/project/pip/) already installed\n", + "- Install [BastionLab](https://bastionlab.readthedocs.io/en/latest/docs/getting-started/installation/)\n", + "- Download [the dataset](https://raw.githubusercontent.com/rinbaruah/COVID_preconditions_Kaggle/master/Data/covid.csv) we will be using in this tutorial.\n", + "\n", + "We'll do so by running the code block below. \n", + "\n", + ">You can see our [Installation page](https://bastionlab.readthedocs.io/en/latest/docs/getting-started/installation/) to find the installation method that best suits your needs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# pip packages\n", + "!pip install bastionlab\n", + "!pip install bastionlab_server\n", + "\n", + "# download the dataset\n", + "!wget 'https://raw.githubusercontent.com/rinbaruah/COVID_preconditions_Kaggle/master/Data/covid.csv'" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This dataset collects data on outcomes of Covid cases based on various pre-conditions such as asthma and diabetes. This is a version of a huge dataset provided by the Mexican government based on the population of Mexico, so the insights gained from it may not be valid for other geographical areas in the world." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Launch and connect to the server" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# launch bastionlab_server test package\n", + "import bastionlab_server\n", + "\n", + "srv = bastionlab_server.start()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">*Note that the bastionlab_server package we install here was created for testing purposes. You can also install BastionLab server using our Docker image or from source (especially for non-test purposes). Check out our [Installation Tutorial](https://bastionlab.readthedocs.io/en/latest/docs/tutorials/getting-started/installation.md) for more details.*" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# connect to the server\n", + "from bastionlab import Connection\n", + "\n", + "connection = Connection(\"localhost\")\n", + "client = connection.client" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Upload the dataframe to the server" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll quickly upload the dataset to the server with an open safety policy, since setting up BastionLab is not the focus of this tutorial. It will allows us to demonstrate features without having to approve any data access requests. *You can check out how to define a safe privacy policy [here](https://bastionlab.readthedocs.io/en/latest/docs/tutorials/defining_policy_privacy/).*" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kwabena/base/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "FetchableLazyFrame(identifier=49d13d0e-5b48-4a96-844c-a8a526ac6680)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import polars as pl\n", + "from bastionlab.polars.policy import Policy, TrueRule, Log\n", + "\n", + "df = pl.read_csv(\"covid.csv\")\n", + "\n", + "policy = Policy(safe_zone=TrueRule(), unsafe_handling=Log(), savable=True)\n", + "rdf = client.polars.send_df(df, policy=policy, sanitized_columns=[\"Name\"])\n", + "\n", + "rdf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "This policy is not suitable for production. Please note that we only use it for demonstration purposes, to avoid having to approve any data access requests in the tutorial.

\n", + "\n", + "We'll check that we're properly connected and that we have the authorizations by running a simple query:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': polars.datatypes.Utf8,\n", + " 'sex': polars.datatypes.Int64,\n", + " 'patient_type': polars.datatypes.Int64,\n", + " 'entry_date': polars.datatypes.Utf8,\n", + " 'date_symptoms': polars.datatypes.Utf8,\n", + " 'date_died': polars.datatypes.Utf8,\n", + " 'intubed': polars.datatypes.Int64,\n", + " 'pneumonia': polars.datatypes.Int64,\n", + " 'age': polars.datatypes.Int64,\n", + " 'pregnancy': polars.datatypes.Int64,\n", + " 'diabetes': polars.datatypes.Int64,\n", + " 'copd': polars.datatypes.Int64,\n", + " 'asthma': polars.datatypes.Int64,\n", + " 'inmsupr': polars.datatypes.Int64,\n", + " 'hypertension': polars.datatypes.Int64,\n", + " 'other_disease': polars.datatypes.Int64,\n", + " 'cardiovascular': polars.datatypes.Int64,\n", + " 'obesity': polars.datatypes.Int64,\n", + " 'renal_chronic': polars.datatypes.Int64,\n", + " 'tobacco': polars.datatypes.Int64,\n", + " 'contact_other_covid': polars.datatypes.Int64,\n", + " 'covid_res': polars.datatypes.Int64,\n", + " 'icu': polars.datatypes.Int64}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rdf.schema" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data cleaning\n", + "_______________________________________________________________\n", + "\n", + "Let's start by preparing our dataset.\n", + "\n", + "### Dropping columns\n", + "\n", + "Firstly let's use the `drop` method to remove the columns we don't need for our training model: `entry_date`, `date_symptoms`, `date_died`, `patient_type`, `sex`, `id` and `date`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Dropping columns that don't influence our model\n", + "rdf = rdf.drop(\n", + " [\"entry_date\", \"date_symptoms\", \"date_died\", \"patient_type\", \"sex\", \"id\", \"date\"]\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Checking dtypes\n", + "\n", + "We want to make sure that all categorical columns like `diabetes` have an integer dtype. These are columns that contain either `2` to represent true, the patient did have diabetes, `1` to represent false, the patient didn't have diabetes or `97`, `98` or `99` which are used to represent `unknown`. Any continuous value such as `age` should be represented by a float.\n", + "\n", + "By printing out the schema of our RemoteLazyFrame, we see that `age` is an integer value." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'intubed': polars.datatypes.Int64,\n", + " 'pneumonia': polars.datatypes.Int64,\n", + " 'age': polars.datatypes.Int64,\n", + " 'pregnancy': polars.datatypes.Int64,\n", + " 'diabetes': polars.datatypes.Int64,\n", + " 'copd': polars.datatypes.Int64,\n", + " 'asthma': polars.datatypes.Int64,\n", + " 'inmsupr': polars.datatypes.Int64,\n", + " 'hypertension': polars.datatypes.Int64,\n", + " 'other_disease': polars.datatypes.Int64,\n", + " 'cardiovascular': polars.datatypes.Int64,\n", + " 'obesity': polars.datatypes.Int64,\n", + " 'renal_chronic': polars.datatypes.Int64,\n", + " 'tobacco': polars.datatypes.Int64,\n", + " 'contact_other_covid': polars.datatypes.Int64,\n", + " 'covid_res': polars.datatypes.Int64,\n", + " 'icu': polars.datatypes.Int64}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rdf.schema" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Handling null/unknown values\n", + "\n", + "In the case of this dataset, we don't have null values as such, but we do have the use of `97`, `98` and `99` to signify \"unknown\" in our categorical columns.\n", + "\n", + "To decide the best strategy for handling these values, let's first get a sense of the scale of these unknown values!\n", + "\n", + "Firstly, we will store the names of all these categorical columns in a list. Then we will get the sum of values in these columns which are 97,98 or 99 by using Polars `is_between` function. We will get a percentage of this value against the total values in the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "shape: (1, 17)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "intubed\n", + "\n", + "pneumonia\n", + "\n", + "age\n", + "\n", + "pregnancy\n", + "\n", + "diabetes\n", + "\n", + "copd\n", + "\n", + "asthma\n", + "\n", + "inmsupr\n", + "\n", + "hypertension\n", + "\n", + "other_disease\n", + "\n", + "cardiovascular\n", + "\n", + "obesity\n", + "\n", + "renal_chronic\n", + "\n", + "tobacco\n", + "\n", + "contact_other_covid\n", + "\n", + "covid_res\n", + "\n", + "icu\n", + "
\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "\n", + "f64\n", + "
\n", + "78.505371\n", + "\n", + "0.001941\n", + "\n", + "0.036534\n", + "\n", + "50.952697\n", + "\n", + "0.349628\n", + "\n", + "0.308682\n", + "\n", + "0.309212\n", + "\n", + "0.349452\n", + "\n", + "0.321919\n", + "\n", + "0.458523\n", + "\n", + "0.321566\n", + "\n", + "0.31433\n", + "\n", + "0.316271\n", + "\n", + "0.336568\n", + "\n", + "30.891349\n", + "\n", + "0.0\n", + "\n", + "78.505547\n", + "
\n", + "
" + ], + "text/plain": [ + "shape: (1, 17)\n", + "┌───────────┬─────────┬──────────┬───────────┬─────┬──────────┬────────────┬───────────┬───────────┐\n", + "│ intubed ┆ pneumon ┆ age ┆ pregnancy ┆ ... ┆ tobacco ┆ contact_ot ┆ covid_res ┆ icu │\n", + "│ --- ┆ ia ┆ --- ┆ --- ┆ ┆ --- ┆ her_covid ┆ --- ┆ --- │\n", + "│ f64 ┆ --- ┆ f64 ┆ f64 ┆ ┆ f64 ┆ --- ┆ f64 ┆ f64 │\n", + "│ ┆ f64 ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ │\n", + "╞═══════════╪═════════╪══════════╪═══════════╪═════╪══════════╪════════════╪═══════════╪═══════════╡\n", + "│ 78.505371 ┆ 0.00194 ┆ 0.036534 ┆ 50.952697 ┆ ... ┆ 0.336568 ┆ 30.891349 ┆ 0.0 ┆ 78.505547 │\n", + "│ ┆ 1 ┆ ┆ ┆ ┆ ┆ ┆ ┆ │\n", + "└───────────┴─────────┴──────────┴───────────┴─────┴──────────┴────────────┴───────────┴───────────┘" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Get percentage of values in column between 96 and 100\n", + "percent_missing = rdf.select(\n", + " pl.col(x).is_between(96, 100).sum().alias(x) * 100 / pl.col(x).count()\n", + " for x in rdf.columns\n", + ")\n", + "percent_missing.collect().fetch()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the `intubed`, `pregnancy`, `icu` and `contact_other_covid` columns contain significant amounts of \"unknown\" values, we will `drop()` them from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "rdf = rdf.drop([\"intubed\", \"pregnancy\", \"contact_other_covid\", \"icu\"])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next we will delete any rows which have a value which is not `1` or `2` for our categorical columns. This essentially deletes all of these unknown 97, 98 and 99 values, while ensuring there are no other unexpected values. To do that, we'll use the `filter()` method." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "rdf = rdf.filter(pl.col([x for x in rdf.columns]).is_between(0, 3))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### View dataset size and columns\n", + "\n", + "Now that we have finished cleaning our dataset, we can take a look again at some information about our dataset so we can confirm our dataset is still sufficiently large for training our model." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pneumonia': polars.datatypes.Int64,\n", + " 'age': polars.datatypes.Int64,\n", + " 'diabetes': polars.datatypes.Int64,\n", + " 'copd': polars.datatypes.Int64,\n", + " 'asthma': polars.datatypes.Int64,\n", + " 'inmsupr': polars.datatypes.Int64,\n", + " 'hypertension': polars.datatypes.Int64,\n", + " 'other_disease': polars.datatypes.Int64,\n", + " 'cardiovascular': polars.datatypes.Int64,\n", + " 'obesity': polars.datatypes.Int64,\n", + " 'renal_chronic': polars.datatypes.Int64,\n", + " 'tobacco': polars.datatypes.Int64,\n", + " 'covid_res': polars.datatypes.Int64}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rdf.schema" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "shape: (1, 13)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "pneumonia\n", + "\n", + "age\n", + "\n", + "diabetes\n", + "\n", + "copd\n", + "\n", + "asthma\n", + "\n", + "inmsupr\n", + "\n", + "hypertension\n", + "\n", + "other_disease\n", + "\n", + "cardiovascular\n", + "\n", + "obesity\n", + "\n", + "renal_chronic\n", + "\n", + "tobacco\n", + "\n", + "covid_res\n", + "
\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "\n", + "u32\n", + "
\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "\n", + "3170\n", + "
\n", + "
" + ], + "text/plain": [ + "shape: (1, 13)\n", + "┌───────────┬──────┬──────────┬──────┬─────┬─────────┬───────────────┬─────────┬───────────┐\n", + "│ pneumonia ┆ age ┆ diabetes ┆ copd ┆ ... ┆ obesity ┆ renal_chronic ┆ tobacco ┆ covid_res │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ u32 ┆ u32 ┆ u32 ┆ ┆ u32 ┆ u32 ┆ u32 ┆ u32 │\n", + "╞═══════════╪══════╪══════════╪══════╪═════╪═════════╪═══════════════╪═════════╪═══════════╡\n", + "│ 3170 ┆ 3170 ┆ 3170 ┆ 3170 ┆ ... ┆ 3170 ┆ 3170 ┆ 3170 ┆ 3170 │\n", + "└───────────┴──────┴──────────┴──────┴─────┴─────────┴───────────────┴─────────┴───────────┘" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# get percentage of values in column between 96 and 100\n", + "size = rdf.select(pl.col(x).count().alias(x) for x in rdf.columns)\n", + "size.collect().fetch()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transform label column to binary data" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the original dataset, the label column uses `2` for covid present and `1` for covid absent. We'll change that by using the `width_column()` method with Polar's `\"when-then-otherwise\"` statement to transform those values to binary: `1` for covid present and `0` for covid absent." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "label = \"covid_res\"\n", + "rdf = rdf.with_column(pl.when(pl.col(label) == 2).then(1).otherwise(0).alias(label))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we call `collect()` to run all the data pre-processing applied to the RemoteDataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "rdf = rdf.collect()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data conversion\n", + "_______________________________________________\n", + "\n", + "### Splitting the dataset\n", + "\n", + "Now that our data is clean, let's convert the Covid dataset to a trainable dataset on BastionLab. \n", + "\n", + "First, we'll split the Covid dataset into training and testing datasets using the `train_test_split()` method in BastionLab. We split it so 80% of the information goes to the train set and 20% to test set. It is important that the test dataset is made from the original data but isn't used in the training set." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from bastionlab.polars import train_test_split\n", + "\n", + "# Get all input columns (without the `covid_res` column which is the label) of the RemoteDataFrame\n", + "cols = list(filter(lambda a: a != label, rdf.columns))\n", + "\n", + "# We convert the selected RemoteDataFrames into RemoteArrays\n", + "inputs_array = rdf.select(cols).collect().to_array()\n", + "labels_array = rdf.select(label).collect().to_array()\n", + "\n", + "# The RemoteDataFrame is converted into a RemoteArray and is shuffled and split\n", + "# into training and testing dataset.\n", + "(\n", + " train_inputs_array,\n", + " test_inputs_array,\n", + " train_labels_array,\n", + " test_labels_array,\n", + ") = train_test_split(\n", + " inputs_array,\n", + " labels_array,\n", + " test_size=0.2,\n", + " shuffle=True,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### RemoteDataFrame to RemoteTensor conversion\n", + "\n", + "Deep learning models only accept tensors, so we'll convert both the `train` and `test` inputs `RemoteDataFrame` into `RemoteTensor` in the following code block." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "train_inputs_tensor = train_inputs_array.to_tensor()\n", + "test_inputs_tensor = test_inputs_array.to_tensor()\n", + "\n", + "train_labels_tensor = train_labels_array.to_tensor()\n", + "test_labels_tensor = test_labels_array.to_tensor()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because the created model accepts input tensors as `float32`, we cast the `RemoteTensors` to `torch.float32` type." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Cast RemoteTensor from `int64` to `float`\n", + "train_inputs_tensor = train_inputs_tensor.to(torch.float)\n", + "test_inputs_tensor = test_inputs_tensor.to(torch.float)\n", + "\n", + "train_labels_tensor = train_labels_tensor.to(torch.float)\n", + "test_labels_tensor = test_labels_tensor.to(torch.float)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next step is to convert the train and test labels `RemoteDataFrames` into `RemoteTensors`." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, a `RemoteDataset` is created with the `RemoteTensors` created above, which is how the model will be able to read the data from now on.\n", + "> Note that a `RemoteDataset` is a pointer to a collection of `RemoteTensors` (*inputs, labels*)." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = client.torch.RemoteDataset(\n", + " inputs=[train_inputs_tensor], labels=train_labels_tensor\n", + ")\n", + "test_dataset = client.torch.RemoteDataset(\n", + " inputs=[test_inputs_tensor], labels=test_labels_tensor\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating the deep learning regression model\n", + "_______________________________________________________\n", + "\n", + "In this section, we'll use a PyTorch linear layer to create a linear regression model.\n", + "\n", + "PyTorch models are used to simply creating nueral networks. And in the instance of this example, we use a [Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) layer to create a simple layer to perform our linear regression.\n", + "\n", + "Read more about PyTorch models [here](https://pytorch.org/docs/stable/nn.html)\n", + "\n", + "The model has 12 features (*age, asthma, etc*) and outputs 1 feature (*1-covid present, 0-covid absent*)." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of input features 12\n" + ] + } + ], + "source": [ + "import torch\n", + "from torch.nn import Module, Linear\n", + "\n", + "# The in_feature is the size of nth dimension of the input tensor.\n", + "in_features = train_inputs_tensor.shape[-1]\n", + "\n", + "print(\"Number of input features\", in_features)\n", + "\n", + "\n", + "class LinearRegression(Module):\n", + " def __init__(self, in_features: int) -> None:\n", + " super().__init__()\n", + " self.layer1 = Linear(in_features, 1)\n", + "\n", + " def forward(self, tensor):\n", + " return self.layer1(tensor)\n", + "\n", + "\n", + "# An instance of our Covid LinearRegression model is created\n", + "model = LinearRegression(in_features=in_features)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uploading the model to BastionLab\n", + "__________________________________________________\n", + "\n", + "Once we have created and instantiated our model, passing the right number of input feature, which is twelve (12), we can upload it to the BastionLab server. \n", + "\n", + "In this step, you'll be able to parametrize the learner with the type of loss, the optimizer to use, and the dataset to train on.\n", + "\n", + "A learner is BastionLab's representation of a trainer (or the training loop from classical deep learning literature). It takes the reference to the `RemoteDataset`, how many batches to train with at a time (`batch_size`), and a few other parameters that will control how we train our data.\n", + "\n", + "Defining our learner, we use the `l2` loss function since we are performing a binary classification task with our linear regressor.\n", + "\n", + "We also use the SGD optimizer which is one of the most optimizers from the literature. And, the learning rate (`lr`) and momentum `momentum` we chosen from a hyperparameter fine tuning task." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sending LinearRegression: 100%|████████████████████| 3.15k/3.15k [00:00<00:00, 1.31MB/s]\n" + ] + } + ], + "source": [ + "from bastionlab.torch.optimizer import SGD\n", + "\n", + "# The model is uploaded to the BastionLab Torch service.\n", + "remote_learner = connection.client.torch.RemoteLearner(\n", + " model,\n", + " train_dataset,\n", + " max_batch_size=64,\n", + " loss=\"l2\",\n", + " optimizer=SGD(lr=1e-3, momentum=0.09),\n", + " model_name=\"LinearRegression\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `fit` method is very synanymous to the ones in machine learning literature, and in BastionLab, it triggers the training loop remotely. \n", + "\n", + "With the parameter `nb_epochs`, we can control how many epochs we train our model.\n", + "\n", + "When we train with all the batches in the `RemoteDataset`, it describes a single `epoch`.\n", + "\n", + "Again, the number of epochs was chosen after a hyperparameter fine-tuning task." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 19/100 - train: 54%|██████████▊ | 21/39 [00:00<00:00, 97.71batch/s, l2=0.0000 (+/- 814.7276)] \n", + "Epoch 39/100 - train: 82%|████████████████▍ | 32/39 [00:00<00:00, 150.41batch/s, l2=0.0000 (+/- 543.1517)] \n", + "Epoch 60/100 - train: 62%|████████████▎ | 24/39 [00:00<00:00, 112.41batch/s, l2=0.0000 (+/- 716.9603)] \n", + "Epoch 80/100 - train: 92%|██████████████████▍ | 36/39 [00:00<00:00, 169.82batch/s, l2=10.0000 (+/- 484.4327)] \n", + "Epoch 100/100 - train: 100%|████████████████████| 39/39 [00:00<00:00, 6972.63batch/s, l2=10.0000 (+/- 448.1003)]\n" + ] + } + ], + "source": [ + "# The linear regression model is trained on the dataset here.\n", + "remote_learner.fit(nb_epochs=100)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to verify the model's accuracy, we validate the model on the `test_dataset` using the `test` method on the `RemoteLearner`.\n", + "\n", + "This method performs a single epoch iteration over our `RemoteDataset`." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1/1 - test: 100%|████████████████████| 39/39 [00:00<00:00, 187.11batch/s, l2=10.0000 (+/- 459.5899)] \n" + ] + } + ], + "source": [ + "# The linear regression model is validated with the `test_dataset`\n", + "remote_learner.test(test_dataset)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training ends, the trained model is fetched using the `get_model()` method.\n", + "\n", + "`get_model` also exists on the `RemoteLearner`, a method used by BastionLab to download the trained model." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LinearRegression(\n", + " (layer1): Linear(in_features=12, out_features=1, bias=True)\n", + ")" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The trained model is fetched from the BastionLab Torch service.\n", + "remote_learner.get_model()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our deep learning model has been trained. All that's left to do is close our connection:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# connection.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b02da7949da3af218e39ce2f31e32c0a1b93d864bdc2c2f7ec5cedc3569d5902" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/how-to-guides/distilbert_example_notebook.ipynb b/docs/docs/how-to-guides/distilbert_example_notebook.ipynb index c3d0b6ca..257e9d0a 100644 --- a/docs/docs/how-to-guides/distilbert_example_notebook.ipynb +++ b/docs/docs/how-to-guides/distilbert_example_notebook.ipynb @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -62,11 +62,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "!wget https: // archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip\n", + "!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip\n", "!unzip smsspamcollection.zip" ] }, @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -196,7 +196,7 @@ "└───────┴─────────────────────────────────────┘" ] }, - "execution_count": 1, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -228,9 +228,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kwabena/base/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "from transformers import DistilBertTokenizer\n", "import torch\n", @@ -284,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -322,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -357,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -389,17 +398,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias']\n", + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } @@ -463,27 +472,18 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sending SMSSpamCollection: 100%|████████████████████| 35.7k/35.7k [00:00<00:00, 38.2MB/s]\n", - "Sending SMSSpamCollection (test): 100%|████████████████████| 35.7k/35.7k [00:00<00:00, 34.6MB/s]\n" - ] - } - ], + "outputs": [], "source": [ "from bastionlab import Connection\n", "\n", "# The Data owner privately uploads their model online\n", "client = Connection(\"localhost\").client.torch\n", "\n", - "remote_dataset = client.RemoteDataset(\n", - " train_set, validation_set, name=\"SMSSpamCollection\"\n", - ")" + "train_dataset = client.RemoteDataset(train_set, name=\"SMSSpamCollection-train\")\n", + "\n", + "test_dataset = client.RemoteDataset(validation_set, name=\"SMSSpamCollection-validation\")" ] }, { @@ -507,16 +507,17 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['SMSSpamCollection (e4377764d92780aca061fb21f3afcb8b3b0d44cc30a27d9e3123d92eb63259e1): size=64, desc=N/A']" + "['identifier: \"df65d0fe-aed8-47f8-bd3a-bd4b7044d6aa\"\\nname: \"SMSSpamCollection-validation\"\\n',\n", + " 'identifier: \"a5259dfa-5ba0-468c-82b0-beeeff201fd7\"\\nname: \"SMSSpamCollection-train\"\\n']" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -525,7 +526,7 @@ "client = Connection(\"localhost\").client.torch\n", "\n", "# Fetches the list of all `RemoteDataset` on the BastionLab Torch service\n", - "remote_datasets = client.list_remote_datasets()\n", + "remote_datasets = client.get_available_datasets()\n", "\n", "# Here, we print the list of the available RemoteDatasets on the BastionLab Torch service\n", "# It will display in this form `[\"(Name): nb_samples=int, dtype=str\"]`\n", @@ -542,23 +543,20 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "RemoteDataset(identifier=a5259dfa-5ba0-468c-82b0-beeeff201fd7, name=SMSSpamCollection-train, privacy_limit=-1.0, inputs=[RemoteTensor(identifier=15ddf154-8862-43d9-8f5c-cb57d1e6db38, dtype=torch.int64, shape=torch.Size([64, 32])), RemoteTensor(identifier=faf17cf0-4793-4897-bbbc-54ca342bcbe6, dtype=torch.int64, shape=torch.Size([64, 32]))], label=RemoteTensor(identifier=fe73e77e-2372-4794-b76a-e17db19e86f8, dtype=torch.int64, shape=torch.Size([64]))) RemoteDataset(identifier=df65d0fe-aed8-47f8-bd3a-bd4b7044d6aa, name=SMSSpamCollection-validation, privacy_limit=-1.0, inputs=[RemoteTensor(identifier=1698bfb4-2cf8-4565-b34f-28e92f6c3dda, dtype=torch.int64, shape=torch.Size([64, 32])), RemoteTensor(identifier=6a239f8b-f053-4b35-ab27-fdc4c7f670a1, dtype=torch.int64, shape=torch.Size([64, 32]))], label=RemoteTensor(identifier=f88a462e-0606-410b-b5f4-629ceae2fe3d, dtype=torch.int64, shape=torch.Size([64])))\n" + ] } ], "source": [ "# Here, only the first element of the dataset is printed\n", - "remote_datasets[0]" + "print(train_dataset, test_dataset)" ] }, { @@ -582,22 +580,24 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Sending DistilBERT: 100%|████████████████████| 268M/268M [00:03<00:00, 71.2MB/s] \n", - "Epoch 1/2 - train: 100%|████████████████████| 32/32 [00:13<00:00, 2.46batch/s, cross_entropy=0.5558 (+/- 0.0000)] \n", - "Epoch 2/2 - train: 100%|████████████████████| 32/32 [00:11<00:00, 2.70batch/s, cross_entropy=0.5000 (+/- 0.0000)]\n", - "Epoch 1/1 - test: 100%|████████████████████| 32/32 [00:01<00:00, 26.33batch/s, accuracy=0.8874 (+/- 0.0000)] \n" + "/home/kwabena/base/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py:217: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + " mask, torch.tensor(torch.finfo(scores.dtype).min)\n", + "Sending DistilBERT: 100%|████████████████████| 268M/268M [00:09<00:00, 29.7MB/s] \n", + "Epoch 1/2 - train: 100%|████████████████████| 32/32 [00:09<00:00, 3.51batch/s, cross_entropy=10.0000 (+/- 457.0489)] \n", + "Epoch 2/2 - train: 100%|████████████████████| 32/32 [00:07<00:00, 4.37batch/s, cross_entropy=10.0000 (+/- 443.1989)]\n", + "Epoch 1/1 - test: 100%|████████████████████| 32/32 [00:00<00:00, 52.59batch/s, accuracy=1.0000 (+/- 823.3106)] \n" ] } ], "source": [ - "from bastionlab.torch.optimizer_config import Adam\n", + "from bastionlab.torch.optimizer import Adam\n", "\n", "# Torch is selected from the created client connection\n", "# BastionLab has multiple services (Polars, Torch, etc)\n", @@ -608,7 +608,7 @@ "# the training dataset to use.\n", "remote_learner = client.RemoteLearner(\n", " model,\n", - " remote_datasets[0],\n", + " train_dataset,\n", " max_batch_size=2,\n", " loss=\"cross_entropy\",\n", " optimizer=Adam(lr=5e-5),\n", @@ -619,7 +619,7 @@ "remote_learner.fit(nb_epochs=2)\n", "\n", "# The trained model is tested with the `accuracy` metric.\n", - "remote_learner.test(metric=\"accuracy\")\n", + "remote_learner.test(test_dataset=test_dataset, metric=\"accuracy\")\n", "\n", "# The trained model is fetched using the get_model() method\n", "trained_model = remote_learner.get_model()" @@ -628,7 +628,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "base", "language": "python", "name": "python3" }, @@ -646,7 +646,7 @@ }, "vscode": { "interpreter": { - "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + "hash": "b02da7949da3af218e39ce2f31e32c0a1b93d864bdc2c2f7ec5cedc3569d5902" } } },