From 4a11c7d221034b1c7e8a802f48ba482e0bc8d7c7 Mon Sep 17 00:00:00 2001 From: Yongjin Cho Date: Mon, 28 Apr 2025 23:41:07 +0900 Subject: [PATCH] Fix errors when max_cols is 1 --- livelossplot/outputs/matplotlib_plot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/livelossplot/outputs/matplotlib_plot.py b/livelossplot/outputs/matplotlib_plot.py index 0197acf..a7025d6 100644 --- a/livelossplot/outputs/matplotlib_plot.py +++ b/livelossplot/outputs/matplotlib_plot.py @@ -59,6 +59,8 @@ def send(self, logger: MainLogger): max_rows = math.ceil((len(log_groups) + len(self.extra_plots)) / self.max_cols) fig, axes = plt.subplots(max_rows, self.max_cols) + if not isinstance(axes, np.ndarray): + axes = np.array([[axes]]) axes = axes.reshape(-1, self.max_cols) self._before_plots(fig, axes, len(log_groups))