|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +import logging |
| 4 | +import sys |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from sklearn.preprocessing import StandardScaler |
| 8 | +import pandas as pd |
| 9 | +import mlflow |
| 10 | + |
| 11 | +SCALERS = {} |
| 12 | + |
| 13 | + |
| 14 | +def get_arg_parser(parser=None): |
| 15 | + """Parse the command line arguments for merge using argparse. |
| 16 | +
|
| 17 | + Args: |
| 18 | + parser (argparse.ArgumentParser or CompliantArgumentParser): |
| 19 | + an argument parser instance |
| 20 | +
|
| 21 | + Returns: |
| 22 | + ArgumentParser: the argument parser instance |
| 23 | +
|
| 24 | + Notes: |
| 25 | + if parser is None, creates a new parser instance |
| 26 | + """ |
| 27 | + # add arguments that are specific to the component |
| 28 | + if parser is None: |
| 29 | + parser = argparse.ArgumentParser(description=__doc__) |
| 30 | + |
| 31 | + parser.add_argument("--raw_training_data", type=str, required=True, help="") |
| 32 | + parser.add_argument("--raw_testing_data", type=str, required=True, help="") |
| 33 | + parser.add_argument("--train_output", type=str, required=True, help="") |
| 34 | + parser.add_argument("--test_output", type=str, required=True, help="") |
| 35 | + parser.add_argument( |
| 36 | + "--metrics_prefix", type=str, required=False, help="Metrics prefix" |
| 37 | + ) |
| 38 | + return parser |
| 39 | + |
| 40 | + |
| 41 | +def apply_transforms(df): |
| 42 | + """Applies transformation for datetime and numerical columns |
| 43 | +
|
| 44 | + Args: |
| 45 | + df (pd.DataFrame): |
| 46 | + dataframe to transform |
| 47 | +
|
| 48 | + Returns: |
| 49 | + pd.DataFrame: transformed dataframe |
| 50 | + """ |
| 51 | + global SCALERS |
| 52 | + |
| 53 | + datetimes = ["trans_date_trans_time"] # "dob" |
| 54 | + normalize = [ |
| 55 | + "age", |
| 56 | + "merch_lat", |
| 57 | + "merch_long", |
| 58 | + "lat", |
| 59 | + "long", |
| 60 | + "city_pop", |
| 61 | + "trans_date_trans_time", |
| 62 | + "amt", |
| 63 | + ] |
| 64 | + |
| 65 | + for column in datetimes: |
| 66 | + if column not in df.columns: |
| 67 | + continue |
| 68 | + df.loc[:, column] = pd.to_datetime(df[column]).view("int64") |
| 69 | + for column in normalize: |
| 70 | + if column not in df.columns: |
| 71 | + continue |
| 72 | + |
| 73 | + if column not in SCALERS: |
| 74 | + print(f"Creating encoder for column: {column}") |
| 75 | + # Simply set all zeros if the category is unseen |
| 76 | + scaler = StandardScaler() |
| 77 | + scaler.fit(df[column].values.reshape(-1, 1)) |
| 78 | + SCALERS[column] = scaler |
| 79 | + |
| 80 | + scaler = SCALERS.get(column) |
| 81 | + df.loc[:, column] = scaler.transform(df[column].values.reshape(-1, 1)) |
| 82 | + |
| 83 | + return df |
| 84 | + |
| 85 | + |
| 86 | +def preprocess_data( |
| 87 | + raw_training_data, |
| 88 | + raw_testing_data, |
| 89 | + train_data_dir="./", |
| 90 | + test_data_dir="./", |
| 91 | + metrics_prefix="default-prefix", |
| 92 | +): |
| 93 | + """Preprocess the raw_training_data and raw_testing_data and save the processed data to train_data_dir and test_data_dir. |
| 94 | +
|
| 95 | + Args: |
| 96 | + raw_training_data: Training data directory that need to be processed |
| 97 | + raw_testing_data: Testing data directory that need to be processed |
| 98 | + train_data_dir: Train data directory where processed train data will be saved |
| 99 | + test_data_dir: Test data directory where processed test data will be saved |
| 100 | + Returns: |
| 101 | + None |
| 102 | + """ |
| 103 | + |
| 104 | + logger.info( |
| 105 | + f"Raw Training Data path: {raw_training_data}, Raw Testing Data path: {raw_testing_data}, Processed Training Data dir path: {train_data_dir}, Processed Testing Data dir path: {test_data_dir}" |
| 106 | + ) |
| 107 | + |
| 108 | + logger.debug(f"Loading data...") |
| 109 | + train_df = pd.read_csv(raw_training_data + f"/train.csv", index_col=0) |
| 110 | + test_df = pd.read_csv(raw_testing_data + f"/test.csv", index_col=0) |
| 111 | + |
| 112 | + if "is_fraud" in train_df.columns: |
| 113 | + fraud_weight = ( |
| 114 | + train_df["is_fraud"].value_counts()[0] |
| 115 | + / train_df["is_fraud"].value_counts()[1] |
| 116 | + ) |
| 117 | + logger.debug(f"Fraud weight: {fraud_weight}") |
| 118 | + np.savetxt(train_data_dir + "/fraud_weight.txt", np.array([fraud_weight])) |
| 119 | + |
| 120 | + logger.debug(f"Applying transformations...") |
| 121 | + train_data = apply_transforms(train_df) |
| 122 | + test_data = apply_transforms(test_df) |
| 123 | + |
| 124 | + logger.debug(f"Train data samples: {len(train_data)}") |
| 125 | + logger.debug(f"Test data samples: {len(test_data)}") |
| 126 | + logger.info(f"Saving processed data to {train_data_dir} and {test_data_dir}") |
| 127 | + |
| 128 | + os.makedirs(train_data_dir, exist_ok=True) |
| 129 | + os.makedirs(test_data_dir, exist_ok=True) |
| 130 | + |
| 131 | + train_data.to_csv(train_data_dir + "/data.csv") |
| 132 | + test_data.to_csv(test_data_dir + "/data.csv") |
| 133 | + |
| 134 | + # Mlflow logging |
| 135 | + log_metadata(train_data, test_data, metrics_prefix) |
| 136 | + |
| 137 | + |
| 138 | +def log_metadata(train_df, test_df, metrics_prefix): |
| 139 | + with mlflow.start_run() as mlflow_run: |
| 140 | + # get Mlflow client |
| 141 | + mlflow_client = mlflow.tracking.client.MlflowClient() |
| 142 | + root_run_id = mlflow_run.data.tags.get("mlflow.rootRunId") |
| 143 | + logger.debug(f"Root runId: {root_run_id}") |
| 144 | + if root_run_id: |
| 145 | + mlflow_client.log_metric( |
| 146 | + run_id=root_run_id, |
| 147 | + key=f"{metrics_prefix}/Number of train datapoints", |
| 148 | + value=f"{train_df.shape[0]}", |
| 149 | + ) |
| 150 | + |
| 151 | + mlflow_client.log_metric( |
| 152 | + run_id=root_run_id, |
| 153 | + key=f"{metrics_prefix}/Number of test datapoints", |
| 154 | + value=f"{test_df.shape[0]}", |
| 155 | + ) |
| 156 | + |
| 157 | + |
| 158 | +def main(cli_args=None): |
| 159 | + """Component main function. |
| 160 | +
|
| 161 | + It parses arguments and executes run() with the right arguments. |
| 162 | +
|
| 163 | + Args: |
| 164 | + cli_args (List[str], optional): list of args to feed script, useful for debugging. Defaults to None. |
| 165 | + """ |
| 166 | + # build an arg parser |
| 167 | + parser = get_arg_parser() |
| 168 | + # run the parser on cli args |
| 169 | + args = parser.parse_args(cli_args) |
| 170 | + logger.info(f"Running script with arguments: {args}") |
| 171 | + |
| 172 | + def run(): |
| 173 | + """Run script with arguments (the core of the component). |
| 174 | +
|
| 175 | + Args: |
| 176 | + args (argparse.namespace): command line arguments provided to script |
| 177 | + """ |
| 178 | + |
| 179 | + preprocess_data( |
| 180 | + args.raw_training_data, |
| 181 | + args.raw_testing_data, |
| 182 | + args.train_output, |
| 183 | + args.test_output, |
| 184 | + args.metrics_prefix, |
| 185 | + ) |
| 186 | + |
| 187 | + run() |
| 188 | + |
| 189 | + |
| 190 | +if __name__ == "__main__": |
| 191 | + # Set logging to sys.out |
| 192 | + logger = logging.getLogger(__name__) |
| 193 | + logger.setLevel(logging.DEBUG) |
| 194 | + log_format = logging.Formatter("[%(asctime)s] [%(levelname)s] - %(message)s") |
| 195 | + handler = logging.StreamHandler(sys.stdout) |
| 196 | + handler.setLevel(logging.DEBUG) |
| 197 | + handler.setFormatter(log_format) |
| 198 | + logger.addHandler(handler) |
| 199 | + |
| 200 | + main() |
0 commit comments