22from  typing  import  Any , Literal , Optional , Union , cast 
33
44import  matplotlib .pyplot  as  plt 
5+ import  numpy  as  np 
56from  matplotlib  import  colormaps 
67from  matplotlib .axes  import  Axes 
78from  matplotlib .lines  import  Line2D 
@@ -387,6 +388,113 @@ def _plot_cross_validated_estimator(
387388
388389        return  self .ax_ , lines , info_pos_label 
389390
391+     def  _plot_average_cross_validated_binary_estimator (
392+         self ,
393+         * ,
394+         estimator_name : str ,
395+         roc_curve_kwargs : list [dict [str , Any ]],
396+         plot_chance_level : bool  =  True ,
397+         chance_level_kwargs : Optional [dict [str , Any ]],
398+     ) ->  tuple [Axes , list [Line2D ], Union [str , None ]]:
399+         """Plot ROC curve for a cross-validated estimator. 
400+ 
401+         Parameters 
402+         ---------- 
403+         estimator_name : str 
404+             The name of the estimator. 
405+ 
406+         roc_curve_kwargs : list of dict 
407+             List of dictionaries containing keyword arguments to customize the ROC 
408+             curves. The length of the list should match the number of curves to plot. 
409+ 
410+         plot_chance_level : bool, default=True 
411+             Whether to plot the chance level. 
412+ 
413+         chance_level_kwargs : dict, default=None 
414+             Keyword arguments to be passed to matplotlib's `plot` for rendering 
415+             the chance level line. 
416+ 
417+         Returns 
418+         ------- 
419+         ax : matplotlib.axes.Axes 
420+             The axes with the ROC curves plotted. 
421+ 
422+         lines : list of matplotlib.lines.Line2D 
423+             The plotted ROC curve lines. 
424+ 
425+         info_pos_label : str or None 
426+             String containing positive label information for binary classification, 
427+             None for multiclass. 
428+         """ 
429+         lines : list [Line2D ] =  []
430+         average_type  =  self .roc_curve ["average" ].cat .categories .item ()
431+         n_folds : int  =  0 
432+ 
433+         for  split_idx  in  self .roc_curve ["split_index" ].cat .categories :
434+             if  split_idx  is  None :
435+                 continue 
436+             split_idx  =  int (split_idx )
437+             query  =  f"label == { self .pos_label !r} { split_idx }  
438+             roc_curve  =  self .roc_curve .query (query )
439+ 
440+             line_kwargs_validated  =  _validate_style_kwargs (
441+                 {"color" : "grey" , "alpha" : 0.3 , "lw" : 0.75 }, roc_curve_kwargs [split_idx ]
442+             )
443+ 
444+             (line ,) =  self .ax_ .plot (
445+                 roc_curve ["fpr" ],
446+                 roc_curve ["tpr" ],
447+                 ** line_kwargs_validated ,
448+             )
449+             lines .append (line )
450+             n_folds  +=  1 
451+ 
452+         info_pos_label  =  (
453+             f"\n (Positive label: { self .pos_label }  
454+             if  self .pos_label  is  not None 
455+             else  "" 
456+         )
457+ 
458+         query  =  f"label == { self .pos_label !r} { average_type }  
459+         average_roc_curve  =  self .roc_curve .query (query )
460+         average_roc_auc  =  self .roc_auc .query (query )["roc_auc" ].item ()
461+ 
462+         line_kwargs_validated  =  _validate_style_kwargs ({}, {})
463+         line_kwargs_validated ["label" ] =  (
464+             f"{ average_type .capitalize ()} { n_folds }  
465+             f"(AUC = { average_roc_auc :0.2f}  
466+         )
467+ 
468+         (line ,) =  self .ax_ .plot (
469+             average_roc_curve ["fpr" ],
470+             average_roc_curve ["tpr" ],
471+             ** line_kwargs_validated ,
472+         )
473+         lines .append (line )
474+ 
475+         info_pos_label  =  (
476+             f"\n (Positive label: { self .pos_label }  
477+             if  self .pos_label  is  not None 
478+             else  "" 
479+         )
480+ 
481+         if  plot_chance_level :
482+             self .chance_level_  =  _add_chance_level (
483+                 self .ax_ ,
484+                 chance_level_kwargs ,
485+                 self ._default_chance_level_kwargs ,
486+             )
487+         else :
488+             self .chance_level_  =  None 
489+ 
490+         if  self .data_source  in  ("train" , "test" ):
491+             title  =  f"{ estimator_name } \\ bf{{{ self .data_source }  
492+         else :
493+             title  =  f"{ estimator_name } \\ bf{{external}}$ set" 
494+         self .ax_ .legend (bbox_to_anchor = (1.02 , 1 ), title = title )
495+ 
496+         return  self .ax_ , lines , info_pos_label 
497+ 
390498    def  _plot_comparison_estimator (
391499        self ,
392500        * ,
@@ -760,17 +868,30 @@ def plot(
760868                chance_level_kwargs = chance_level_kwargs ,
761869            )
762870        elif  self .report_type  ==  "cross-validation" :
763-             self .ax_ , self .lines_ , info_pos_label  =  (
764-                 self ._plot_cross_validated_estimator (
765-                     estimator_name = (
766-                         estimator_name 
767-                         or  self .roc_auc ["estimator_name" ].cat .categories .item ()
768-                     ),
769-                     roc_curve_kwargs = roc_curve_kwargs ,
770-                     plot_chance_level = plot_chance_level ,
771-                     chance_level_kwargs = chance_level_kwargs ,
871+             if  "average"  in  self .roc_auc .columns :
872+                 self .ax_ , self .lines_ , info_pos_label  =  (
873+                     self ._plot_average_cross_validated_binary_estimator (
874+                         estimator_name = (
875+                             estimator_name 
876+                             or  self .roc_auc ["estimator_name" ].cat .categories .item ()
877+                         ),
878+                         roc_curve_kwargs = roc_curve_kwargs ,
879+                         plot_chance_level = plot_chance_level ,
880+                         chance_level_kwargs = chance_level_kwargs ,
881+                     )
882+                 )
883+             else :
884+                 self .ax_ , self .lines_ , info_pos_label  =  (
885+                     self ._plot_cross_validated_estimator (
886+                         estimator_name = (
887+                             estimator_name 
888+                             or  self .roc_auc ["estimator_name" ].cat .categories .item ()
889+                         ),
890+                         roc_curve_kwargs = roc_curve_kwargs ,
891+                         plot_chance_level = plot_chance_level ,
892+                         chance_level_kwargs = chance_level_kwargs ,
893+                     )
772894                )
773-             )
774895        elif  self .report_type  ==  "comparison-estimator" :
775896            self .ax_ , self .lines_ , info_pos_label  =  self ._plot_comparison_estimator (
776897                estimator_names = self .roc_auc ["estimator_name" ].cat .categories ,
@@ -812,6 +933,7 @@ def _compute_data_for_display(
812933        cls ,
813934        y_true : Sequence [YPlotData ],
814935        y_pred : Sequence [YPlotData ],
936+         average : Optional [Literal ["threshold" ]] =  None ,
815937        * ,
816938        report_type : ReportType ,
817939        estimators : Sequence [BaseEstimator ],
@@ -869,6 +991,7 @@ def _compute_data_for_display(
869991        roc_auc_records  =  []
870992
871993        if  ml_task  ==  "binary-classification" :
994+             pos_label_validated  =  cast (PositiveLabel , pos_label_validated )
872995            for  y_true_i , y_pred_i  in  zip (y_true , y_pred ):
873996                fpr_i , tpr_i , thresholds_i  =  roc_curve (
874997                    y_true_i .y ,
@@ -878,8 +1001,6 @@ def _compute_data_for_display(
8781001                )
8791002                roc_auc_i  =  auc (fpr_i , tpr_i )
8801003
881-                 pos_label_validated  =  cast (PositiveLabel , pos_label_validated )
882- 
8831004                for  fpr , tpr , threshold  in  zip (fpr_i , tpr_i , thresholds_i ):
8841005                    roc_curve_records .append (
8851006                        {
@@ -900,8 +1021,63 @@ def _compute_data_for_display(
9001021                        "roc_auc" : roc_auc_i ,
9011022                    }
9021023                )
1024+             if  average  is  not None :
1025+                 if  average  ==  "threshold" :
1026+                     all_thresholds  =  []
1027+                     all_fprs  =  []
1028+                     all_tprs  =  []
1029+ 
1030+                     roc_curves_df  =  DataFrame .from_records (roc_curve_records )
1031+                     for  _ , group  in  roc_curves_df .groupby ("split_index" ):
1032+                         sorted_group  =  group .sort_values ("threshold" , ascending = False )
1033+                         all_thresholds .append (
1034+                             np .array (sorted_group ["threshold" ].values )
1035+                         )
1036+                         all_fprs .append (np .array (sorted_group ["fpr" ].values ))
1037+                         all_tprs .append (np .array (sorted_group ["tpr" ].values ))
1038+ 
1039+                     average_fpr , average_tpr , average_threshold  =  (
1040+                         cls ._threshold_average (
1041+                             xs = all_fprs ,
1042+                             ys = all_tprs ,
1043+                             thresholds = all_thresholds ,
1044+                         )
1045+                     )
1046+                 else :
1047+                     raise  TypeError (
1048+                         "'threshold' is the only supported option for `average`," 
1049+                         f"but got { average }  
1050+                     )
1051+                 average_roc_auc  =  auc (average_fpr , average_tpr )
1052+                 for  fpr , tpr , threshold  in  zip (
1053+                     average_fpr , average_tpr , average_threshold 
1054+                 ):
1055+                     roc_curve_records .append (
1056+                         {
1057+                             "estimator_name" : y_true_i .estimator_name ,
1058+                             "split_index" : None ,
1059+                             "label" : pos_label_validated ,
1060+                             "threshold" : threshold ,
1061+                             "fpr" : fpr ,
1062+                             "tpr" : tpr ,
1063+                             "average" : "threshold" ,
1064+                         }
1065+                     )
1066+                 roc_auc_records .append (
1067+                     {
1068+                         "estimator_name" : y_true_i .estimator_name ,
1069+                         "split_index" : None ,
1070+                         "label" : pos_label_validated ,
1071+                         "roc_auc" : average_roc_auc ,
1072+                         "average" : "threshold" ,
1073+                     }
1074+                 )
9031075
9041076        else :  # multiclass-classification 
1077+             if  average  is  not None :
1078+                 raise  ValueError (
1079+                     "Averaging is not implemented for multi class classification" 
1080+                 )
9051081            # OvR fashion to collect fpr, tpr, and roc_auc 
9061082            for  y_true_i , y_pred_i , est  in  zip (y_true , y_pred , estimators ):
9071083                label_binarizer  =  LabelBinarizer ().fit (est .classes_ )
@@ -942,7 +1118,7 @@ def _compute_data_for_display(
9421118            "estimator_name" : "category" ,
9431119            "split_index" : "category" ,
9441120            "label" : "category" ,
945-         }
1121+         }  |  ({ "average" :  "category" }  if   average   is   not   None   else  {}) 
9461122
9471123        return  cls (
9481124            roc_curve = DataFrame .from_records (roc_curve_records ).astype (dtypes ),
0 commit comments