From 616c8d2d0167c0af5b6cf33eda540f4a59f2c361 Mon Sep 17 00:00:00 2001 From: Piotr Migdal Date: Fri, 15 Jul 2022 14:15:15 +0200 Subject: [PATCH] prediction plot quick fix --- livelossplot/outputs/matplotlib_subplots.py | 8 ++++++++ livelossplot/plot_losses.py | 2 ++ 2 files changed, 10 insertions(+) diff --git a/livelossplot/outputs/matplotlib_subplots.py b/livelossplot/outputs/matplotlib_subplots.py index 6057650..5b6d91f 100644 --- a/livelossplot/outputs/matplotlib_subplots.py +++ b/livelossplot/outputs/matplotlib_subplots.py @@ -13,6 +13,11 @@ def draw(self, *args, **kwargs): def __call__(self, *args, **kwargs): self.draw(*args, **kwargs) + def set_output_mode(self, mode: str): + """Set notebook or script mode - not implemented yet""" + ... + + class LossSubplot(BaseSubplot): """To rewrire, this one now won't work""" @@ -59,6 +64,7 @@ def draw(self, logs): plt.title(self.title) plt.xlabel('epoch') plt.legend(loc='center right') + plt.show() class Plot1D(BaseSubplot): @@ -77,6 +83,7 @@ def draw(self, *args, **kwargs): plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model") plt.title("Prediction") plt.legend(loc='lower right') + plt.show() class Plot2d(BaseSubplot): @@ -119,3 +126,4 @@ def send(self, logger): plt.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points) if self.X_test is not None: plt.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3) + plt.show() diff --git a/livelossplot/plot_losses.py b/livelossplot/plot_losses.py index e32b498..69fa02b 100644 --- a/livelossplot/plot_losses.py +++ b/livelossplot/plot_losses.py @@ -4,6 +4,7 @@ import livelossplot from livelossplot.main_logger import MainLogger from livelossplot import outputs +from IPython.display import clear_output BO = TypeVar('BO', bound=outputs.BaseOutput) @@ -37,6 +38,7 @@ def update(self, *args, **kwargs): def send(self): """Method will send logs to every output class""" + clear_output(wait=True) for output in self.outputs: output.send(self.logger)