@@ -19,6 +19,7 @@ def load_plot_data(
1919 data : List ,
2020 effect_size : str = "mean_diff" ,
2121 contrast_type : str = None ,
22+ ci_type : str = "bca" ,
2223 idx : Optional [List [int ]] = None
2324) -> List :
2425 """
@@ -32,6 +33,8 @@ def load_plot_data(
3233 Type of effect size ('mean_diff', 'median_diff', etc.).
3334 contrast_type: str
3435 Type of dabest object to plot ('delta2' or 'mini-meta' or 'delta').
36+ ci_type: str
37+ Type of confidence interval to plot ('bca' or 'pct')
3538 idx: Optional[List[int]], default=None
3639 List of indices to select from the contrast objects if delta-delta experiment.
3740 If None, only the delta-delta objects are plotted.
@@ -53,14 +56,14 @@ def load_plot_data(
5356 current_plot_data = getattr (current_contrast , effect_attr )
5457 bootstraps .append (current_plot_data .results .bootstraps [index ])
5558 differences .append (current_plot_data .results .difference [index ])
56- bcalows .append (current_plot_data .results .bca_low [index ])
57- bcahighs .append (current_plot_data .results .bca_high [index ])
59+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' ) [index ])
60+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' ) [index ])
5861 else :
5962 contrast_plot_data = [getattr (contrast , effect_attr ) for contrast in data ]
6063 bootstraps_nested = [result .results .bootstraps .to_list () for result in contrast_plot_data ]
6164 differences_nested = [result .results .difference .to_list () for result in contrast_plot_data ]
62- bcalows_nested = [result .results .bca_low .to_list () for result in contrast_plot_data ]
63- bcahighs_nested = [result .results .bca_high .to_list () for result in contrast_plot_data ]
65+ bcalows_nested = [result .results .get ( ci_type + '_low' ) .to_list () for result in contrast_plot_data ]
66+ bcahighs_nested = [result .results .get ( ci_type + '_high' ) .to_list () for result in contrast_plot_data ]
6467
6568 bootstraps = [element for innerList in bootstraps_nested for element in innerList ]
6669 differences = [element for innerList in differences_nested for element in innerList ]
@@ -79,14 +82,14 @@ def load_plot_data(
7982 current_plot_data = getattr (getattr (current_contrast , effect_attr ), contrast_attr )
8083 bootstraps .append (current_plot_data .bootstraps_delta_delta )
8184 differences .append (current_plot_data .difference )
82- bcalows .append (current_plot_data .bca_low )
83- bcahighs .append (current_plot_data .bca_high )
85+ bcalows .append (current_plot_data .results . get ( ci_type + '_low' )[ 0 ] )
86+ bcahighs .append (current_plot_data .results . get ( ci_type + '_high' )[ 0 ] )
8487 elif index == 0 or index == 1 :
8588 current_plot_data = getattr (current_contrast , effect_attr )
8689 bootstraps .append (current_plot_data .results .bootstraps [index ])
8790 differences .append (current_plot_data .results .difference [index ])
88- bcalows .append (current_plot_data .results .bca_low [index ])
89- bcahighs .append (current_plot_data .results .bca_high [index ])
91+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' ) [index ])
92+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' ) [index ])
9093 else :
9194 raise ValueError ("The selected indices must be 0, 1, or 2." )
9295 else :
@@ -95,14 +98,14 @@ def load_plot_data(
9598 current_plot_data = getattr (getattr (current_contrast , effect_attr ), contrast_attr )
9699 bootstraps .append (current_plot_data .bootstraps_weighted_delta )
97100 differences .append (current_plot_data .difference )
98- bcalows .append (current_plot_data .results .bca_low )
99- bcahighs .append (current_plot_data .results .bca_high )
101+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' )[ 0 ] )
102+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' )[ 0 ] )
100103 elif index < num_of_groups :
101104 current_plot_data = getattr (current_contrast , effect_attr )
102105 bootstraps .append (current_plot_data .results .bootstraps [index ])
103106 differences .append (current_plot_data .results .difference [index ])
104- bcalows .append (current_plot_data .results .bca_low [index ])
105- bcahighs .append (current_plot_data .results .bca_high [index ])
107+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' ) [index ])
108+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' ) [index ])
106109 else :
107110 msg1 = "There are only {} groups (starting from zero) in this dabest object. " .format (num_of_groups )
108111 msg2 = "The idx given is {}." .format (index )
@@ -113,8 +116,8 @@ def load_plot_data(
113116
114117 bootstraps = [getattr (result , f"bootstraps_{ attribute_suffix } " ) for result in contrast_plot_data ]
115118 differences = [result .difference for result in contrast_plot_data ]
116- bcalows = [result .bca_low for result in contrast_plot_data ]
117- bcahighs = [result .bca_high for result in contrast_plot_data ]
119+ bcalows = [result .results . get ( ci_type + '_low' )[ 0 ] for result in contrast_plot_data ]
120+ bcahighs = [result .results . get ( ci_type + '_high' )[ 0 ] for result in contrast_plot_data ]
118121
119122 return bootstraps , differences , bcalows , bcahighs
120123
@@ -124,6 +127,7 @@ def check_for_errors(
124127 ax ,
125128 fig_size ,
126129 effect_size ,
130+ ci_type ,
127131 horizontal ,
128132 marker_size ,
129133 custom_palette ,
@@ -140,6 +144,7 @@ def check_for_errors(
140144 yticks ,
141145 yticklabels ,
142146 remove_spines ,
147+ summary_bars ,
143148 ) -> str :
144149
145150 # Contrasts
@@ -203,6 +208,10 @@ def check_for_errors(
203208 raise ValueError ("The `effect_size` argument must be `mean_diff` for mini-meta analyses." )
204209 if data [0 ].delta2 and effect_size not in ['mean_diff' , 'hedges_g' , 'delta_g' ]:
205210 raise ValueError ("The `effect_size` argument must be `mean_diff`, `hedges_g`, or `delta_g` for delta-delta analyses." )
211+
212+ # CI type
213+ if ci_type not in ('bca' , 'pct' ):
214+ raise TypeError ("`ci_type` must be either 'bca' or 'pct'." )
206215
207216 # Horizontal
208217 if not isinstance (horizontal , bool ):
@@ -277,6 +286,15 @@ def check_for_errors(
277286 if not isinstance (remove_spines , bool ):
278287 raise TypeError ("`remove_spines` must be a boolean value." )
279288
289+ # Summary bars
290+ if summary_bars is not None :
291+ if not isinstance (summary_bars , list | tuple ):
292+ raise TypeError ("summary_bars must be a list/tuple of indices (ints)." )
293+ if not all (isinstance (i , int ) for i in summary_bars ):
294+ raise TypeError ("summary_bars must be a list/tuple of indices (ints)." )
295+ if any (i >= number_of_curves_to_plot for i in summary_bars ):
296+ raise ValueError ("Index {} chosen is out of range for the contrast objects." .format ([i for i in summary_bars if i >= number_of_curves_to_plot ]))
297+
280298 return contrast_type
281299
282300
@@ -288,6 +306,7 @@ def get_kwargs(
288306 errorbar_kwargs ,
289307 delta_text_kwargs ,
290308 contrast_bars_kwargs ,
309+ summary_bars_kwargs ,
291310 marker_size
292311 ):
293312 from .misc_tools import merge_two_dicts
@@ -369,9 +388,21 @@ def get_kwargs(
369388 else :
370389 contrast_bars_kwargs = merge_two_dicts (default_contrast_bars_kwargs , contrast_bars_kwargs )
371390
391+ # Summary bars kwargs.
392+ default_summary_bars_kwargs = {
393+ "span_ax" : False ,
394+ "color" : None ,
395+ "alpha" : 0.15 ,
396+ "zorder" :- 3
397+ }
398+ if summary_bars_kwargs is None :
399+ summary_bars_kwargs = default_summary_bars_kwargs
400+ else :
401+ summary_bars_kwargs = merge_two_dicts (default_summary_bars_kwargs , summary_bars_kwargs )
402+
372403
373404 return (violin_kwargs , zeroline_kwargs , marker_kwargs , errorbar_kwargs ,
374- delta_text_kwargs , contrast_bars_kwargs )
405+ delta_text_kwargs , contrast_bars_kwargs , summary_bars_kwargs )
375406
376407
377408
@@ -407,6 +438,7 @@ def forest_plot(
407438 ax : Optional [plt .Axes ] = None ,
408439 fig_size : tuple [int , int ] = None ,
409440 effect_size : str = "mean_diff" ,
441+ ci_type = 'bca' ,
410442 horizontal : bool = False ,
411443
412444 marker_size : int = 10 ,
@@ -431,6 +463,8 @@ def forest_plot(
431463
432464 contrast_bars : bool = True ,
433465 contrast_bars_kwargs : dict = None ,
466+ summary_bars : list | tuple = None ,
467+ summary_bars_kwargs : dict = None ,
434468
435469 violin_kwargs : Optional [dict ] = None ,
436470 zeroline_kwargs : Optional [dict ] = None ,
@@ -455,6 +489,8 @@ def forest_plot(
455489 Figure size for the plot.
456490 effect_size : str
457491 Type of effect size to plot (e.g., 'mean_diff', `hedges_g` or 'delta_g').
492+ ci_type : str
493+ Type of confidence interval to plot (bca' or 'pct')
458494 horizontal : bool, default=False
459495 If True, the plot will be horizontal.
460496 marker_size : int, default=12
@@ -495,6 +531,10 @@ def forest_plot(
495531 If True, it adds bars from the zeroline to the effect size curve.
496532 contrast_bars_kwargs : dict, default=None
497533 Additional keyword arguments for the contrast_bars.
534+ summary_bars: list | tuple, default=None,
535+ If True, it adds summary bars to the relevant effect size curves.
536+ summary_bars_kwargs : dict, default=None,
537+ Additional keyword arguments for the summary_bars.
498538 violin_kwargs : Optional[dict], default=None
499539 Additional arguments for violin plot customization.
500540 zeroline_kwargs : Optional[dict], default=None
@@ -519,6 +559,7 @@ def forest_plot(
519559 ax = ax ,
520560 fig_size = fig_size ,
521561 effect_size = effect_size ,
562+ ci_type = ci_type ,
522563 horizontal = horizontal ,
523564 marker_size = marker_size ,
524565 custom_palette = custom_palette ,
@@ -535,16 +576,17 @@ def forest_plot(
535576 yticks = yticks ,
536577 yticklabels = yticklabels ,
537578 remove_spines = remove_spines ,
579+ summary_bars = summary_bars ,
538580 )
539581
540582 # Load plot data and extract info
541583 bootstraps , differences , bcalows , bcahighs = load_plot_data (
542584 data = data ,
543585 effect_size = effect_size ,
544586 contrast_type = contrast_type ,
587+ ci_type = ci_type ,
545588 idx = idx
546589 )
547-
548590 # Adjust figure size based on orientation
549591 number_of_curves_to_plot = len (bootstraps )
550592 # number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
@@ -556,16 +598,17 @@ def forest_plot(
556598 fig , ax = plt .subplots (figsize = fig_size )
557599
558600 # Get Kwargs
559- (violin_kwargs , zeroline_kwargs , marker_kwargs ,
560- errorbar_kwargs , delta_text_kwargs , contrast_bars_kwargs ) = get_kwargs (
561- violin_kwargs = violin_kwargs ,
562- zeroline_kwargs = zeroline_kwargs ,
563- horizontal = horizontal ,
564- marker_kwargs = marker_kwargs ,
565- errorbar_kwargs = errorbar_kwargs ,
566- delta_text_kwargs = delta_text_kwargs ,
567- contrast_bars_kwargs = contrast_bars_kwargs ,
568- marker_size = marker_size
601+ (violin_kwargs , zeroline_kwargs , marker_kwargs , errorbar_kwargs ,
602+ delta_text_kwargs , contrast_bars_kwargs , summary_bars_kwargs ) = get_kwargs (
603+ violin_kwargs = violin_kwargs ,
604+ zeroline_kwargs = zeroline_kwargs ,
605+ horizontal = horizontal ,
606+ marker_kwargs = marker_kwargs ,
607+ errorbar_kwargs = errorbar_kwargs ,
608+ delta_text_kwargs = delta_text_kwargs ,
609+ contrast_bars_kwargs = contrast_bars_kwargs ,
610+ summary_bars_kwargs = summary_bars_kwargs ,
611+ marker_size = marker_size
569612 )
570613
571614 # Plot the violins and make adjustments
@@ -719,6 +762,42 @@ def forest_plot(
719762 else :
720763 ax .add_patch (mpatches .Rectangle ((x , 0 ), 0.25 , y , color = bar_colors [x - 1 ], ** contrast_bars_kwargs ))
721764
765+ # Summary bars
766+ if summary_bars :
767+ _bar_color = summary_bars_kwargs .pop ('color' )
768+ if _bar_color is not None :
769+ bar_colors = [_bar_color ] * number_of_curves_to_plot
770+ else :
771+ bar_colors = violin_colors
772+
773+ span_ax = summary_bars_kwargs .pop ("span_ax" )
774+ summary_xmin , summary_xmax = ax .get_xlim ()
775+ summary_ymin , summary_ymax = ax .get_ylim ()
776+
777+ for summary_index in summary_bars :
778+ if span_ax == True :
779+ starting_location = summary_ymin if horizontal else summary_xmin
780+ else :
781+ starting_location = summary_index + 1
782+
783+ summary_color = bar_colors [summary_index ]
784+ summary_ci_low , summary_ci_high = bcalows [summary_index ], bcahighs [summary_index ]
785+
786+ if horizontal :
787+ ax .add_patch (mpatches .Rectangle (
788+ (summary_ci_low , starting_location ),
789+ summary_ci_high - summary_ci_low , summary_ymax + 1 ,
790+ color = summary_color ,
791+ ** summary_bars_kwargs )
792+ )
793+ else :
794+ ax .add_patch (mpatches .Rectangle (
795+ (starting_location , summary_ci_low ),
796+ summary_xmax + 1 , summary_ci_high - summary_ci_low ,
797+ color = summary_color ,
798+ ** summary_bars_kwargs )
799+ )
800+
722801 ## Invert Y-axis if horizontal
723802 if horizontal :
724803 ax .invert_yaxis ()
0 commit comments