Skip to content

Commit 08e50c3

Browse files
committed
Simulating PR #154 merge
2 parents 96b8186 + 0ad578f commit 08e50c3

20 files changed

+1079
-1061
lines changed

hpobench/benchmarks/ml/__init__.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from hpobench.benchmarks.ml.histgb_benchmark import HistGBBenchmark, HistGBBenchmarkBB, HistGBBenchmarkMF
21
from hpobench.benchmarks.ml.lr_benchmark import LRBenchmark, LRBenchmarkBB, LRBenchmarkMF
32
from hpobench.benchmarks.ml.nn_benchmark import NNBenchmark, NNBenchmarkBB, NNBenchmarkMF
43
from hpobench.benchmarks.ml.rf_benchmark import RandomForestBenchmark, RandomForestBenchmarkBB, \
@@ -8,17 +7,26 @@
87
from hpobench.benchmarks.ml.yahpo_benchmark import YAHPOGymMORawBenchmark, YAHPOGymRawBenchmark
98

109
try:
10+
# `xgboost` is from https://xgboost.readthedocs.io/en/latest/install.html#conda
11+
# and not part of the scikit-learn bundle and not a strict requirement for running HPOBench
12+
# for other spaces and also for tabular benchmarks
1113
from hpobench.benchmarks.ml.xgboost_benchmark import XGBoostBenchmark, XGBoostBenchmarkBB, XGBoostBenchmarkMF
12-
except ImportError:
13-
pass
14+
__all__ = [
15+
'LRBenchmark', 'LRBenchmarkBB', 'LRBenchmarkMF',
16+
'NNBenchmark', 'NNBenchmarkBB', 'NNBenchmarkMF',
17+
'RandomForestBenchmark', 'RandomForestBenchmarkBB', 'RandomForestBenchmarkMF',
18+
'SVMBenchmark', 'SVMBenchmarkBB', 'SVMBenchmarkMF',
19+
'XGBoostBenchmark', 'XGBoostBenchmarkBB', 'XGBoostBenchmarkMF',
20+
'TabularBenchmark',
21+
'YAHPOGymMORawBenchmark', 'YAHPOGymRawBenchmark',
22+
]
23+
except (ImportError, AttributeError) as e:
24+
__all__ = [
25+
'LRBenchmark', 'LRBenchmarkBB', 'LRBenchmarkMF',
26+
'NNBenchmark', 'NNBenchmarkBB', 'NNBenchmarkMF',
27+
'RandomForestBenchmark', 'RandomForestBenchmarkBB', 'RandomForestBenchmarkMF',
28+
'SVMBenchmark', 'SVMBenchmarkBB', 'SVMBenchmarkMF',
29+
'TabularBenchmark',
30+
'YAHPOGymMORawBenchmark', 'YAHPOGymRawBenchmark',
1431

15-
16-
__all__ = ['HistGBBenchmark', 'HistGBBenchmarkBB', 'HistGBBenchmarkMF',
17-
'LRBenchmark', 'LRBenchmarkBB', 'LRBenchmarkMF',
18-
'NNBenchmark', 'NNBenchmarkBB', 'NNBenchmarkMF',
19-
'RandomForestBenchmark', 'RandomForestBenchmarkBB', 'RandomForestBenchmarkMF',
20-
'SVMBenchmark', 'SVMBenchmarkBB', 'SVMBenchmarkMF',
21-
'TabularBenchmark',
22-
'XGBoostBenchmark', 'XGBoostBenchmarkBB', 'XGBoostBenchmarkMF',
23-
'YAHPOGymMORawBenchmark', 'YAHPOGymRawBenchmark',
24-
]
32+
]

hpobench/benchmarks/ml/lr_benchmark.py

Lines changed: 207 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,38 @@
44
55
0.0.1:
66
* First implementation of the LR Benchmarks.
7+
0.0.2:
8+
* Restructuring for consistency and to match ML Benchmark Template updates.
9+
0.0.3:
10+
* Adding Learning Curve support.
711
"""
812

9-
13+
import time
1014
from typing import Union, Tuple, Dict
1115

1216
import ConfigSpace as CS
1317
import numpy as np
18+
import pandas as pd
1419
from ConfigSpace.hyperparameters import Hyperparameter
1520
from sklearn.linear_model import SGDClassifier
1621

22+
from hpobench.util.rng_helper import get_rng
1723
from hpobench.dependencies.ml.ml_benchmark_template import MLBenchmark
1824

19-
__version__ = '0.0.1'
25+
__version__ = '0.0.3'
2026

2127

2228
class LRBenchmark(MLBenchmark):
23-
def __init__(self,
24-
task_id: int,
25-
rng: Union[np.random.RandomState, int, None] = None,
26-
valid_size: float = 0.33,
27-
data_path: Union[str, None] = None):
28-
29-
super(LRBenchmark, self).__init__(task_id, rng, valid_size, data_path)
30-
self.cache_size = 500
29+
""" Multi-multi-fidelity Logisitic Regression Benchmark
30+
"""
31+
def __init__(
32+
self,
33+
task_id: int,
34+
valid_size: float = 0.33,
35+
rng: Union[np.random.RandomState, int, None] = None,
36+
data_path: Union[str, None] = None
37+
):
38+
super(LRBenchmark, self).__init__(task_id, valid_size, rng, data_path)
3139

3240
@staticmethod
3341
def get_configuration_space(seed: Union[int, None] = None) -> CS.ConfigurationSpace:
@@ -44,7 +52,8 @@ def get_configuration_space(seed: Union[int, None] = None) -> CS.ConfigurationSp
4452
])
4553
return cs
4654

47-
def get_fidelity_space(self, seed: Union[int, None] = None) -> CS.ConfigurationSpace:
55+
@staticmethod
56+
def get_fidelity_space(seed: Union[int, None] = None) -> CS.ConfigurationSpace:
4857
fidelity_space = CS.ConfigurationSpace(seed=seed)
4958
fidelity_space.add_hyperparameters(
5059
# gray-box setting (multi-multi-fidelity) - iterations + data subsample
@@ -53,17 +62,11 @@ def get_fidelity_space(self, seed: Union[int, None] = None) -> CS.ConfigurationS
5362
return fidelity_space
5463

5564
@staticmethod
56-
def _get_fidelity_choices(iter_choice: str, subsample_choice: str) -> Tuple[Hyperparameter, Hyperparameter]:
65+
def _get_fidelity_choices(
66+
iter_choice: str, subsample_choice: str
67+
) -> Tuple[Hyperparameter, Hyperparameter]:
5768
"""Fidelity space available --- specifies the fidelity dimensions
58-
59-
For SVM, only a single fidelity exists, i.e., subsample fraction.
60-
if fidelity_choice == 0
61-
uses the entire data (subsample=1), reflecting the black-box setup
62-
else
63-
parameterizes the fraction of data to subsample
64-
6569
"""
66-
6770
assert iter_choice in ['fixed', 'variable']
6871
assert subsample_choice in ['fixed', 'variable']
6972

@@ -79,14 +82,16 @@ def _get_fidelity_choices(iter_choice: str, subsample_choice: str) -> Tuple[Hype
7982
'subsample', lower=0.1, upper=1.0, default_value=1.0, log=False
8083
)
8184
)
82-
8385
iter = fidelity1[iter_choice]
8486
subsample = fidelity2[subsample_choice]
8587
return iter, subsample
8688

87-
def init_model(self, config: Union[CS.Configuration, Dict],
88-
fidelity: Union[CS.Configuration, Dict, None] = None,
89-
rng: Union[int, np.random.RandomState, None] = None):
89+
def init_model(
90+
self,
91+
config: Union[CS.Configuration, Dict],
92+
fidelity: Union[CS.Configuration, Dict, None] = None,
93+
rng: Union[int, np.random.RandomState, None] = None
94+
):
9095
# initializing model
9196
rng = self.rng if rng is None else rng
9297

@@ -103,13 +108,185 @@ def init_model(self, config: Union[CS.Configuration, Dict],
103108
learning_rate="adaptive",
104109
tol=None,
105110
random_state=rng,
106-
107111
)
108112
return model
109113

114+
def get_model_size(self, model: SGDClassifier = None) -> float:
115+
""" Returns the dimensionality as a proxy for the number of model parameters
116+
117+
Logistic Regression models have a fixed number of parameters given a dataset. Model size is
118+
being approximated as the number of beta parameters required as the model support plus the
119+
intercept. This depends on the dataset and not on the trained model.
120+
121+
Parameters
122+
----------
123+
model : SGDClassifier
124+
Trained LR model. This parameter is required to maintain function signature.
125+
126+
Returns
127+
-------
128+
float
129+
"""
130+
ndims = self.train_X.shape[1]
131+
# accounting for the intercept
132+
ndims += 1
133+
return ndims
134+
135+
def _train_objective(
136+
self,
137+
config: Dict,
138+
fidelity: Dict,
139+
shuffle: bool,
140+
rng: Union[np.random.RandomState, int, None] = None,
141+
evaluation: Union[str, None] = "valid",
142+
record_stats: bool = False,
143+
get_learning_curve: bool = False,
144+
lc_every_k: int = 1,
145+
**kwargs
146+
):
147+
"""Function that instantiates a 'config' on a 'fidelity' and trains it
148+
149+
The ML model is instantiated and trained on the training split. Optionally, the model is
150+
evaluated on the training set. Optionally, the learning curves are collected.
151+
152+
Parameters
153+
----------
154+
config : CS.Configuration, Dict
155+
The hyperparameter configuration.
156+
fidelity : CS.Configuration, Dict
157+
The fidelity configuration.
158+
shuffle : bool (optional)
159+
If True, shuffles the training split before fitting the ML model.
160+
rng : np.random.RandomState, int (optional)
161+
The random seed passed to the ML model and if applicable, used for shuffling the data
162+
and subsampling the dataset fraction.
163+
evaluation : str (optional)
164+
If "valid", the ML model is trained on the training set alone.
165+
If "test", the ML model is trained on the training + validation sets.
166+
record_stats : bool (optional)
167+
If True, records the evaluation metrics of the trained ML model on the training set.
168+
This is set to False by default to reduce overall compute time.
169+
get_learning_curve : bool (optional)
170+
If True, records the learning curve using partial_fit or warm starting, if applicable.
171+
This is set to False by default to reduce overall compute time.
172+
Enabling True, implies that the for each iteration, the model will be evaluated on both
173+
the validation and test sets, optionally on the training set also.
174+
lc_every_k : int (optional)
175+
If True, records the learning curve after every k iterations.
176+
"""
177+
if rng is not None:
178+
rng = get_rng(rng, self.rng)
179+
180+
# initializing model
181+
model = self.init_model(config, fidelity, rng)
182+
183+
# preparing data
184+
if evaluation == "valid":
185+
train_X = self.train_X
186+
train_y = self.train_y
187+
elif evaluation == "test":
188+
train_X = np.vstack((self.train_X, self.valid_X))
189+
train_y = pd.concat((self.train_y, self.valid_y))
190+
else:
191+
raise ValueError("{} not in ['valid', 'test']".format(evaluation))
192+
train_idx = np.arange(len(train_X)) if self.train_idx is None else self.train_idx
193+
194+
# shuffling data
195+
if shuffle:
196+
train_idx = self.shuffle_data_idx(train_idx, rng)
197+
if isinstance(train_idx, np.ndarray):
198+
train_X = train_X[train_idx]
199+
else:
200+
train_X = train_X.iloc[train_idx]
201+
train_y = train_y.iloc[train_idx]
202+
203+
# subsample here:
204+
# application of the other fidelity to the dataset that the model interfaces
205+
# carried over from previous HPOBench code that borrowed from FABOLAS' SVM
206+
lower_bound_lim = 1.0 / 512.0
207+
if self.lower_bound_train_size is None:
208+
self.lower_bound_train_size = (10 * self.n_classes) / self.train_X.shape[0]
209+
self.lower_bound_train_size = np.max((lower_bound_lim, self.lower_bound_train_size))
210+
subsample = np.max((fidelity['subsample'], self.lower_bound_train_size))
211+
train_idx = self.rng.choice(
212+
np.arange(len(train_X)), size=int(
213+
subsample * len(train_X)
214+
)
215+
)
216+
# fitting the model with subsampled data
217+
if get_learning_curve:
218+
# IMPORTANT to allow partial_fit
219+
model.warm_start = True
220+
lc_time = 0.0
221+
model_fit_time = 0.0
222+
learning_curves = dict(train=[], valid=[], test=[])
223+
lc_spacings = self._get_lc_spacing(model.max_iter, lc_every_k)
224+
iter_start = 0
225+
for i in range(len(lc_spacings)):
226+
iter_end = lc_spacings[i]
227+
start = time.time()
228+
# trains model for k steps
229+
for j in range(iter_end - iter_start):
230+
model.partial_fit(
231+
train_X[train_idx],
232+
train_y.iloc[train_idx],
233+
np.unique(train_y.iloc[train_idx])
234+
)
235+
# adding all partial fit times
236+
model_fit_time += time.time() - start
237+
iter_start = iter_end
238+
lc_start = time.time()
239+
if record_stats:
240+
train_pred = model.predict(train_X)
241+
train_loss = 1 - self.scorers['acc'](
242+
train_y, train_pred, **self.scorer_args['acc']
243+
)
244+
learning_curves['train'].append(train_loss)
245+
val_pred = model.predict(self.valid_X)
246+
val_loss = 1 - self.scorers['acc'](
247+
self.valid_y, val_pred, **self.scorer_args['acc']
248+
)
249+
learning_curves['valid'].append(val_loss)
250+
test_pred = model.predict(self.test_X)
251+
test_loss = 1 - self.scorers['acc'](
252+
self.test_y, test_pred, **self.scorer_args['acc']
253+
)
254+
learning_curves['test'].append(test_loss)
255+
# sums the time taken to evaluate and collect data for the learning curves
256+
lc_time += time.time() - lc_start
257+
else:
258+
# default training as per the base benchmark template
259+
learning_curves = None
260+
lc_time = None
261+
start = time.time()
262+
model.fit(train_X[train_idx], train_y.iloc[train_idx])
263+
model_fit_time = time.time() - start
264+
# model inference
265+
inference_time = 0.0
266+
# can optionally not record evaluation metrics on training set to save compute
267+
if record_stats:
268+
start = time.time()
269+
pred_train = model.predict(train_X)
270+
inference_time = time.time() - start
271+
# computing statistics on training data
272+
scores = dict()
273+
score_cost = dict()
274+
for k, v in self.scorers.items():
275+
scores[k] = 0.0
276+
score_cost[k] = 0.0
277+
_start = time.time()
278+
if record_stats:
279+
scores[k] = v(train_y, pred_train, **self.scorer_args[k])
280+
score_cost[k] = time.time() - _start + inference_time
281+
train_loss = 1 - scores["acc"]
282+
return model, model_fit_time, train_loss, scores, score_cost, learning_curves, lc_time
283+
110284

111285
class LRBenchmarkBB(LRBenchmark):
112-
def get_fidelity_space(self, seed: Union[int, None] = None) -> CS.ConfigurationSpace:
286+
""" Black-box version of the LRBenchmark
287+
"""
288+
@staticmethod
289+
def get_fidelity_space(seed: Union[int, None] = None) -> CS.ConfigurationSpace:
113290
fidelity_space = CS.ConfigurationSpace(seed=seed)
114291
fidelity_space.add_hyperparameters(
115292
# black-box setting (full fidelity)
@@ -119,7 +296,10 @@ def get_fidelity_space(self, seed: Union[int, None] = None) -> CS.ConfigurationS
119296

120297

121298
class LRBenchmarkMF(LRBenchmark):
122-
def get_fidelity_space(self, seed: Union[int, None] = None) -> CS.ConfigurationSpace:
299+
""" Multi-fidelity version of the LRBenchmark
300+
"""
301+
@staticmethod
302+
def get_fidelity_space(seed: Union[int, None] = None) -> CS.ConfigurationSpace:
123303
fidelity_space = CS.ConfigurationSpace(seed=seed)
124304
fidelity_space.add_hyperparameters(
125305
# gray-box setting (multi-fidelity) - iterations

0 commit comments

Comments
 (0)