1010from sklearn import metrics
1111import math
1212
13+ from naslib .predictors .zerocost import ZeroCost
1314from naslib .search_spaces .core .query_metrics import Metric
1415from naslib .utils import generate_kfold , cross_validation
1516
17+ from naslib import utils
18+
1619logger = logging .getLogger (__name__ )
1720
1821
@@ -47,6 +50,9 @@ def __init__(self, predictor, config=None):
4750 self .num_arches_to_mutate = 5
4851 self .max_mutation_rate = 3
4952
53+ # For ZeroCost proxies
54+ self .dataloader = None
55+
5056 def adapt_search_space (
5157 self , search_space , load_labeled , scope = None , dataset_api = None
5258 ):
@@ -70,6 +76,9 @@ def adapt_search_space(
7076 "This search space is not yet implemented in PredictorEvaluator."
7177 )
7278
79+ if isinstance (self .predictor , ZeroCost ):
80+ self .dataloader , _ , _ , _ , _ = utils .get_train_val_loaders (self .config )
81+
7382 def get_full_arch_info (self , arch ):
7483 """
7584 Given an arch, return the accuracy, train_time,
@@ -139,10 +148,8 @@ def load_dataset(self, load_labeled=False, data_size=10, arch_hash_map={}):
139148 arch .load_labeled_architecture (dataset_api = self .dataset_api )
140149
141150 arch_hash = arch .get_hash ()
142- if False : # removing this for consistency, for now
143- continue
144- else :
145- arch_hash_map [arch_hash ] = True
151+
152+ arch_hash_map [arch_hash ] = True
146153
147154 accuracy , train_time , info_dict = self .get_full_arch_info (arch )
148155 xdata .append (arch )
@@ -295,7 +302,11 @@ def single_evaluate(self, train_data, test_data, fidelity):
295302 hyperparams = self .predictor .get_hyperparams ()
296303
297304 fit_time_end = time .time ()
298- test_pred = self .predictor .query (xtest , test_info )
305+ if isinstance (self .predictor , ZeroCost ):
306+ [g .parse () for g in xtest ] # parse the graphs because they will be used
307+ test_pred = self .predictor .query_batch (xtest , self .dataloader )
308+ else :
309+ test_pred = self .predictor .query (xtest , test_info )
299310 query_time_end = time .time ()
300311
301312 # If the predictor is an ensemble, take the mean
0 commit comments