Skip to content

Commit 718ed51

Browse files
committed
Forest plots can now do regular delta experiments
1 parent 9f27bab commit 718ed51

File tree

6 files changed

+197
-60
lines changed

6 files changed

+197
-60
lines changed

dabest/forest_plot.py

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def load_plot_data(
1717
data: List,
1818
effect_size: str = "mean_diff",
19-
contrast_type: str = 'delta2',
19+
contrast_type: str = None,
2020
idx: Optional[List[int]] = None
2121
) -> List:
2222
"""
@@ -29,7 +29,7 @@ def load_plot_data(
2929
effect_size: str
3030
Type of effect size ('mean_diff', 'median_diff', etc.).
3131
contrast_type: str
32-
Type of dabest object to plot ('delta2' or 'mini-meta')
32+
Type of dabest object to plot ('delta2' or 'mini-meta' or 'delta').
3333
idx: Optional[List[int]], default=None
3434
List of indices to select from the contrast objects if delta-delta experiment.
3535
If None, only the delta-delta objects are plotted.
@@ -40,37 +40,61 @@ def load_plot_data(
4040
"""
4141
# Effect size and contrast types
4242
effect_attr = "hedges_g" if effect_size == 'delta_g' else effect_size
43-
contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(contrast_type)
4443

45-
if idx is not None:
46-
bootstraps, differences, bcalows, bcahighs = [], [], [], []
47-
for current_idx, index_group in enumerate(idx):
48-
current_contrast = data[current_idx]
49-
if len(index_group)>0:
50-
for index in index_group:
51-
if index == 2:
52-
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
53-
bootstraps.append(current_plot_data.bootstraps_delta_delta)
54-
differences.append(current_plot_data.difference)
55-
bcalows.append(current_plot_data.bca_low)
56-
bcahighs.append(current_plot_data.bca_high)
57-
elif index == 0 or index == 1:
44+
if contrast_type == "delta":
45+
if idx is not None:
46+
bootstraps, differences, bcalows, bcahighs = [], [], [], []
47+
for current_idx, index_group in enumerate(idx):
48+
current_contrast = data[current_idx]
49+
if len(index_group)>0:
50+
for index in index_group:
5851
current_plot_data = getattr(current_contrast, effect_attr)
5952
bootstraps.append(current_plot_data.results.bootstraps[index])
6053
differences.append(current_plot_data.results.difference[index])
6154
bcalows.append(current_plot_data.results.bca_low[index])
6255
bcahighs.append(current_plot_data.results.bca_high[index])
63-
else:
64-
raise ValueError("The selected indices must be 0, 1, or 2.")
56+
else:
57+
contrast_plot_data = [getattr(contrast, effect_attr) for contrast in data]
58+
bootstraps_nested = [result.results.bootstraps.to_list() for result in contrast_plot_data]
59+
differences_nested = [result.results.difference.to_list() for result in contrast_plot_data]
60+
bcalows_nested = [result.results.bca_low.to_list() for result in contrast_plot_data]
61+
bcahighs_nested = [result.results.bca_high.to_list() for result in contrast_plot_data]
62+
63+
bootstraps = [element for innerList in bootstraps_nested for element in innerList]
64+
differences = [element for innerList in differences_nested for element in innerList]
65+
bcalows = [element for innerList in bcalows_nested for element in innerList]
66+
bcahighs = [element for innerList in bcahighs_nested for element in innerList]
6567
else:
66-
contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]
68+
contrast_attr = {"delta2": "delta_delta", "mini_meta": "mini_meta"}.get(contrast_type)
69+
if idx is not None:
70+
bootstraps, differences, bcalows, bcahighs = [], [], [], []
71+
for current_idx, index_group in enumerate(idx):
72+
current_contrast = data[current_idx]
73+
if len(index_group)>0:
74+
for index in index_group:
75+
if index == 2:
76+
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
77+
bootstraps.append(current_plot_data.bootstraps_delta_delta)
78+
differences.append(current_plot_data.difference)
79+
bcalows.append(current_plot_data.bca_low)
80+
bcahighs.append(current_plot_data.bca_high)
81+
elif index == 0 or index == 1:
82+
current_plot_data = getattr(current_contrast, effect_attr)
83+
bootstraps.append(current_plot_data.results.bootstraps[index])
84+
differences.append(current_plot_data.results.difference[index])
85+
bcalows.append(current_plot_data.results.bca_low[index])
86+
bcahighs.append(current_plot_data.results.bca_high[index])
87+
else:
88+
raise ValueError("The selected indices must be 0, 1, or 2.")
89+
else:
90+
contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]
6791

68-
attribute_suffix = "weighted_delta" if contrast_type == "mini_meta" else "delta_delta"
92+
attribute_suffix = "weighted_delta" if contrast_type == "mini_meta" else "delta_delta"
6993

70-
bootstraps = [getattr(result, f"bootstraps_{attribute_suffix}") for result in contrast_plot_data]
71-
differences = [result.difference for result in contrast_plot_data]
72-
bcalows = [result.bca_low for result in contrast_plot_data]
73-
bcahighs = [result.bca_high for result in contrast_plot_data]
94+
bootstraps = [getattr(result, f"bootstraps_{attribute_suffix}") for result in contrast_plot_data]
95+
differences = [result.difference for result in contrast_plot_data]
96+
bcalows = [result.bca_low for result in contrast_plot_data]
97+
bcahighs = [result.bca_high for result in contrast_plot_data]
7498

7599
return bootstraps, differences, bcalows, bcahighs
76100

@@ -103,11 +127,20 @@ def check_for_errors(
103127
raise ValueError("The `data` argument must be a non-empty list of dabest objects.")
104128

105129
## Check if all contrasts are delta-delta or all are mini-meta
106-
contrast_type = "delta2" if data[0].delta2 else "mini_meta"
130+
131+
contrast_type = ("delta2" if data[0].delta2
132+
else "mini_meta" if data[0].is_mini_meta
133+
else "delta"
134+
)
135+
136+
# contrast_type = "delta2" if data[0].delta2 else "mini_meta"
107137
for contrast in data:
108-
check_contrast_type = "delta2" if contrast.delta2 else "mini_meta"
138+
check_contrast_type = ("delta2" if contrast.delta2
139+
else "mini_meta" if contrast.is_mini_meta
140+
else "delta"
141+
)
109142
if check_contrast_type != contrast_type:
110-
raise ValueError("Each dabest object supplied must be the same experimental type (mini-meta or delta-delta)")
143+
raise ValueError("Each dabest object supplied must be the same experimental type (mini-meta or delta-delta or neither.)")
111144

112145
# Idx
113146
if idx is not None:
@@ -426,7 +459,8 @@ def forest_plot(
426459
)
427460

428461
# Adjust figure size based on orientation
429-
number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
462+
number_of_curves_to_plot = len(bootstraps)
463+
# number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
430464
if ax is not None:
431465
fig = ax.figure
432466
else:

nbs/API/forest_plot.ipynb

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
"def load_plot_data(\n",
7676
" data: List, \n",
7777
" effect_size: str = \"mean_diff\", \n",
78-
" contrast_type: str = 'delta2',\n",
78+
" contrast_type: str = None,\n",
7979
" idx: Optional[List[int]] = None\n",
8080
") -> List:\n",
8181
" \"\"\"\n",
@@ -88,7 +88,7 @@
8888
" effect_size: str\n",
8989
" Type of effect size ('mean_diff', 'median_diff', etc.).\n",
9090
" contrast_type: str\n",
91-
" Type of dabest object to plot ('delta2' or 'mini-meta')\n",
91+
" Type of dabest object to plot ('delta2' or 'mini-meta' or 'delta').\n",
9292
" idx: Optional[List[int]], default=None\n",
9393
" List of indices to select from the contrast objects if delta-delta experiment. \n",
9494
" If None, only the delta-delta objects are plotted.\n",
@@ -99,37 +99,61 @@
9999
" \"\"\"\n",
100100
" # Effect size and contrast types\n",
101101
" effect_attr = \"hedges_g\" if effect_size == 'delta_g' else effect_size\n",
102-
" contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(contrast_type)\n",
103102
"\n",
104-
" if idx is not None:\n",
105-
" bootstraps, differences, bcalows, bcahighs = [], [], [], []\n",
106-
" for current_idx, index_group in enumerate(idx):\n",
107-
" current_contrast = data[current_idx]\n",
108-
" if len(index_group)>0:\n",
109-
" for index in index_group:\n",
110-
" if index == 2:\n",
111-
" current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)\n",
112-
" bootstraps.append(current_plot_data.bootstraps_delta_delta)\n",
113-
" differences.append(current_plot_data.difference)\n",
114-
" bcalows.append(current_plot_data.bca_low)\n",
115-
" bcahighs.append(current_plot_data.bca_high)\n",
116-
" elif index == 0 or index == 1:\n",
103+
" if contrast_type == \"delta\":\n",
104+
" if idx is not None:\n",
105+
" bootstraps, differences, bcalows, bcahighs = [], [], [], []\n",
106+
" for current_idx, index_group in enumerate(idx):\n",
107+
" current_contrast = data[current_idx]\n",
108+
" if len(index_group)>0:\n",
109+
" for index in index_group:\n",
117110
" current_plot_data = getattr(current_contrast, effect_attr)\n",
118111
" bootstraps.append(current_plot_data.results.bootstraps[index])\n",
119112
" differences.append(current_plot_data.results.difference[index])\n",
120113
" bcalows.append(current_plot_data.results.bca_low[index])\n",
121114
" bcahighs.append(current_plot_data.results.bca_high[index])\n",
122-
" else:\n",
123-
" raise ValueError(\"The selected indices must be 0, 1, or 2.\")\n",
115+
" else:\n",
116+
" contrast_plot_data = [getattr(contrast, effect_attr) for contrast in data]\n",
117+
" bootstraps_nested = [result.results.bootstraps.to_list() for result in contrast_plot_data]\n",
118+
" differences_nested = [result.results.difference.to_list() for result in contrast_plot_data]\n",
119+
" bcalows_nested = [result.results.bca_low.to_list() for result in contrast_plot_data]\n",
120+
" bcahighs_nested = [result.results.bca_high.to_list() for result in contrast_plot_data]\n",
121+
" \n",
122+
" bootstraps = [element for innerList in bootstraps_nested for element in innerList]\n",
123+
" differences = [element for innerList in differences_nested for element in innerList]\n",
124+
" bcalows = [element for innerList in bcalows_nested for element in innerList]\n",
125+
" bcahighs = [element for innerList in bcahighs_nested for element in innerList]\n",
124126
" else:\n",
125-
" contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]\n",
127+
" contrast_attr = {\"delta2\": \"delta_delta\", \"mini_meta\": \"mini_meta\"}.get(contrast_type)\n",
128+
" if idx is not None:\n",
129+
" bootstraps, differences, bcalows, bcahighs = [], [], [], []\n",
130+
" for current_idx, index_group in enumerate(idx):\n",
131+
" current_contrast = data[current_idx]\n",
132+
" if len(index_group)>0:\n",
133+
" for index in index_group:\n",
134+
" if index == 2:\n",
135+
" current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)\n",
136+
" bootstraps.append(current_plot_data.bootstraps_delta_delta)\n",
137+
" differences.append(current_plot_data.difference)\n",
138+
" bcalows.append(current_plot_data.bca_low)\n",
139+
" bcahighs.append(current_plot_data.bca_high)\n",
140+
" elif index == 0 or index == 1:\n",
141+
" current_plot_data = getattr(current_contrast, effect_attr)\n",
142+
" bootstraps.append(current_plot_data.results.bootstraps[index])\n",
143+
" differences.append(current_plot_data.results.difference[index])\n",
144+
" bcalows.append(current_plot_data.results.bca_low[index])\n",
145+
" bcahighs.append(current_plot_data.results.bca_high[index])\n",
146+
" else:\n",
147+
" raise ValueError(\"The selected indices must be 0, 1, or 2.\")\n",
148+
" else:\n",
149+
" contrast_plot_data = [getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in data]\n",
126150
"\n",
127-
" attribute_suffix = \"weighted_delta\" if contrast_type == \"mini_meta\" else \"delta_delta\"\n",
151+
" attribute_suffix = \"weighted_delta\" if contrast_type == \"mini_meta\" else \"delta_delta\"\n",
128152
"\n",
129-
" bootstraps = [getattr(result, f\"bootstraps_{attribute_suffix}\") for result in contrast_plot_data]\n",
130-
" differences = [result.difference for result in contrast_plot_data]\n",
131-
" bcalows = [result.bca_low for result in contrast_plot_data]\n",
132-
" bcahighs = [result.bca_high for result in contrast_plot_data]\n",
153+
" bootstraps = [getattr(result, f\"bootstraps_{attribute_suffix}\") for result in contrast_plot_data]\n",
154+
" differences = [result.difference for result in contrast_plot_data]\n",
155+
" bcalows = [result.bca_low for result in contrast_plot_data]\n",
156+
" bcahighs = [result.bca_high for result in contrast_plot_data]\n",
133157
"\n",
134158
" return bootstraps, differences, bcalows, bcahighs\n",
135159
"\n",
@@ -162,11 +186,20 @@
162186
" raise ValueError(\"The `data` argument must be a non-empty list of dabest objects.\")\n",
163187
" \n",
164188
" ## Check if all contrasts are delta-delta or all are mini-meta\n",
165-
" contrast_type = \"delta2\" if data[0].delta2 else \"mini_meta\"\n",
189+
"\n",
190+
" contrast_type = (\"delta2\" if data[0].delta2 \n",
191+
" else \"mini_meta\" if data[0].is_mini_meta\n",
192+
" else \"delta\"\n",
193+
" )\n",
194+
"\n",
195+
" # contrast_type = \"delta2\" if data[0].delta2 else \"mini_meta\"\n",
166196
" for contrast in data:\n",
167-
" check_contrast_type = \"delta2\" if contrast.delta2 else \"mini_meta\"\n",
197+
" check_contrast_type = (\"delta2\" if contrast.delta2 \n",
198+
" else \"mini_meta\" if contrast.is_mini_meta\n",
199+
" else \"delta\"\n",
200+
" )\n",
168201
" if check_contrast_type != contrast_type:\n",
169-
" raise ValueError(\"Each dabest object supplied must be the same experimental type (mini-meta or delta-delta)\")\n",
202+
" raise ValueError(\"Each dabest object supplied must be the same experimental type (mini-meta or delta-delta or neither.)\")\n",
170203
"\n",
171204
" # Idx\n",
172205
" if idx is not None:\n",
@@ -485,7 +518,8 @@
485518
" )\n",
486519
"\n",
487520
" # Adjust figure size based on orientation\n",
488-
" number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)\n",
521+
" number_of_curves_to_plot = len(bootstraps)\n",
522+
" # number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)\n",
489523
" if ax is not None:\n",
490524
" fig = ax.figure\n",
491525
" else:\n",
25.2 KB
Loading
13.3 KB
Loading

nbs/tests/mpl_image_tests/test_05_forest_plot.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ def create_mini_meta_dataset(N=20, seed=9999, control_locs=[3, 3.5, 3.25], contr
153153
contrasts_mini_meta = [contrast_mini_meta01, contrast_mini_meta02, contrast_mini_meta03]
154154

155155

156+
delta1 = load(data = df_mini_meta01,
157+
idx=(("Control 1", "Test 1"), ("Control 2", "Test 2"), ("Control 3", "Test 3")))
158+
delta2 = load(data = df_mini_meta02,
159+
idx=(("Control 1", "Test 1"), ("Control 2", "Test 2"), ("Control 3", "Test 3")))
160+
delta3 = load(data = df_mini_meta03,
161+
idx=(("Control 1", "Test 1"), ("Control 2", "Test 2"), ("Control 3", "Test 3")))
162+
contrasts_deltas = [delta1, delta2, delta3]
163+
156164
# Import your forest_plot function here
157165
from dabest.forest_plot import forest_plot
158166

@@ -353,4 +361,21 @@ def test_516_deltadelta_eserrorbarkwargs_forest():
353361
es_errorbar_kwargs={
354362
'color': 'red', 'lw': 4, 'linestyle': '--', 'alpha': 0.6,
355363
}
356-
)
364+
)
365+
366+
367+
@pytest.mark.mpl_image_compare(tolerance=8)
368+
def test_517_regular_delta_no_idx():
369+
plt.rcdefaults()
370+
return forest_plot(
371+
contrasts_deltas,
372+
)
373+
374+
@pytest.mark.mpl_image_compare(tolerance=8)
375+
def test_518_regular_delta_idx():
376+
plt.rcdefaults()
377+
return forest_plot(
378+
contrasts_deltas,
379+
idx = [(0,), (0,), (0,)],
380+
labels=['Drug1 \nTest 1 - Control 1', 'Drug2 \nTest 2 - Control 2', 'Drug3 \nTest 3 - Control 3']
381+
)

0 commit comments

Comments
 (0)