-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
146 lines (121 loc) · 2.75 KB
/
__main__.py
File metadata and controls
146 lines (121 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pandas as pd
import streamlit as st
import plotly.express as px
@st.cache_data(ttl=60*60)
def get_summary():
df = pd.read_csv("summary.csv")
df["Loss_Generalization_Gap"] = df["Loss_mean_Dev"] - df["Loss_mean_Train"]
df["Accuracy_Generalization_Gap"] = df["Accuracy_mean_Train"] - df["Accuracy_mean_Dev"]
return df
@st.cache_data(ttl=60*60)
def filter_optimizers(df, optimizers_selected):
return df[
df["Optimizer"].isin(optimizers_selected)
]
def plot_summary(df, x, y):
df = df.copy()
c = "Optimizer"
if "Accuracy" in y and "Generalization" not in y:
sub_title = f"Higher is better"
percentage = True
else:
sub_title = f"Lower is better"
percentage = False
if percentage:
df[y] *= 100
if "Accuracy" in y and "Generalization" not in y:
range_y = [
0,
100
]
else:
range_y = [
0,
df[
df[y] > 0
][y].quantile(0.90)*1.1
]
# if "loss" in y.lower():
# range_y = [0, df[y].quantile(0.90)*1.1]
# else:
# range_y = None
# if y == "Generalization_Gap":
# sub_title = f"Lower is better"
# range_y = None
# else:
# range_y = [0, 100 if percentage else 1]
# sub_title = f"Higher is better"
title = f'{y.replace("_", " ")}'
title += f"<br><sup>{sub_title}</sup>"
facet_row = "Train_Batch_Size"
fig = px.line(
data_frame=df,
x=x,
y=y,
facet_col="Learning_Rate",
facet_row="Train_Batch_Size",
facet_row_spacing = 0.1,
color = c,
title = title,
range_x = [df[x].values.min(), df[x].values.max()],
range_y = range_y, # df[y].values.min() * 0.95
markers=True,
)
n_rows = df[facet_row].unique().shape[0]
fig.update_layout(height=300*n_rows)
fig.update_traces(
patch={
"marker": {"size": 5},
"line": {
"width": 1,
# "dash": "dot"
},
}
)
fig.update_traces(connectgaps=True) # required for connecting dev accuracies
st.plotly_chart(fig, use_container_width=True)
return fig
def main():
st.set_page_config(
layout = "wide"
)
summary = get_summary()
x_options = [
"Epoch",
"Train_Time",
]
exclude = [
"Model",
"Optimizer",
"Learning_Rate",
"Train_Batch_Size"
]
y_options = [col for col in summary.columns if col not in x_options and col not in exclude]
optimizers_options = list(summary["Optimizer"].unique())
with st.sidebar:
x_selected = st.radio(
"X Metric",
x_options
)
y_selected = st.radio(
"Y Metric",
y_options
)
optimizers_selected = st.multiselect(
"Optimizers",
optimizers_options,
# value = ["Adam", "SGD"]
)
if len(optimizers_selected) == 0:
optimizers_selected = optimizers_options
# st.warning("Select optimizer(s)")
# return
summary = summary.pipe(filter_optimizers, optimizers_selected)
plot_summary(
summary,
x=x_selected,
y=y_selected
)
return
if __name__ == "__main__":
main()