|
| 1 | +from PyQt5.QtWidgets import (QWidget, QVBoxLayout) |
| 2 | +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas |
| 3 | +from matplotlib.figure import Figure |
| 4 | +from PyQt5.QtCore import QTimer |
| 5 | +import numpy as np |
| 6 | + |
| 7 | + |
| 8 | +class EmbeddedPlotWidget(QWidget): |
| 9 | + def __init__(self, parent=None): |
| 10 | + super().__init__(parent) |
| 11 | + |
| 12 | + # Create a layout for the widget |
| 13 | + layout = QVBoxLayout() |
| 14 | + |
| 15 | + # Create a matplotlib figure and canvas |
| 16 | + self.figure = Figure(figsize=(10, 6), dpi=100) |
| 17 | + self.canvas = FigureCanvas(self.figure) |
| 18 | + |
| 19 | + # Add canvas to layout |
| 20 | + layout.addWidget(self.canvas) |
| 21 | + self.setLayout(layout) |
| 22 | + |
| 23 | + def plot_fitness(self, fitnesses_per_generation): |
| 24 | + """ |
| 25 | + Plot fitness over generations, averaged. |
| 26 | + fitnesses_per_generation: list of lists of fitness values |
| 27 | + """ |
| 28 | + # Clear previous plot |
| 29 | + self.figure.clear() |
| 30 | + ax = self.figure.add_subplot(111) |
| 31 | + |
| 32 | + # Collect statistics across generations |
| 33 | + generations = list(range(len(fitnesses_per_generation))) |
| 34 | + max_fitness = [max(gen) for gen in fitnesses_per_generation] |
| 35 | + mean_fitness = [np.mean(gen) for gen in fitnesses_per_generation] |
| 36 | + std_fitness = [np.std(gen) for gen in fitnesses_per_generation] |
| 37 | + |
| 38 | + # Plot max fitness |
| 39 | + ax.plot( |
| 40 | + generations, |
| 41 | + max_fitness, |
| 42 | + label="Max fitness", |
| 43 | + color="b", |
| 44 | + ) |
| 45 | + ax.fill_between( |
| 46 | + generations, |
| 47 | + np.array(max_fitness) - np.array(std_fitness), |
| 48 | + np.array(max_fitness) + np.array(std_fitness), |
| 49 | + color="b", |
| 50 | + alpha=0.2, |
| 51 | + ) |
| 52 | + |
| 53 | + # Plot mean fitness |
| 54 | + ax.plot( |
| 55 | + generations, |
| 56 | + mean_fitness, |
| 57 | + label="Mean fitness", |
| 58 | + color="r", |
| 59 | + ) |
| 60 | + ax.fill_between( |
| 61 | + generations, |
| 62 | + np.array(mean_fitness) - np.array(std_fitness), |
| 63 | + np.array(mean_fitness) + np.array(std_fitness), |
| 64 | + color="r", |
| 65 | + alpha=0.2, |
| 66 | + ) |
| 67 | + |
| 68 | + # Customize plot |
| 69 | + ax.set_xlabel("Generation index") |
| 70 | + ax.set_ylabel("Fitness") |
| 71 | + ax.set_title("Mean and max fitness with std as shade") |
| 72 | + ax.legend() |
| 73 | + |
| 74 | + # Adjust layout and redraw |
| 75 | + self.figure.tight_layout() |
| 76 | + self.canvas.draw() |
| 77 | + |
| 78 | + # Optionally save the figure |
| 79 | + self.figure.savefig("src/ariel/viz/gui/resources/example.png") |
| 80 | + |
| 81 | + |
| 82 | +# class EmbeddedPlotWidgetDynamic(QWidget): |
| 83 | +# def __init__(self, parent=None, num_generations=10, update_interval=1000): |
| 84 | +# super().__init__(parent) |
| 85 | + |
| 86 | +# # Create a layout for the widget |
| 87 | +# layout = QVBoxLayout() |
| 88 | + |
| 89 | +# # Create a matplotlib figure and canvas |
| 90 | +# self.figure = Figure(figsize=(10, 6), dpi=100) |
| 91 | +# self.canvas = FigureCanvas(self.figure) |
| 92 | + |
| 93 | +# # Add canvas to layout |
| 94 | +# layout.addWidget(self.canvas) |
| 95 | +# self.setLayout(layout) |
| 96 | + |
| 97 | +# # Store database path |
| 98 | +# self.database_path = None |
| 99 | + |
| 100 | +# # Setup timer for periodic updates |
| 101 | +# self.update_timer = QTimer(self) |
| 102 | +# self.update_timer.timeout.connect(self.update_plot) |
| 103 | +# self.update_timer.start(update_interval) # Update every second |
| 104 | + |
| 105 | +# # Flag to prevent multiple simultaneous updates |
| 106 | +# self.is_updating = False |
| 107 | +# self.num_generations = num_generations |
| 108 | + |
| 109 | +# self.draw_base_plot(self.figure.add_subplot(111)) |
| 110 | + |
| 111 | +# def set_database_path(self, database_path): |
| 112 | +# self.database_path = database_path |
| 113 | + |
| 114 | +# def set_num_generations(self, num_generations): |
| 115 | +# self.num_generations = num_generations |
| 116 | + |
| 117 | +# def update_plot(self): |
| 118 | +# """ |
| 119 | +# Update the plot with the latest data from the database. |
| 120 | +# """ |
| 121 | +# if self.is_updating or not self.database_path: |
| 122 | +# return |
| 123 | + |
| 124 | +# try: |
| 125 | +# self.is_updating = True |
| 126 | + |
| 127 | +# # Clear previous plot |
| 128 | +# self.figure.clear() |
| 129 | +# ax = self.figure.add_subplot(111) |
| 130 | + |
| 131 | +# # Open database |
| 132 | +# dbengine = open_database_sqlite( |
| 133 | +# self.database_path, |
| 134 | +# open_method=OpenMethod.OPEN_IF_EXISTS |
| 135 | +# ) |
| 136 | + |
| 137 | +# # Read data |
| 138 | +# df = pd.read_sql( |
| 139 | +# select( |
| 140 | +# Experiment.id.label("experiment_id"), |
| 141 | +# Generation.generation_index, |
| 142 | +# Individual.fitness, |
| 143 | +# ) |
| 144 | +# .join_from(Experiment, Generation, Experiment.id == Generation.experiment_id) |
| 145 | +# .join_from(Generation, Population, Generation.population_id == Population.id) |
| 146 | +# .join_from(Population, Individual, Population.id == Individual.population_id), |
| 147 | +# dbengine, |
| 148 | +# ) |
| 149 | + |
| 150 | +# if df.empty or not ((df["generation_index"] == 0) & df["fitness"].notna()).any(): |
| 151 | +# self.draw_base_plot(ax) # Show base plot |
| 152 | +# return |
| 153 | + |
| 154 | +# # Aggregate data |
| 155 | +# agg_per_experiment_per_generation = ( |
| 156 | +# df.groupby(["experiment_id", "generation_index"]) |
| 157 | +# .agg({"fitness": ["max", "mean"]}) |
| 158 | +# .reset_index() |
| 159 | +# ) |
| 160 | +# agg_per_experiment_per_generation.columns = [ |
| 161 | +# "experiment_id", |
| 162 | +# "generation_index", |
| 163 | +# "max_fitness", |
| 164 | +# "mean_fitness", |
| 165 | +# ] |
| 166 | + |
| 167 | +# agg_per_generation = ( |
| 168 | +# agg_per_experiment_per_generation.groupby("generation_index") |
| 169 | +# .agg({"max_fitness": ["mean", "std"], "mean_fitness": ["mean", "std"]}) |
| 170 | +# .reset_index() |
| 171 | +# ) |
| 172 | +# agg_per_generation.columns = [ |
| 173 | +# "generation_index", |
| 174 | +# "max_fitness", |
| 175 | +# "max_fitness_std", |
| 176 | +# "mean_fitness", |
| 177 | +# "mean_fitness_std", |
| 178 | +# ] |
| 179 | + |
| 180 | +# ax.set_xlim(0, self.num_generations) |
| 181 | + |
| 182 | +# # Plot max fitness |
| 183 | +# ax.plot( |
| 184 | +# agg_per_generation["generation_index"], |
| 185 | +# agg_per_generation["max_fitness"], |
| 186 | +# label="Max fitness", |
| 187 | +# color="b", |
| 188 | +# ) |
| 189 | +# ax.fill_between( |
| 190 | +# agg_per_generation["generation_index"], |
| 191 | +# agg_per_generation["max_fitness"] - agg_per_generation["max_fitness_std"], |
| 192 | +# agg_per_generation["max_fitness"] + agg_per_generation["max_fitness_std"], |
| 193 | +# color="b", |
| 194 | +# alpha=0.2, |
| 195 | +# ) |
| 196 | + |
| 197 | +# # Plot mean fitness |
| 198 | +# ax.plot( |
| 199 | +# agg_per_generation["generation_index"], |
| 200 | +# agg_per_generation["mean_fitness"], |
| 201 | +# label="Mean fitness", |
| 202 | +# color="r", |
| 203 | +# ) |
| 204 | +# ax.fill_between( |
| 205 | +# agg_per_generation["generation_index"], |
| 206 | +# agg_per_generation["mean_fitness"] - agg_per_generation["mean_fitness_std"], |
| 207 | +# agg_per_generation["mean_fitness"] + agg_per_generation["mean_fitness_std"], |
| 208 | +# color="r", |
| 209 | +# alpha=0.2, |
| 210 | +# ) |
| 211 | + |
| 212 | +# # Customize plot |
| 213 | +# ax.set_xlabel("Generation index") |
| 214 | +# ax.set_ylabel("Fitness") |
| 215 | +# ax.set_title("Mean and max fitness across repetitions with std as shade") |
| 216 | +# ax.legend() |
| 217 | + |
| 218 | +# # Adjust layout and redraw |
| 219 | +# self.figure.tight_layout() |
| 220 | +# self.canvas.draw() |
| 221 | + |
| 222 | +# except Exception as e: |
| 223 | +# print(f"Error updating plot: {e}") |
| 224 | + |
| 225 | +# finally: |
| 226 | +# self.is_updating = False |
| 227 | + |
| 228 | + |
| 229 | +# def draw_base_plot(self, ax): |
| 230 | +# """ |
| 231 | +# Draws a placeholder base plot until valid data is available. |
| 232 | +# """ |
| 233 | +# ax.set_xlabel("Generation index") |
| 234 | +# ax.set_ylabel("Fitness") |
| 235 | +# ax.set_title("No data yet...") |
| 236 | +# ax.text(0.5, 0.5, "Waiting for first generation to complete", fontsize=14, ha="center", va="center", transform=ax.transAxes) |
| 237 | +# self.figure.tight_layout() |
| 238 | +# self.canvas.draw() |
0 commit comments