diff --git a/client/src/bastionlab/polars/policy.py b/client/src/bastionlab/polars/policy.py index bfad10a1..e356d660 100644 --- a/client/src/bastionlab/polars/policy.py +++ b/client/src/bastionlab/polars/policy.py @@ -131,17 +131,24 @@ class Policy: safe_zone: Rule unsafe_handling: UnsafeAction - savable: bool + savable: bool = True + convertable: bool = False def serialize(self) -> str: if self.savable: savable_str = "true" else: savable_str = "false" - return f'{{"safe_zone":{self.safe_zone.serialize()},"unsafe_handling":{self.unsafe_handling.serialize()},"savable":{savable_str}}}' + + if self.convertable: + convertable_str = "true" + else: + convertable_str = "false" + + return f'{{"safe_zone":{self.safe_zone.serialize()},"unsafe_handling":{self.unsafe_handling.serialize()},"savable":{savable_str},"convertable":{convertable_str}}}' -DEFAULT_POLICY = Policy(Aggregation(10), Review(), True) +DEFAULT_POLICY = Policy(Aggregation(10), Review(), True, True) """ Default BastionLab Client Policy `Policy(Aggregation(10), Review(), True)` """ diff --git a/docs/docs/tutorials/data_conversion.ipynb b/docs/docs/tutorials/data_conversion.ipynb index 1b020822..1bcc2493 100644 --- a/docs/docs/tutorials/data_conversion.ipynb +++ b/docs/docs/tutorials/data_conversion.ipynb @@ -151,7 +151,9 @@ "from bastionlab.polars.policy import Policy, TrueRule, Log\n", "\n", "df = pl.read_csv(\"titanic.csv\")\n", - "policy = Policy(safe_zone=TrueRule(), unsafe_handling=Log(), savable=False)\n", + "policy = Policy(\n", + " safe_zone=TrueRule(), unsafe_handling=Log(), savable=False, convertable=True\n", + ")\n", "rdf = client.polars.send_df(df.limit(100), policy=policy)\n", "\n", "rdf" @@ -459,7 +461,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.8.10 64-bit", "language": "python", "name": "python3" }, @@ -473,12 +475,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10 (default, Jun 22 2022, 20:18:18) \n[GCC 9.4.0]" + "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" } } }, diff --git a/server/bastionlab_conversion/src/converter.rs b/server/bastionlab_conversion/src/converter.rs index 28de0025..3fd1d9f3 100644 --- a/server/bastionlab_conversion/src/converter.rs +++ b/server/bastionlab_conversion/src/converter.rs @@ -37,6 +37,11 @@ impl ConversionService for Converter { ) -> Result, Status> { self.sess_manager.verify_request(&request)?; let identifier = &request.get_ref().identifier; + let df_artifact = self.polars.get_df_artifact(identifier)?; + + if !df_artifact.policy.check_convertable() { + return Err(Status::unknown("Dataframe is not convertable")); + } let df = self.polars.get_df_unchecked(&identifier)?; diff --git a/server/bastionlab_polars/src/access_control.rs b/server/bastionlab_polars/src/access_control.rs index 2c13fcd6..325f1228 100644 --- a/server/bastionlab_polars/src/access_control.rs +++ b/server/bastionlab_polars/src/access_control.rs @@ -8,6 +8,7 @@ pub struct Policy { safe_zone: Rule, unsafe_handling: UnsafeAction, savable: bool, + convertable: bool, } impl Policy { @@ -26,6 +27,7 @@ impl Policy { safe_zone: Rule::AtLeastNOf(2, vec![self.safe_zone.clone(), other.safe_zone.clone()]), unsafe_handling: self.unsafe_handling.merge(other.unsafe_handling), savable: self.savable && other.savable, + convertable: self.convertable && other.convertable, } } @@ -34,12 +36,17 @@ impl Policy { safe_zone: Rule::True, unsafe_handling: UnsafeAction::Log, savable: true, + convertable: true, } } pub fn check_savable(&self) -> bool { return self.savable; } + + pub fn check_convertable(&self) -> bool { + return self.convertable; + } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/server/bastionlab_polars/src/composite_plan.rs b/server/bastionlab_polars/src/composite_plan.rs index 00d064bf..ec10edd7 100644 --- a/server/bastionlab_polars/src/composite_plan.rs +++ b/server/bastionlab_polars/src/composite_plan.rs @@ -119,11 +119,14 @@ impl CompositePlan { "Could not apply with_row_count: no input data frame", ))?; let df = frame.df.with_row_count(&name, Some(0)).map_err(|e| { - Status::invalid_argument(format!("Error while running with_row_count: {}", e)) + Status::invalid_argument(format!( + "Error while running with_row_count: {}", + e + )) })?; let stats = frame.stats; stack.push(StackFrame { df, stats }); - } + } } } diff --git a/server/bastionlab_polars/src/lib.rs b/server/bastionlab_polars/src/lib.rs index d8d0cc3d..a58c9812 100644 --- a/server/bastionlab_polars/src/lib.rs +++ b/server/bastionlab_polars/src/lib.rs @@ -59,7 +59,7 @@ pub struct DelayedDataFrame { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DataFrameArtifact { dataframe: DataFrame, - policy: Policy, + pub policy: Policy, fetchable: VerificationResult, blacklist: Vec, query_details: String, @@ -251,6 +251,19 @@ Reason: {}", }) } + pub fn get_df_artifact(&self, identifier: &str) -> Result { + let dfs = self.dataframes.read().unwrap(); + Ok(dfs + .get(identifier) + .ok_or_else(|| { + Status::not_found(format!( + "Could not find dataframe: identifier={}", + identifier + )) + })? + .clone()) + } + pub fn get_df_unchecked(&self, identifier: &str) -> Result { let dfs = self.dataframes.read().unwrap(); Ok(dfs diff --git a/tests/test_conversion.py b/tests/test_conversion.py index aec537cc..a023214c 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -32,7 +32,7 @@ def test_df_to_tensor_conv(self): } ).with_column((pl.col("a") * pl.col("b")).alias("c")) - rdf = client.polars.send_df(df, Policy(TrueRule(), Log(), False)) + rdf = client.polars.send_df(df, Policy(TrueRule(), Log(), False, True)) arr = rdf.to_array() @@ -53,7 +53,7 @@ def test_split_remote_array_with_negs(self): } ).with_column((pl.col("a") * pl.col("b")).alias("c")) - rdf = client.polars.send_df(df, Policy(TrueRule(), Log(), False)) + rdf = client.polars.send_df(df, Policy(TrueRule(), Log(), False, True)) arr = rdf.to_array() with self.assertRaises(ValueError) as ve: @@ -69,7 +69,7 @@ def test_split_remote_array(self): } ).with_column((pl.col("a") * pl.col("b")).alias("c")) - rdf = client.polars.send_df(df, Policy(TrueRule(), Log(), False)) + rdf = client.polars.send_df(df, Policy(TrueRule(), Log(), False, True)) arr = rdf.to_array() diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 0d0ca504..b39a8d7c 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -43,7 +43,7 @@ def test_dtypes(self): for df in frames: rdf = client.polars.send_df( - df, policy=Policy(TrueRule(), Log(), savable=False) + df, policy=Policy(TrueRule(), Log(), savable=False, convertable=True) ) df2 = rdf.select(pl.all()).collect().fetch() @@ -67,7 +67,7 @@ def test_mixed_types_dataframe(self): for df in frames: rdf = client.polars.send_df( - df, policy=Policy(TrueRule(), Log(), savable=False) + df, policy=Policy(TrueRule(), Log(), savable=False, convertable=True) ) df2 = rdf.select(pl.all()).collect().fetch() diff --git a/tests/test_queries.py b/tests/test_queries.py index 10408a92..b3f89271 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -29,7 +29,10 @@ def testingdf(self): connection = Connection("localhost", 50056) client = connection.client policy = Policy( - safe_zone=Aggregation(min_agg_size=1), unsafe_handling=Log(), savable=False + safe_zone=Aggregation(min_agg_size=1), + unsafe_handling=Log(), + savable=False, + convertable=True, ) rdf = client.polars.send_df(df, policy) self.assertNotEqual(rdf, None) @@ -40,7 +43,10 @@ def testingquery(self): connection = Connection("localhost", 50056) client = connection.client policy = Policy( - safe_zone=Aggregation(min_agg_size=1), unsafe_handling=Log(), savable=False + safe_zone=Aggregation(min_agg_size=1), + unsafe_handling=Log(), + savable=False, + convertable=True, ) rdf = client.polars.send_df(df, policy) per_class_rates = ( @@ -57,7 +63,12 @@ def testingquery2(self): df = pl.read_csv("titanic.csv").limit(50) connection = Connection("localhost", 50056) client = connection.client - policy = Policy(safe_zone=Aggregation(1), unsafe_handling=Log(), savable=False) + policy = Policy( + safe_zone=Aggregation(1), + unsafe_handling=Log(), + savable=False, + convertable=True, + ) rdf = client.polars.send_df(df, policy) per_sex_rates = ( rdf.select([pl.col("Sex"), pl.col("Survived")]) diff --git a/tests/test_training.py b/tests/test_training.py index cbc53bac..924f1c5b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -63,7 +63,9 @@ def test_df_to_tensor_conv(self): client = connection.client get_covid_dataset() df = pl.read_csv("covid.csv").limit(200) - policy = Policy(safe_zone=TrueRule(), unsafe_handling=Log(), savable=True) + policy = Policy( + safe_zone=TrueRule(), unsafe_handling=Log(), savable=True, convertable=True + ) rdf = client.polars.send_df(df, policy=policy, sanitized_columns=["Name"]) rdf = rdf.drop( [