@@ -1263,9 +1263,6 @@ def _get_display(
12631263 display : display_class
12641264 The display.
12651265 """
1266- if self ._parent ._reports_type == "CrossValidationReport" :
1267- raise NotImplementedError ()
1268-
12691266 if "seed" in display_kwargs and display_kwargs ["seed" ] is None :
12701267 cache_key = None
12711268 else :
@@ -1288,55 +1285,121 @@ def _get_display(
12881285 y_true : list [YPlotData ] = []
12891286 y_pred : list [YPlotData ] = []
12901287
1291- for report , report_name in zip (
1292- self ._parent .reports_ , self ._parent .report_names_
1293- ):
1294- report_X , report_y , _ = report .metrics ._get_X_y_and_data_source_hash (
1295- data_source = data_source ,
1296- X = X ,
1297- y = y ,
1298- )
1288+ if self ._parent ._reports_type == "EstimatorReport" :
1289+ for report , report_name in zip (
1290+ self ._parent .reports_ , self ._parent .report_names_
1291+ ):
1292+ report_X , report_y , _ = (
1293+ report .metrics ._get_X_y_and_data_source_hash (
1294+ data_source = data_source ,
1295+ X = X ,
1296+ y = y ,
1297+ )
1298+ )
12991299
1300- y_true .append (
1301- YPlotData (
1302- estimator_name = report_name ,
1303- split_index = None ,
1304- y = report_y ,
1300+ y_true .append (
1301+ YPlotData (
1302+ estimator_name = report_name ,
1303+ split_index = None ,
1304+ y = report_y ,
1305+ )
13051306 )
1306- )
1307- results = _get_cached_response_values (
1308- cache = report ._cache ,
1309- estimator_hash = report ._hash ,
1310- estimator = report ._estimator ,
1311- X = report_X ,
1312- response_method = response_method ,
1307+ results = _get_cached_response_values (
1308+ cache = report ._cache ,
1309+ estimator_hash = report ._hash ,
1310+ estimator = report ._estimator ,
1311+ X = report_X ,
1312+ response_method = response_method ,
1313+ data_source = data_source ,
1314+ data_source_hash = None ,
1315+ pos_label = display_kwargs .get ("pos_label" ),
1316+ )
1317+ for key , value , is_cached in results :
1318+ if not is_cached :
1319+ report ._cache [key ] = value
1320+ if key [- 1 ] != "predict_time" :
1321+ y_pred .append (
1322+ YPlotData (
1323+ estimator_name = report_name ,
1324+ split_index = None ,
1325+ y = value ,
1326+ )
1327+ )
1328+
1329+ progress .update (main_task , advance = 1 , refresh = True )
1330+
1331+ display = display_class ._compute_data_for_display (
1332+ y_true = y_true ,
1333+ y_pred = y_pred ,
1334+ report_type = "comparison-estimator" ,
1335+ estimators = [report .estimator_ for report in self ._parent .reports_ ],
1336+ estimator_names = self ._parent .report_names_ ,
1337+ ml_task = self ._parent ._ml_task ,
13131338 data_source = data_source ,
1314- data_source_hash = None ,
1315- pos_label = display_kwargs .get ("pos_label" ),
1339+ ** display_kwargs ,
13161340 )
1317- for key , value , is_cached in results :
1318- if not is_cached :
1319- report ._cache [key ] = value
1320- if key [- 1 ] != "predict_time" :
1321- y_pred .append (
1341+
1342+ else :
1343+ for report , report_name in zip (
1344+ self ._parent .reports_ , self ._parent .report_names_
1345+ ):
1346+ for split_index , estimator_report in enumerate (
1347+ report .estimator_reports_
1348+ ):
1349+ report_X , report_y , _ = (
1350+ estimator_report .metrics ._get_X_y_and_data_source_hash (
1351+ data_source = data_source ,
1352+ X = X ,
1353+ y = y ,
1354+ )
1355+ )
1356+
1357+ y_true .append (
13221358 YPlotData (
13231359 estimator_name = report_name ,
1324- split_index = None ,
1325- y = value ,
1360+ split_index = split_index ,
1361+ y = report_y ,
13261362 )
13271363 )
1328- progress .update (main_task , advance = 1 , refresh = True )
13291364
1330- display = display_class ._compute_data_for_display (
1331- y_true = y_true ,
1332- y_pred = y_pred ,
1333- report_type = "comparison-estimator" ,
1334- estimators = [report .estimator_ for report in self ._parent .reports_ ],
1335- estimator_names = self ._parent .report_names_ ,
1336- ml_task = self ._parent ._ml_task ,
1337- data_source = data_source ,
1338- ** display_kwargs ,
1339- )
1365+ results = _get_cached_response_values (
1366+ cache = estimator_report ._cache ,
1367+ estimator_hash = estimator_report ._hash ,
1368+ estimator = estimator_report .estimator_ ,
1369+ X = report_X ,
1370+ response_method = response_method ,
1371+ data_source = data_source ,
1372+ data_source_hash = None ,
1373+ pos_label = display_kwargs .get ("pos_label" ),
1374+ )
1375+ for key , value , is_cached in results :
1376+ if not is_cached :
1377+ report ._cache [key ] = value
1378+ if key [- 1 ] != "predict_time" :
1379+ y_pred .append (
1380+ YPlotData (
1381+ estimator_name = report_name ,
1382+ split_index = split_index ,
1383+ y = value ,
1384+ )
1385+ )
1386+
1387+ progress .update (main_task , advance = 1 , refresh = True )
1388+
1389+ display = display_class ._compute_data_for_display (
1390+ y_true = y_true ,
1391+ y_pred = y_pred ,
1392+ report_type = "comparison-cross-validation" ,
1393+ estimators = [
1394+ estimator_report .estimator_
1395+ for report in self ._parent .reports_
1396+ for estimator_report in report .estimator_reports_
1397+ ],
1398+ estimator_names = self ._parent .report_names_ ,
1399+ ml_task = self ._parent ._ml_task ,
1400+ data_source = data_source ,
1401+ ** display_kwargs ,
1402+ )
13401403
13411404 if cache_key is not None :
13421405 # Unless seed is an int (i.e. the call is deterministic),
@@ -1476,6 +1539,8 @@ def precision_recall(
14761539 >>> display = comparison_report.metrics.precision_recall()
14771540 >>> display.plot()
14781541 """
1542+ if self ._parent ._reports_type == "CrossValidationReport" :
1543+ raise NotImplementedError ()
14791544 response_method = ("predict_proba" , "decision_function" )
14801545 display_kwargs = {"pos_label" : pos_label }
14811546 display = cast (
@@ -1560,6 +1625,8 @@ def prediction_error(
15601625 >>> display = comparison_report.metrics.prediction_error()
15611626 >>> display.plot(kind="actual_vs_predicted")
15621627 """
1628+ if self ._parent ._reports_type == "CrossValidationReport" :
1629+ raise NotImplementedError ()
15631630 display_kwargs = {"subsample" : subsample , "seed" : seed }
15641631 display = cast (
15651632 PredictionErrorDisplay ,
0 commit comments