diff --git a/docs/docs/tutorials/memory_quotas_and_deleting_dataframes.ipynb b/docs/docs/tutorials/memory_quotas_and_deleting_dataframes.ipynb new file mode 100644 index 00000000..ccdcfa0c --- /dev/null +++ b/docs/docs/tutorials/memory_quotas_and_deleting_dataframes.ipynb @@ -0,0 +1,427 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "

Memory quotas and deletion of dataframes

\n", + " \n", + " \"Open\n", + "
\n", + "\n", + "________________________\n", + "\n", + "Memory quotas define how much memory each user can consume on the server. \n", + "\n", + "Because each operation results in a new dataframe, it becomes necessary to add limits to how much memory each user may consume.\n", + "\n", + "This tutorial demonstrates how memory quotas work and how to delete dataframes to free memory.\n", + "\n", + "Data owners can set the desired memory quota (in bytes) in the config.toml before launching the server.\n", + "\n", + "Memory quotas only work when authentication is enabled.\n", + "\n", + "## Pre-requisites\n", + "___________________________________________\n", + "\n", + "### Installation and dataset\n", + "\n", + "In order to run this notebook, we need to:\n", + "- Have [Python3.7](https://www.python.org/downloads/) (or greater) and [Python Pip](https://pypi.org/project/pip/) installed\n", + "- Install [BastionLab](https://bastionlab.readthedocs.io/en/latest/docs/getting-started/installation/)\n", + "- Download [the dataset](https://www.kaggle.com/competitions/titanic) we will be using in this tutorial.\n", + "\n", + "We'll do so by running the code block below. \n", + "\n", + ">If you are running this notebook on your machine instead of [Google Colab](https://colab.research.google.com/github/mithril-security/bastionlab/blob/v0.3.6/docs/docs/tutorials/memory_quotas.ipynb), you can see our [Installation page](https://bastionlab.readthedocs.io/en/latest/docs/getting-started/installation/) to find the installation method that best suits your needs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2023-02-14 17:42:28-- https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8000::154, 2606:50c0:8003::154, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 60302 (59K) [text/plain]\n", + "Saving to: ‘titanic.csv.6’\n", + "\n", + "titanic.csv.6 100%[===================>] 58.89K --.-KB/s in 0.005s \n", + "\n", + "2023-02-14 17:42:28 (12.4 MB/s) - ‘titanic.csv.6’ saved [60302/60302]\n", + "\n" + ] + } + ], + "source": [ + "# pip packages\n", + "!pip install bastionlab\n", + "!pip install bastionlab_server\n", + "\n", + "# download the Titanic dataset\n", + "!wget 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our dataset is based on the Titanic dataset, one of the most popular datasets used for understanding machine learning which contains information relating to the passengers aboard the Titanic." + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": {}, + "outputs": [], + "source": [ + "from bastionlab import Identity\n", + "\n", + "# Create `Identity` for data owner.\n", + "data_owner = Identity.create(\"data_owner\")\n", + "\n", + "# Create `Identity` for data scientist.\n", + "data_scientist = Identity.create(\"data_scientist\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Launch and connect to the server" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BastionLab server (version 0.3.7) already installed\n", + "Libtorch (version 1.13.1) already installed\n", + "TLS certificates already generated\n", + "Bastionlab server is now running on port 50056\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2023-02-15T16:19:37Z INFO bastionlab] Authentication is enabled.\n", + "[2023-02-15T16:19:37Z INFO bastionlab] Telemetry is disabled.\n", + "[2023-02-15T16:19:37Z INFO bastionlab] Successfully loaded saved dataframes\n", + "[2023-02-15T16:19:37Z INFO bastionlab] BastionLab server listening on 0.0.0.0:50056.\n", + "[2023-02-15T16:19:37Z INFO bastionlab] Server ready to take requests\n" + ] + } + ], + "source": [ + "# launch bastionlab_server test package\n", + "import bastionlab_server\n", + "\n", + "# the True paramter turns authentication on for the server, the number is the memory quota (in bytes) permiited per user\n", + "srv = bastionlab_server.start(True, 121860)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">*Note that the bastionlab_server package we install here was created for testing purposes. You can also install BastionLab server using our Docker image or from source (especially for non-test purposes). Check out our [Installation Tutorial](../getting-started/installation.md) for more details.*\n", + "\n", + "It's important to note that in a typical workflow, the data owner would send a set of keys to the server, so that authentication is required for all users at the point of connection. **BastionLab offers the authentication feature**, authentication must be enabled for memory quotas to work. You can refer to the [authentication tutorial](https://bastionlab.readthedocs.io/en/latest/docs/tutorials/authentication/) to set up authentication on your server." + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": {}, + "outputs": [], + "source": [ + "# connecting to the server\n", + "from bastionlab import Connection\n", + "\n", + "connection = Connection(\"localhost\", identity=data_scientist)\n", + "client = connection.client" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Upload the dataframe to the server" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll quickly upload the dataset to the server with an open safety policy, since setting up BastionLab is not the focus of this tutorial. It will allows us to demonstrate features without having to approve any data access requests. You can check out how to define a safe privacy policy [here](https://bastionlab.readthedocs.io/en/latest/docs/tutorials/defining_policy_privacy/)." + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FetchableLazyFrame(identifier=2eaabc4e-a908-4246-8883-57e5f2ee1c36)" + ] + }, + "execution_count": 145, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import polars as pl\n", + "from bastionlab.polars.policy import Policy, TrueRule, Log\n", + "\n", + "df = pl.read_csv(\"titanic.csv\")\n", + "\n", + "policy = Policy(safe_zone=TrueRule(), unsafe_handling=Log(), savable=True)\n", + "rdf = client.polars.send_df(df, policy=policy, sanitized_columns=[\"Name\"])\n", + "\n", + "rdf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">*This policy is not suitable for production. Please note that we only use it for demonstration purposes, to avoid having to approve any data access requests in the tutorial.
*\n", + "\n", + "We'll check that we're properly connected and that we have the authorizations by running a simple query:" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [], + "source": [ + "per_class_rates = (\n", + " rdf.select([pl.col(\"Pclass\"), pl.col(\"Survived\")])\n", + " .groupby(pl.col(\"Pclass\"))\n", + " .agg(pl.col(\"Survived\").mean())\n", + " .sort(\"Survived\", reverse=True)\n", + " .collect()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List the dataframes available on the server to see if the query executed successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FetchableLazyFrame(identifier=9f32f95c-cb38-43c3-9c8f-cad90df4cd3a),\n", + " FetchableLazyFrame(identifier=2eaabc4e-a908-4246-8883-57e5f2ee1c36)]" + ] + }, + "execution_count": 147, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.polars.list_dfs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run that query one more time and see what happens." + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "metadata": {}, + "outputs": [ + { + "ename": "GRPCException", + "evalue": "Received gRPC error: code=StatusCode.UNKNOWN message=You have consumed your entire memory quota. Please ask the data owner to delete your dataframes to free memory.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31m_InactiveRpcError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/bastionlab/client/src/bastionlab/errors.py:90\u001b[0m, in \u001b[0;36mGRPCException.map_error\u001b[0;34m(f)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 90\u001b[0m \u001b[39mreturn\u001b[39;00m f()\n\u001b[1;32m 91\u001b[0m \u001b[39mexcept\u001b[39;00m _InactiveRpcError \u001b[39mas\u001b[39;00m e:\n", + "File \u001b[0;32m~/bastionlab/client/src/bastionlab/polars/client.py:153\u001b[0m, in \u001b[0;36mBastionLabPolars._run_query..\u001b[0;34m()\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclient\u001b[39m.\u001b[39mrefresh_session_if_needed()\n\u001b[1;32m 152\u001b[0m res \u001b[39m=\u001b[39m GRPCException\u001b[39m.\u001b[39mmap_error(\n\u001b[0;32m--> 153\u001b[0m \u001b[39mlambda\u001b[39;00m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstub\u001b[39m.\u001b[39;49mRunQuery(Query(composite_plan\u001b[39m=\u001b[39;49mcomposite_plan))\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m \u001b[39mreturn\u001b[39;00m FetchableLazyFrame\u001b[39m.\u001b[39m_from_reference(\u001b[39mself\u001b[39m, res)\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/grpc/_channel.py:946\u001b[0m, in \u001b[0;36m_UnaryUnaryMultiCallable.__call__\u001b[0;34m(self, request, timeout, metadata, credentials, wait_for_ready, compression)\u001b[0m\n\u001b[1;32m 944\u001b[0m state, call, \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_blocking(request, timeout, metadata, credentials,\n\u001b[1;32m 945\u001b[0m wait_for_ready, compression)\n\u001b[0;32m--> 946\u001b[0m \u001b[39mreturn\u001b[39;00m _end_unary_response_blocking(state, call, \u001b[39mFalse\u001b[39;49;00m, \u001b[39mNone\u001b[39;49;00m)\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/grpc/_channel.py:849\u001b[0m, in \u001b[0;36m_end_unary_response_blocking\u001b[0;34m(state, call, with_call, deadline)\u001b[0m\n\u001b[1;32m 848\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 849\u001b[0m \u001b[39mraise\u001b[39;00m _InactiveRpcError(state)\n", + "\u001b[0;31m_InactiveRpcError\u001b[0m: <_InactiveRpcError of RPC that terminated with:\n\tstatus = StatusCode.UNKNOWN\n\tdetails = \"You have consumed your entire memory quota. Please ask the data owner to delete your dataframes to free memory.\"\n\tdebug_error_string = \"{\"created\":\"@1676479058.963225235\",\"description\":\"Error received from peer ipv4:127.0.0.1:50056\",\"file\":\"src/core/lib/surface/call.cc\",\"file_line\":966,\"grpc_message\":\"You have consumed your entire memory quota. Please ask the data owner to delete your dataframes to free memory.\",\"grpc_status\":2}\"\n>", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mGRPCException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [148], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m per_class_rates \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m----> 2\u001b[0m \u001b[43mrdf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mPclass\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mSurvived\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroupby\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mPclass\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43magg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mSurvived\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msort\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mSurvived\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreverse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m )\n", + "File \u001b[0;32m~/bastionlab/client/src/bastionlab/polars/remote_polars.py:309\u001b[0m, in \u001b[0;36mRemoteLazyFrame.collect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcollect\u001b[39m(\u001b[39mself\u001b[39m: LDF) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m LDF:\n\u001b[1;32m 305\u001b[0m \u001b[39m\"\"\"runs any pending queries/actions on RemoteLazyFrame that have not yet been performed.\u001b[39;00m\n\u001b[1;32m 306\u001b[0m \u001b[39m Returns:\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \u001b[39m FetchableLazyFrame: FetchableLazyFrame of datarame after any queries have been performed\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 309\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_meta\u001b[39m.\u001b[39;49m_polars_client\u001b[39m.\u001b[39;49m_run_query(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcomposite_plan)\n", + "File \u001b[0;32m~/bastionlab/client/src/bastionlab/polars/client.py:152\u001b[0m, in \u001b[0;36mBastionLabPolars._run_query\u001b[0;34m(self, composite_plan)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mremote_polars\u001b[39;00m \u001b[39mimport\u001b[39;00m FetchableLazyFrame\n\u001b[1;32m 150\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclient\u001b[39m.\u001b[39mrefresh_session_if_needed()\n\u001b[0;32m--> 152\u001b[0m res \u001b[39m=\u001b[39m GRPCException\u001b[39m.\u001b[39;49mmap_error(\n\u001b[1;32m 153\u001b[0m \u001b[39mlambda\u001b[39;49;00m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstub\u001b[39m.\u001b[39;49mRunQuery(Query(composite_plan\u001b[39m=\u001b[39;49mcomposite_plan))\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m \u001b[39mreturn\u001b[39;00m FetchableLazyFrame\u001b[39m.\u001b[39m_from_reference(\u001b[39mself\u001b[39m, res)\n", + "File \u001b[0;32m~/bastionlab/client/src/bastionlab/errors.py:92\u001b[0m, in \u001b[0;36mGRPCException.map_error\u001b[0;34m(f)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mreturn\u001b[39;00m f()\n\u001b[1;32m 91\u001b[0m \u001b[39mexcept\u001b[39;00m _InactiveRpcError \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m---> 92\u001b[0m \u001b[39mraise\u001b[39;00m GRPCException(e)\n\u001b[1;32m 93\u001b[0m \u001b[39mexcept\u001b[39;00m _MultiThreadedRendezvous \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 94\u001b[0m \u001b[39mraise\u001b[39;00m GRPCException(e)\n", + "\u001b[0;31mGRPCException\u001b[0m: Received gRPC error: code=StatusCode.UNKNOWN message=You have consumed your entire memory quota. Please ask the data owner to delete your dataframes to free memory." + ] + } + ], + "source": [ + "per_class_rates = (\n", + " rdf.select([pl.col(\"Pclass\"), pl.col(\"Survived\")])\n", + " .groupby(pl.col(\"Pclass\"))\n", + " .agg(pl.col(\"Survived\").mean())\n", + " .sort(\"Survived\", reverse=True)\n", + " .collect()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've hit the memory limit and cannot create any more dataframes. Let's double check by listing the dataframes available on the server." + ] + }, + { + "cell_type": "code", + "execution_count": 149, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FetchableLazyFrame(identifier=9f32f95c-cb38-43c3-9c8f-cad90df4cd3a),\n", + " FetchableLazyFrame(identifier=2eaabc4e-a908-4246-8883-57e5f2ee1c36)]" + ] + }, + "execution_count": 149, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.polars.list_dfs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deleting the dataframe\n", + "_______________________________________\n", + "\n", + "Now let's delete the resulting dataframe from our previous operation. This will delete the dataframe from the memory of the server, if the dataframe was previously saved using save() for persistence, it will be deleted from local storage as well.\n", + "\n", + " Data owners can delete any dataframe. Data scientists can only delete dataframes that they created, for example as a result of an operation. " + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FetchableLazyFrame(identifier=2eaabc4e-a908-4246-8883-57e5f2ee1c36)]" + ] + }, + "execution_count": 150, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "per_class_rates.delete()\n", + "\n", + "client.polars.list_dfs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the deleted dataframe is no longer available on the server." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, close the connection." + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BastionLab's server already stopped\n" + ] + }, + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 142, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# connection.close()\n", + "# bastionlab_server.stop(srv)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "vscode": { + "interpreter": { + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/server/bastionlab_common/src/config.rs b/server/bastionlab_common/src/config.rs index 831f40a1..8d015fb5 100644 --- a/server/bastionlab_common/src/config.rs +++ b/server/bastionlab_common/src/config.rs @@ -26,6 +26,7 @@ pub struct BastionLabConfig { pub public_keys_directory: String, pub session_expiry_in_secs: u64, + pub max_memory_consumption: usize, } fn uri_to_socket(uri: &Uri) -> Result { @@ -49,6 +50,10 @@ impl BastionLabConfig { pub fn session_expiry(&self) -> Result { Ok(self.session_expiry_in_secs) } + + pub fn max_memory(&self) -> Result { + Ok(self.max_memory_consumption) + } } fn deserialize_uri<'de, D>(deserializer: D) -> Result diff --git a/server/bastionlab_common/src/lib.rs b/server/bastionlab_common/src/lib.rs index 87d59d39..f14bf757 100644 --- a/server/bastionlab_common/src/lib.rs +++ b/server/bastionlab_common/src/lib.rs @@ -5,6 +5,7 @@ pub mod config; pub mod prelude; pub mod session; pub mod telemetry; +pub mod tracking; pub mod session_proto { tonic::include_proto!("bastionlab"); diff --git a/server/bastionlab_common/src/tracking.rs b/server/bastionlab_common/src/tracking.rs new file mode 100644 index 00000000..230af16a --- /dev/null +++ b/server/bastionlab_common/src/tracking.rs @@ -0,0 +1,60 @@ +use crate::session::SessionManager; +use std::collections::HashMap; +use std::sync::{Arc, Mutex, RwLock}; +use tonic::Status; + +#[derive(Debug)] +pub struct Tracking { + sess_manager: Arc, + //Maps users to their total consumption and a hashmap of their dfs and their sizes + pub memory_quota: Arc)>>>, + max_memory: Mutex, + pub dataframe_user: Arc>>, //Maps dataframe identifiers to users +} + +impl Tracking { + pub fn new(sess_manager: Arc, max_memory: usize) -> Self { + Self { + sess_manager, + memory_quota: Arc::new(RwLock::new(HashMap::new())), + max_memory: Mutex::new(max_memory), + dataframe_user: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub fn memory_quota_check( + &self, + size: usize, + user_id: String, + identifier: String, + ) -> Result<(), Status> { + //We return immediately if auth is disabled + if !self.sess_manager.auth_enabled() { + return Ok(()); + } + + let mut memory_quota = self.memory_quota.write().unwrap(); + let consumption = memory_quota.get(&user_id); + let resulting_consumption = match consumption { + Some((consumption, identifiers)) => { + if consumption + size > *self.max_memory.lock().unwrap() { + return Err(Status::unknown( + "You have consumed your entire memory quota. Please delete some of your dataframes to free memory.", + )); + } + let mut identifiers = identifiers.to_owned(); + identifiers.insert(identifier.clone(), size); + (consumption + size, identifiers) + } + None => { + let mut hash_map = HashMap::new(); + hash_map.insert(identifier.clone(), size); + (size, hash_map) + } + }; + memory_quota.insert(user_id.clone(), resulting_consumption); + let mut dataframe_user = self.dataframe_user.write().unwrap(); + dataframe_user.insert(identifier, user_id); + Ok(()) + } +} diff --git a/server/bastionlab_polars/src/lib.rs b/server/bastionlab_polars/src/lib.rs index ddaab1a7..b88001e2 100644 --- a/server/bastionlab_polars/src/lib.rs +++ b/server/bastionlab_polars/src/lib.rs @@ -4,6 +4,7 @@ use bastionlab_common::{ session::SessionManager, session_proto::ClientInfo, telemetry::{self, TelemetryEventProps}, + tracking::Tracking, }; use polars::prelude::*; @@ -103,14 +104,16 @@ pub struct BastionLabPolars { dataframes: Arc>>, arrays: Arc>>, sess_manager: Arc, + tracking: Arc, } impl BastionLabPolars { - pub fn new(sess_manager: Arc) -> Self { + pub fn new(sess_manager: Arc, tracking: Arc) -> Self { Self { dataframes: Arc::new(RwLock::new(HashMap::new())), arrays: Arc::new(RwLock::new(HashMap::new())), sess_manager, + tracking, } } @@ -312,11 +315,14 @@ Reason: {}", Ok(res) } - pub fn insert_df(&self, df: DataFrameArtifact) -> String { - let mut dfs = self.dataframes.write().unwrap(); + pub fn insert_df(&self, df: DataFrameArtifact, user_id: String) -> Result { let identifier = format!("{}", Uuid::new_v4()); + let size = df.dataframe.estimated_size(); + self.tracking + .memory_quota_check(size, user_id, identifier.clone())?; + let mut dfs = self.dataframes.write().unwrap(); dfs.insert(identifier.clone(), df); - identifier + Ok(identifier) } pub fn insert_array(&self, array: ArrayStore) -> String { @@ -393,12 +399,42 @@ Reason: {}", Ok(()) } - pub fn delete_dfs(&self, identifier: &str) -> Result<(), Error> { + pub fn delete_dfs(&self, identifier: &str, user_id: String) -> Result<(), Status> { + let owner_check = self.sess_manager.verify_if_owner(&user_id)?; + + //Removes the memory occupied by this df from memory quota + let mut memory_quota = self.tracking.memory_quota.write().unwrap(); + let mut dataframe_user = self.tracking.dataframe_user.write().unwrap(); + + let dataframe_owner = if owner_check { + dataframe_user.get(identifier).unwrap() + } else { + let dataframe_owner = dataframe_user.get(identifier).unwrap(); + if dataframe_owner == &user_id { + dataframe_owner + } else { + return Err(Status::invalid_argument( + "This dataframe does not belong to you.", + )); + } + }; + let mut dfs = self.dataframes.write().unwrap(); dfs.remove(identifier); let path = "data_frames/".to_owned() + identifier + ".json"; std::fs::remove_file(path).unwrap_or(()); + + let (mut consumption, id_sizes) = memory_quota.get(dataframe_owner).unwrap(); + let df_size = id_sizes.get(identifier).unwrap(); + consumption = consumption - df_size; + + let mut id_sizes = id_sizes.to_owned(); + id_sizes.remove(identifier); + memory_quota.insert(user_id, (consumption, id_sizes)); + + dataframe_user.remove(identifier); + Ok(()) } } @@ -436,7 +472,7 @@ impl PolarsService for BastionLabPolars { .map_err(|e| Status::internal(format!("Polars error: {e}")))?; let header = get_df_header(&res.dataframe)?; - let identifier = self.insert_df(res); + let identifier = self.insert_df(res, user_id)?; let elapsed = start_time.elapsed(); @@ -461,10 +497,12 @@ impl PolarsService for BastionLabPolars { let start_time = Instant::now(); let token = self.sess_manager.get_token(&request)?; + let user_id = self.sess_manager.get_user_id(token.clone())?; + let client_info = self.sess_manager.get_client_info(token)?; let (df, hash) = unserialize_dataframe(request.into_inner()).await?; let header = get_df_header(&df.dataframe)?; - let identifier = self.insert_df(df); + let identifier = self.insert_df(df, user_id)?; let elapsed = start_time.elapsed(); telemetry::add_event( @@ -560,20 +598,22 @@ impl PolarsService for BastionLabPolars { let identifier = &request.get_ref().identifier; let user_id = self.sess_manager.get_user_id(token.clone())?; - let owner_check = self.sess_manager.verify_if_owner(&user_id)?; - if owner_check { - self.delete_dfs(identifier)?; - } else { - return Err(Status::internal("Only data owners can delete dataframes.")); - } + self.delete_dfs(identifier, user_id)?; telemetry::add_event( TelemetryEventProps::DeleteDataframe { dataset_name: Some(identifier.clone()), }, Some(self.sess_manager.get_client_info(token)?), ); + + info!( + "Succesfully deleted dataframe {} from the server", + identifier.clone() + ); + Ok(Response::new(Empty {})) } + async fn split( &self, request: Request, diff --git a/server/python-wheel/src/bastionlab_server/server.py b/server/python-wheel/src/bastionlab_server/server.py index f905347a..f42f9ebd 100644 --- a/server/python-wheel/src/bastionlab_server/server.py +++ b/server/python-wheel/src/bastionlab_server/server.py @@ -69,11 +69,28 @@ def tls_certificates(): print("TLS certificates already generated") -def start_server(bastionlab_path: str, libtorch_path: str) -> BastionLabServer: +def start_server( + bastionlab_path: str, libtorch_path: str, auth_flag: bool, mem_quota: int +) -> BastionLabServer: + import shutil + os.chmod(bastionlab_path, 0o755) os.chdir(os.getcwd() + "/bin") os.environ["LD_LIBRARY_PATH"] = libtorch_path + "/lib" - os.environ["DISABLE_AUTHENTICATION"] = "1" + if mem_quota != 0: + with open("config.toml", "w") as outfile: + outfile.write( + 'client_to_enclave_untrusted_url = "https://0.0.0.0:50056" \n public_keys_directory = "keys/" \n session_expiry_in_secs = 1500 \n max_memory_consumption = {} \n '.format( + mem_quota + ) + ) + if auth_flag == False: + os.environ["DISABLE_AUTHENTICATION"] = "1" + else: + os.makedirs(os.getcwd() + "/keys/owners", mode=0o777, exist_ok=True) + os.makedirs(os.getcwd() + "/keys/users", mode=0o777, exist_ok=True) + shutil.copy("../data_owner.pub", os.getcwd() + "/keys/owners") + shutil.copy("../data_scientist.pub", os.getcwd() + "/keys/users") process = subprocess.Popen([bastionlab_path], env=os.environ) os.chdir("..") print("Bastionlab server is now running on port 50056") @@ -108,7 +125,7 @@ def stop(srv: BastionLabServer) -> bool: return False -def start() -> BastionLabServer: +def start(auth_flag: bool = False, mem_quota: int = 0) -> BastionLabServer: """Start BastionLab server. The method will download BastionLab's server binary, then download a specific version of libtorch. The server will then run, as a subprocess, allowing to run the rest of your Google Colab/Jupyter Notebook environment. @@ -140,5 +157,5 @@ def start() -> BastionLabServer: "Unable to download Libtorch", ) tls_certificates() - process = start_server(bastionlab_path, libtorch_path) + process = start_server(bastionlab_path, libtorch_path, auth_flag, mem_quota) return process diff --git a/server/python-wheel/src/bastionlab_server/version.py b/server/python-wheel/src/bastionlab_server/version.py index d7b30e12..8879c6c7 100644 --- a/server/python-wheel/src/bastionlab_server/version.py +++ b/server/python-wheel/src/bastionlab_server/version.py @@ -1 +1 @@ -__version__ = "0.3.6" +__version__ = "0.3.7" diff --git a/server/src/main.rs b/server/src/main.rs index 05dd7546..730690e3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -4,6 +4,7 @@ use bastionlab_common::{ auth::KeyManagement, session::SessionManager, telemetry::{self, TelemetryEventProps}, + tracking::Tracking, }; use bastionlab_polars::BastionLabPolars; use bastionlab_torch::BastionLabTorch; @@ -95,6 +96,14 @@ async fn main() -> Result<()> { .session_expiry() .context("Parsing the public session_expiry config")?, )); + + let tracking = Arc::new(Tracking::new( + sess_manager.clone(), + config + .max_memory() + .context("Parsing the maximum memory config")?, + )); + let server_cert = fs::read("tls/host_server.pem").context("Reading the tls/host_server.pem file")?; let server_key = @@ -148,13 +157,12 @@ async fn main() -> Result<()> { }; // Polars - let polars_svc = BastionLabPolars::new(sess_manager.clone()); + let polars_svc = BastionLabPolars::new(sess_manager.clone(), tracking.clone()); let builder = { use bastionlab_polars::{ polars_proto::polars_service_server::PolarsServiceServer, BastionLabPolars, }; - let svc = BastionLabPolars::new(sess_manager.clone()); - match BastionLabPolars::load_dfs(&svc) { + match BastionLabPolars::load_dfs(&polars_svc) { Ok(_) => info!("Successfully loaded saved dataframes"), Err(_) => info!("There was an error loading saved dataframes"), }; diff --git a/server/tools/config.toml b/server/tools/config.toml index 2e5db5b5..a9790b4d 100644 --- a/server/tools/config.toml +++ b/server/tools/config.toml @@ -1,3 +1,4 @@ client_to_enclave_untrusted_url = "https://0.0.0.0:50056" public_keys_directory = "keys/" session_expiry_in_secs = 1500 +max_memory_consumption = 5242880