diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BinarizerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BinarizerTests.cs new file mode 100644 index 000000000..567674301 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BinarizerTests.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Microsoft.Spark.UnitTest.TestUtils; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + [Collection("Spark E2E Tests")] + public class BinarizerTests : FeatureBaseTests + { + private readonly SparkSession _spark; + + public BinarizerTests(SparkFixture fixture) : base(fixture) + { + _spark = fixture.Spark; + } + + [Fact] + public void TestBinarizer() + { + string inputCol = "feature"; + DataFrame input = _spark.CreateDataFrame( + new List + { + new GenericRow(new object[] {0, 0.1}), + new GenericRow(new object[] {1, 0.8}), + new GenericRow(new object[] {2, 0.2}) + }, + new StructType(new List + { + new StructField("id", new IntegerType()), new StructField(inputCol, new DoubleType()) + })); + string expectedUid = "theUid"; + string outputCol = "binarized_feature"; + double threshold = 0.5; + Binarizer binarizer = new Binarizer(expectedUid) + .SetInputCol(inputCol) + .SetOutputCol(outputCol) + .SetThreshold(threshold); + DataFrame output = binarizer.Transform(input); + StructType outputSchema = binarizer.TransformSchema(input.Schema()); + + Assert.Contains(output.Schema().Fields, (f => f.Name == outputCol)); + Assert.Contains(outputSchema.Fields, (f => f.Name == outputCol)); + Assert.Equal(inputCol, binarizer.GetInputCol()); + Assert.Equal(outputCol, binarizer.GetOutputCol()); + Assert.Equal(threshold, binarizer.GetThreshold()); + + using (var tempDirectory = new TemporaryDirectory()) + { + string savePath = Path.Join(tempDirectory.Path, "Binarizer"); + binarizer.Save(savePath); + + Binarizer loadedBinarizer = Binarizer.Load(savePath); + Assert.Equal(loadedBinarizer.Uid(), binarizer.Uid()); + } + + Assert.Equal(expectedUid, binarizer.Uid()); + } + + [Fact] + public void TestBinarizerWithArrayParams() + { + string[] inputCol = new[] {"col1", "col2"}; + string[] outputCol = new[] {"feature1", "feature2"}; + double[] threshold = new[] {0.5, 0.8}; + Binarizer binarizer = new Binarizer() + .SetInputCols(inputCol) + .SetOutputCols(outputCol) + .SetThresholds(threshold); + + Assert.Equal(inputCol, binarizer.GetInputCols()); + Assert.Equal(outputCol, binarizer.GetOutputCols()); + Assert.Equal(threshold, binarizer.GetThresholds()); + } + } +} diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Binarizer.cs b/src/csharp/Microsoft.Spark/ML/Feature/Binarizer.cs new file mode 100644 index 000000000..d5888752d --- /dev/null +++ b/src/csharp/Microsoft.Spark/ML/Feature/Binarizer.cs @@ -0,0 +1,161 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; + +namespace Microsoft.Spark.ML.Feature +{ + /// + /// A , Binarize a column of continuous features given a threshold. + /// + public class Binarizer : FeatureBase, IJvmObjectReferenceProvider + { + private static readonly string s_binarizerClassName = + "org.apache.spark.ml.feature.Binarizer"; + + public Binarizer() : base(s_binarizerClassName) + { + } + + public Binarizer(string uid) : base(s_binarizerClassName, uid) + { + } + + internal Binarizer(JvmObjectReference jvmObject) : base(jvmObject) + { + } + + JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; + + /// + /// Gets the column that the should read from + /// + /// string, input column + public string GetInputCol() => (string)(_jvmObject.Invoke("getInputCol")); + + /// + /// Sets the column that the should read from + /// + /// The name of the column to as the source + /// New object + public Binarizer SetInputCol(string value) => + WrapAsBinarizer(_jvmObject.Invoke("setInputCol", value)); + + /// + /// Gets the columns that the should read from + /// + /// array of strings, input column + public string[] GetInputCols() => (string[])(_jvmObject.Invoke("getInputCols")); + + /// + /// Sets the columns that the should read from + /// + /// The name of the columns to as the source + /// New object + public Binarizer SetInputCols(string[] value) => + WrapAsBinarizer(_jvmObject.Invoke("setInputCols", value)); + + /// + /// Param for threshold used to continuous features. + /// + /// Threshold value + /// New object + public Binarizer SetThreshold(double value) => + WrapAsBinarizer(_jvmObject.Invoke("setThreshold", value)); + + /// + /// Gets threshold used to continuous features. + /// + /// double, the threshold + public double GetThreshold() => (double)(_jvmObject.Invoke("getThreshold")); + + /// + /// Param for thresholds used to continuous features. + /// + /// Threshold values + /// New object + public Binarizer SetThresholds(double[] value) => + WrapAsBinarizer(_jvmObject.Invoke("setThresholds", value)); + + /// + /// Gets thresholds used to continuous features. + /// + /// array of double, the thresholds + public double[] GetThresholds() => (double[])(_jvmObject.Invoke("getThresholds")); + + /// + /// The will create a new column in the DataFrame, this is the + /// name of the new column. + /// + /// string, the output column + public string GetOutputCol() => (string)(_jvmObject.Invoke("getOutputCol")); + + /// + /// The will create a new column in the DataFrame, this is the + /// name of the new column. + /// + /// The name of the new column + /// New object + public Binarizer SetOutputCol(string value) => + WrapAsBinarizer(_jvmObject.Invoke("setOutputCol", value)); + + /// + /// The will create a new columns in the DataFrame, this is the + /// name of the new column. + /// + /// array of strings, the output column + public string[] GetOutputCols() => (string[])(_jvmObject.Invoke("getOutputCols")); + + /// + /// The will create a new columns in the DataFrame, this is the + /// name of the new column. + /// + /// The name of the new columns + /// New object + public Binarizer SetOutputCols(string[] value) => + WrapAsBinarizer(_jvmObject.Invoke("setOutputCols", value)); + + /// + /// Executes the and transforms the DataFrame to include the new + /// column + /// + /// The DataFrame to transform + /// + /// New object with the source transformed + /// + public DataFrame Transform(DataFrame source) => + new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source)); + + /// + /// Executes the and transforms the schema. + /// + /// The Schema to be transformed + /// + /// New object with the schema transformed. + /// + public StructType TransformSchema(StructType value) => + new StructType( + (JvmObjectReference)_jvmObject.Invoke( + "transformSchema", + DataType.FromJson(_jvmObject.Jvm, value.Json))); + + /// + /// Loads the that was previously saved using Save + /// + /// The path the previous was saved to + /// New object, loaded from path + public static Binarizer Load(string path) + { + return WrapAsBinarizer( + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_binarizerClassName, "load", path)); + } + + private static Binarizer WrapAsBinarizer(object obj) => + new Binarizer((JvmObjectReference)obj); + } +}