Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
427 changes: 427 additions & 0 deletions docs/docs/tutorials/memory_quotas_and_deleting_dataframes.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions server/bastionlab_common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct BastionLabConfig {

pub public_keys_directory: String,
pub session_expiry_in_secs: u64,
pub max_memory_consumption: usize,
}

fn uri_to_socket(uri: &Uri) -> Result<SocketAddr> {
Expand All @@ -49,6 +50,10 @@ impl BastionLabConfig {
pub fn session_expiry(&self) -> Result<u64> {
Ok(self.session_expiry_in_secs)
}

pub fn max_memory(&self) -> Result<usize> {
Ok(self.max_memory_consumption)
}
}

fn deserialize_uri<'de, D>(deserializer: D) -> Result<Uri, D::Error>
Expand Down
1 change: 1 addition & 0 deletions server/bastionlab_common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod config;
pub mod prelude;
pub mod session;
pub mod telemetry;
pub mod tracking;

pub mod session_proto {
tonic::include_proto!("bastionlab");
Expand Down
60 changes: 60 additions & 0 deletions server/bastionlab_common/src/tracking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use crate::session::SessionManager;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use tonic::Status;

#[derive(Debug)]
pub struct Tracking {
sess_manager: Arc<SessionManager>,
//Maps users to their total consumption and a hashmap of their dfs and their sizes
pub memory_quota: Arc<RwLock<HashMap<String, (usize, HashMap<String, usize>)>>>,
max_memory: Mutex<usize>,
pub dataframe_user: Arc<RwLock<HashMap<String, String>>>, //Maps dataframe identifiers to users
}

impl Tracking {
pub fn new(sess_manager: Arc<SessionManager>, max_memory: usize) -> Self {
Self {
sess_manager,
memory_quota: Arc::new(RwLock::new(HashMap::new())),
max_memory: Mutex::new(max_memory),
dataframe_user: Arc::new(RwLock::new(HashMap::new())),
}
}

pub fn memory_quota_check(
&self,
size: usize,
user_id: String,
identifier: String,
) -> Result<(), Status> {
//We return immediately if auth is disabled
if !self.sess_manager.auth_enabled() {
return Ok(());
}

let mut memory_quota = self.memory_quota.write().unwrap();
let consumption = memory_quota.get(&user_id);
let resulting_consumption = match consumption {
Some((consumption, identifiers)) => {
if consumption + size > *self.max_memory.lock().unwrap() {
return Err(Status::unknown(
"You have consumed your entire memory quota. Please delete some of your dataframes to free memory.",
));
}
let mut identifiers = identifiers.to_owned();
identifiers.insert(identifier.clone(), size);
(consumption + size, identifiers)
}
None => {
let mut hash_map = HashMap::new();
hash_map.insert(identifier.clone(), size);
(size, hash_map)
}
};
memory_quota.insert(user_id.clone(), resulting_consumption);
let mut dataframe_user = self.dataframe_user.write().unwrap();
dataframe_user.insert(identifier, user_id);
Ok(())
}
}
66 changes: 53 additions & 13 deletions server/bastionlab_polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use bastionlab_common::{
session::SessionManager,
session_proto::ClientInfo,
telemetry::{self, TelemetryEventProps},
tracking::Tracking,
};

use polars::prelude::*;
Expand Down Expand Up @@ -103,14 +104,16 @@ pub struct BastionLabPolars {
dataframes: Arc<RwLock<HashMap<String, DataFrameArtifact>>>,
arrays: Arc<RwLock<HashMap<String, ArrayStore>>>,
sess_manager: Arc<SessionManager>,
tracking: Arc<Tracking>,
}

impl BastionLabPolars {
pub fn new(sess_manager: Arc<SessionManager>) -> Self {
pub fn new(sess_manager: Arc<SessionManager>, tracking: Arc<Tracking>) -> Self {
Self {
dataframes: Arc::new(RwLock::new(HashMap::new())),
arrays: Arc::new(RwLock::new(HashMap::new())),
sess_manager,
tracking,
}
}

Expand Down Expand Up @@ -312,11 +315,14 @@ Reason: {}",
Ok(res)
}

pub fn insert_df(&self, df: DataFrameArtifact) -> String {
let mut dfs = self.dataframes.write().unwrap();
pub fn insert_df(&self, df: DataFrameArtifact, user_id: String) -> Result<String, Status> {
let identifier = format!("{}", Uuid::new_v4());
let size = df.dataframe.estimated_size();
self.tracking
.memory_quota_check(size, user_id, identifier.clone())?;
let mut dfs = self.dataframes.write().unwrap();
dfs.insert(identifier.clone(), df);
identifier
Ok(identifier)
}

pub fn insert_array(&self, array: ArrayStore) -> String {
Expand Down Expand Up @@ -393,12 +399,42 @@ Reason: {}",
Ok(())
}

pub fn delete_dfs(&self, identifier: &str) -> Result<(), Error> {
pub fn delete_dfs(&self, identifier: &str, user_id: String) -> Result<(), Status> {
let owner_check = self.sess_manager.verify_if_owner(&user_id)?;

//Removes the memory occupied by this df from memory quota
let mut memory_quota = self.tracking.memory_quota.write().unwrap();
let mut dataframe_user = self.tracking.dataframe_user.write().unwrap();

let dataframe_owner = if owner_check {
dataframe_user.get(identifier).unwrap()
} else {
let dataframe_owner = dataframe_user.get(identifier).unwrap();
if dataframe_owner == &user_id {
dataframe_owner
} else {
return Err(Status::invalid_argument(
"This dataframe does not belong to you.",
));
}
};

let mut dfs = self.dataframes.write().unwrap();
dfs.remove(identifier);

let path = "data_frames/".to_owned() + identifier + ".json";
std::fs::remove_file(path).unwrap_or(());

let (mut consumption, id_sizes) = memory_quota.get(dataframe_owner).unwrap();
let df_size = id_sizes.get(identifier).unwrap();
consumption = consumption - df_size;

let mut id_sizes = id_sizes.to_owned();
id_sizes.remove(identifier);
memory_quota.insert(user_id, (consumption, id_sizes));

dataframe_user.remove(identifier);

Ok(())
}
}
Expand Down Expand Up @@ -436,7 +472,7 @@ impl PolarsService for BastionLabPolars {
.map_err(|e| Status::internal(format!("Polars error: {e}")))?;

let header = get_df_header(&res.dataframe)?;
let identifier = self.insert_df(res);
let identifier = self.insert_df(res, user_id)?;

let elapsed = start_time.elapsed();

Expand All @@ -461,10 +497,12 @@ impl PolarsService for BastionLabPolars {
let start_time = Instant::now();

let token = self.sess_manager.get_token(&request)?;
let user_id = self.sess_manager.get_user_id(token.clone())?;

let client_info = self.sess_manager.get_client_info(token)?;
let (df, hash) = unserialize_dataframe(request.into_inner()).await?;
let header = get_df_header(&df.dataframe)?;
let identifier = self.insert_df(df);
let identifier = self.insert_df(df, user_id)?;

let elapsed = start_time.elapsed();
telemetry::add_event(
Expand Down Expand Up @@ -560,20 +598,22 @@ impl PolarsService for BastionLabPolars {

let identifier = &request.get_ref().identifier;
let user_id = self.sess_manager.get_user_id(token.clone())?;
let owner_check = self.sess_manager.verify_if_owner(&user_id)?;
if owner_check {
self.delete_dfs(identifier)?;
} else {
return Err(Status::internal("Only data owners can delete dataframes."));
}
self.delete_dfs(identifier, user_id)?;
telemetry::add_event(
TelemetryEventProps::DeleteDataframe {
dataset_name: Some(identifier.clone()),
},
Some(self.sess_manager.get_client_info(token)?),
);

info!(
"Succesfully deleted dataframe {} from the server",
identifier.clone()
);

Ok(Response::new(Empty {}))
}

async fn split(
&self,
request: Request<SplitRequest>,
Expand Down
25 changes: 21 additions & 4 deletions server/python-wheel/src/bastionlab_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,28 @@ def tls_certificates():
print("TLS certificates already generated")


def start_server(bastionlab_path: str, libtorch_path: str) -> BastionLabServer:
def start_server(
bastionlab_path: str, libtorch_path: str, auth_flag: bool, mem_quota: int
) -> BastionLabServer:
import shutil

os.chmod(bastionlab_path, 0o755)
os.chdir(os.getcwd() + "/bin")
os.environ["LD_LIBRARY_PATH"] = libtorch_path + "/lib"
os.environ["DISABLE_AUTHENTICATION"] = "1"
if mem_quota != 0:
with open("config.toml", "w") as outfile:
outfile.write(
'client_to_enclave_untrusted_url = "https://0.0.0.0:50056" \n public_keys_directory = "keys/" \n session_expiry_in_secs = 1500 \n max_memory_consumption = {} \n '.format(
mem_quota
)
)
if auth_flag == False:
os.environ["DISABLE_AUTHENTICATION"] = "1"
else:
os.makedirs(os.getcwd() + "/keys/owners", mode=0o777, exist_ok=True)
os.makedirs(os.getcwd() + "/keys/users", mode=0o777, exist_ok=True)
shutil.copy("../data_owner.pub", os.getcwd() + "/keys/owners")
shutil.copy("../data_scientist.pub", os.getcwd() + "/keys/users")
process = subprocess.Popen([bastionlab_path], env=os.environ)
os.chdir("..")
print("Bastionlab server is now running on port 50056")
Expand Down Expand Up @@ -108,7 +125,7 @@ def stop(srv: BastionLabServer) -> bool:
return False


def start() -> BastionLabServer:
def start(auth_flag: bool = False, mem_quota: int = 0) -> BastionLabServer:
"""Start BastionLab server.
The method will download BastionLab's server binary, then download a specific version of libtorch.
The server will then run, as a subprocess, allowing to run the rest of your Google Colab/Jupyter Notebook environment.
Expand Down Expand Up @@ -140,5 +157,5 @@ def start() -> BastionLabServer:
"Unable to download Libtorch",
)
tls_certificates()
process = start_server(bastionlab_path, libtorch_path)
process = start_server(bastionlab_path, libtorch_path, auth_flag, mem_quota)
return process
2 changes: 1 addition & 1 deletion server/python-wheel/src/bastionlab_server/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.6"
__version__ = "0.3.7"
14 changes: 11 additions & 3 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use bastionlab_common::{
auth::KeyManagement,
session::SessionManager,
telemetry::{self, TelemetryEventProps},
tracking::Tracking,
};
use bastionlab_polars::BastionLabPolars;
use bastionlab_torch::BastionLabTorch;
Expand Down Expand Up @@ -95,6 +96,14 @@ async fn main() -> Result<()> {
.session_expiry()
.context("Parsing the public session_expiry config")?,
));

let tracking = Arc::new(Tracking::new(
sess_manager.clone(),
config
.max_memory()
.context("Parsing the maximum memory config")?,
));

let server_cert =
fs::read("tls/host_server.pem").context("Reading the tls/host_server.pem file")?;
let server_key =
Expand Down Expand Up @@ -148,13 +157,12 @@ async fn main() -> Result<()> {
};

// Polars
let polars_svc = BastionLabPolars::new(sess_manager.clone());
let polars_svc = BastionLabPolars::new(sess_manager.clone(), tracking.clone());
let builder = {
use bastionlab_polars::{
polars_proto::polars_service_server::PolarsServiceServer, BastionLabPolars,
};
let svc = BastionLabPolars::new(sess_manager.clone());
match BastionLabPolars::load_dfs(&svc) {
match BastionLabPolars::load_dfs(&polars_svc) {
Ok(_) => info!("Successfully loaded saved dataframes"),
Err(_) => info!("There was an error loading saved dataframes"),
};
Expand Down
1 change: 1 addition & 0 deletions server/tools/config.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
client_to_enclave_untrusted_url = "https://0.0.0.0:50056"
public_keys_directory = "keys/"
session_expiry_in_secs = 1500
max_memory_consumption = 5242880