Skip to content

Commit dd28fd4

Browse files
author
A-lamo
committed
Add: Ported and edited plot analysis from revolve2/gui.
1 parent 841df7c commit dd28fd4

File tree

1 file changed

+238
-0
lines changed

1 file changed

+238
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from PyQt5.QtWidgets import (QWidget, QVBoxLayout)
2+
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
3+
from matplotlib.figure import Figure
4+
from PyQt5.QtCore import QTimer
5+
import numpy as np
6+
7+
8+
class EmbeddedPlotWidget(QWidget):
9+
def __init__(self, parent=None):
10+
super().__init__(parent)
11+
12+
# Create a layout for the widget
13+
layout = QVBoxLayout()
14+
15+
# Create a matplotlib figure and canvas
16+
self.figure = Figure(figsize=(10, 6), dpi=100)
17+
self.canvas = FigureCanvas(self.figure)
18+
19+
# Add canvas to layout
20+
layout.addWidget(self.canvas)
21+
self.setLayout(layout)
22+
23+
def plot_fitness(self, fitnesses_per_generation):
24+
"""
25+
Plot fitness over generations, averaged.
26+
fitnesses_per_generation: list of lists of fitness values
27+
"""
28+
# Clear previous plot
29+
self.figure.clear()
30+
ax = self.figure.add_subplot(111)
31+
32+
# Collect statistics across generations
33+
generations = list(range(len(fitnesses_per_generation)))
34+
max_fitness = [max(gen) for gen in fitnesses_per_generation]
35+
mean_fitness = [np.mean(gen) for gen in fitnesses_per_generation]
36+
std_fitness = [np.std(gen) for gen in fitnesses_per_generation]
37+
38+
# Plot max fitness
39+
ax.plot(
40+
generations,
41+
max_fitness,
42+
label="Max fitness",
43+
color="b",
44+
)
45+
ax.fill_between(
46+
generations,
47+
np.array(max_fitness) - np.array(std_fitness),
48+
np.array(max_fitness) + np.array(std_fitness),
49+
color="b",
50+
alpha=0.2,
51+
)
52+
53+
# Plot mean fitness
54+
ax.plot(
55+
generations,
56+
mean_fitness,
57+
label="Mean fitness",
58+
color="r",
59+
)
60+
ax.fill_between(
61+
generations,
62+
np.array(mean_fitness) - np.array(std_fitness),
63+
np.array(mean_fitness) + np.array(std_fitness),
64+
color="r",
65+
alpha=0.2,
66+
)
67+
68+
# Customize plot
69+
ax.set_xlabel("Generation index")
70+
ax.set_ylabel("Fitness")
71+
ax.set_title("Mean and max fitness with std as shade")
72+
ax.legend()
73+
74+
# Adjust layout and redraw
75+
self.figure.tight_layout()
76+
self.canvas.draw()
77+
78+
# Optionally save the figure
79+
self.figure.savefig("src/ariel/viz/gui/resources/example.png")
80+
81+
82+
# class EmbeddedPlotWidgetDynamic(QWidget):
83+
# def __init__(self, parent=None, num_generations=10, update_interval=1000):
84+
# super().__init__(parent)
85+
86+
# # Create a layout for the widget
87+
# layout = QVBoxLayout()
88+
89+
# # Create a matplotlib figure and canvas
90+
# self.figure = Figure(figsize=(10, 6), dpi=100)
91+
# self.canvas = FigureCanvas(self.figure)
92+
93+
# # Add canvas to layout
94+
# layout.addWidget(self.canvas)
95+
# self.setLayout(layout)
96+
97+
# # Store database path
98+
# self.database_path = None
99+
100+
# # Setup timer for periodic updates
101+
# self.update_timer = QTimer(self)
102+
# self.update_timer.timeout.connect(self.update_plot)
103+
# self.update_timer.start(update_interval) # Update every second
104+
105+
# # Flag to prevent multiple simultaneous updates
106+
# self.is_updating = False
107+
# self.num_generations = num_generations
108+
109+
# self.draw_base_plot(self.figure.add_subplot(111))
110+
111+
# def set_database_path(self, database_path):
112+
# self.database_path = database_path
113+
114+
# def set_num_generations(self, num_generations):
115+
# self.num_generations = num_generations
116+
117+
# def update_plot(self):
118+
# """
119+
# Update the plot with the latest data from the database.
120+
# """
121+
# if self.is_updating or not self.database_path:
122+
# return
123+
124+
# try:
125+
# self.is_updating = True
126+
127+
# # Clear previous plot
128+
# self.figure.clear()
129+
# ax = self.figure.add_subplot(111)
130+
131+
# # Open database
132+
# dbengine = open_database_sqlite(
133+
# self.database_path,
134+
# open_method=OpenMethod.OPEN_IF_EXISTS
135+
# )
136+
137+
# # Read data
138+
# df = pd.read_sql(
139+
# select(
140+
# Experiment.id.label("experiment_id"),
141+
# Generation.generation_index,
142+
# Individual.fitness,
143+
# )
144+
# .join_from(Experiment, Generation, Experiment.id == Generation.experiment_id)
145+
# .join_from(Generation, Population, Generation.population_id == Population.id)
146+
# .join_from(Population, Individual, Population.id == Individual.population_id),
147+
# dbengine,
148+
# )
149+
150+
# if df.empty or not ((df["generation_index"] == 0) & df["fitness"].notna()).any():
151+
# self.draw_base_plot(ax) # Show base plot
152+
# return
153+
154+
# # Aggregate data
155+
# agg_per_experiment_per_generation = (
156+
# df.groupby(["experiment_id", "generation_index"])
157+
# .agg({"fitness": ["max", "mean"]})
158+
# .reset_index()
159+
# )
160+
# agg_per_experiment_per_generation.columns = [
161+
# "experiment_id",
162+
# "generation_index",
163+
# "max_fitness",
164+
# "mean_fitness",
165+
# ]
166+
167+
# agg_per_generation = (
168+
# agg_per_experiment_per_generation.groupby("generation_index")
169+
# .agg({"max_fitness": ["mean", "std"], "mean_fitness": ["mean", "std"]})
170+
# .reset_index()
171+
# )
172+
# agg_per_generation.columns = [
173+
# "generation_index",
174+
# "max_fitness",
175+
# "max_fitness_std",
176+
# "mean_fitness",
177+
# "mean_fitness_std",
178+
# ]
179+
180+
# ax.set_xlim(0, self.num_generations)
181+
182+
# # Plot max fitness
183+
# ax.plot(
184+
# agg_per_generation["generation_index"],
185+
# agg_per_generation["max_fitness"],
186+
# label="Max fitness",
187+
# color="b",
188+
# )
189+
# ax.fill_between(
190+
# agg_per_generation["generation_index"],
191+
# agg_per_generation["max_fitness"] - agg_per_generation["max_fitness_std"],
192+
# agg_per_generation["max_fitness"] + agg_per_generation["max_fitness_std"],
193+
# color="b",
194+
# alpha=0.2,
195+
# )
196+
197+
# # Plot mean fitness
198+
# ax.plot(
199+
# agg_per_generation["generation_index"],
200+
# agg_per_generation["mean_fitness"],
201+
# label="Mean fitness",
202+
# color="r",
203+
# )
204+
# ax.fill_between(
205+
# agg_per_generation["generation_index"],
206+
# agg_per_generation["mean_fitness"] - agg_per_generation["mean_fitness_std"],
207+
# agg_per_generation["mean_fitness"] + agg_per_generation["mean_fitness_std"],
208+
# color="r",
209+
# alpha=0.2,
210+
# )
211+
212+
# # Customize plot
213+
# ax.set_xlabel("Generation index")
214+
# ax.set_ylabel("Fitness")
215+
# ax.set_title("Mean and max fitness across repetitions with std as shade")
216+
# ax.legend()
217+
218+
# # Adjust layout and redraw
219+
# self.figure.tight_layout()
220+
# self.canvas.draw()
221+
222+
# except Exception as e:
223+
# print(f"Error updating plot: {e}")
224+
225+
# finally:
226+
# self.is_updating = False
227+
228+
229+
# def draw_base_plot(self, ax):
230+
# """
231+
# Draws a placeholder base plot until valid data is available.
232+
# """
233+
# ax.set_xlabel("Generation index")
234+
# ax.set_ylabel("Fitness")
235+
# ax.set_title("No data yet...")
236+
# ax.text(0.5, 0.5, "Waiting for first generation to complete", fontsize=14, ha="center", va="center", transform=ax.transAxes)
237+
# self.figure.tight_layout()
238+
# self.canvas.draw()

0 commit comments

Comments
 (0)