diff --git a/config/samples/xgboost-dist/utils.py b/config/samples/xgboost-dist/utils.py index 283af8ba..451d3212 100644 --- a/config/samples/xgboost-dist/utils.py +++ b/config/samples/xgboost-dist/utils.py @@ -18,6 +18,8 @@ import oss2 import json import pandas as pd +import boto3 +from io import StringIO from sklearn import datasets @@ -44,7 +46,7 @@ def extract_xgbooost_cluster_env(): return master_addr, master_port, rank, world_size -def read_train_data(rank, num_workers, path): +def read_train_data(rank, num_workers, path,datasource): """ Read file based on the rank of worker. We use the sklearn.iris data for demonstration @@ -54,20 +56,47 @@ def read_train_data(rank, num_workers, path): :param path: the input file name or the place to read the data :return: XGBoost Dmatrix """ - iris = datasets.load_iris() - x = iris.data - y = iris.target - - start, end = get_range_data(len(x), rank, num_workers) - x = x[start:end, :] - y = y[start:end] - - x = pd.DataFrame(x) - y = pd.DataFrame(y) - dtrain = xgb.DMatrix(data=x, label=y) - - logging.info("Read data from IRIS data source with range from %d to %d", - start, end) + if datasource=='sample': + iris = datasets.load_iris() + x = iris.data + y = iris.target + + start, end = get_range_data(len(x), rank, num_workers) + x = x[start:end, :] + y = y[start:end] + + x = pd.DataFrame(x) + y = pd.DataFrame(y) + dtrain = xgb.DMatrix(data=x, label=y) + + logging.info("Read data from IRIS data source with range from %d to %d",start, end) + if datasource=='s3': + #1. connecting to s3, I proposed we should have a extra method to do the autentication for all different external storage and datasource + #region_name = 'region' + #aws_access_key_id = 'aws_access_key_id' + #aws_secret_access_key = 'aws_secret_access_key' + #s3 =boto3.client('s3',region_name=region_name,aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key) + + #2. get the total length of the file + # I am thinking about the way to pass the bucket and key, maybe I can get them from parse_parameters + obj = s3.get_object(Bucket=bucket, Key=key) + length = obj['ContentLength'] + start, end = get_range_data(len(x), rank, num_workers) + + #3. read corresponding block for each worker + byte_obj = s3.get_object(Bucket='bucket', Key='key', Range=(start,end)) + string_obj= byte_obj['Body'].read().decode('utf-8') + df = pd.read_csv(StringIO(string_obj)) + del byte_obj, string_obj + gc.collect() + # proposed method for cutoff row + if df.iloc[-1].count()!=df.iloc[-2].count(): + df.drop(df.index[-1],inplace=True) + if df.iloc[0].count()!=df.iloc[1].count(): + df.drop(df.index[0],inplace=True) + x, y = df.iloc[:, :-1], df.iloc[:, [-1]] + dtrain = xgb.DMatrix(data=x, label=y) + logging.info("Read data from S3 with range from %d(Bytes) to %d(Bytes)",start, end) return dtrain