From a416194d6852fe0380f25740ff93599903a8dc99 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 14:14:51 -0400 Subject: [PATCH 01/15] move slider backcompat to separate file --- mpl_interactions/_widget_backfill.py | 339 ++++++++++++++++++++++++++ mpl_interactions/widgets.py | 342 +-------------------------- 2 files changed, 344 insertions(+), 337 deletions(-) create mode 100644 mpl_interactions/_widget_backfill.py diff --git a/mpl_interactions/_widget_backfill.py b/mpl_interactions/_widget_backfill.py new file mode 100644 index 00000000..87bf5c94 --- /dev/null +++ b/mpl_interactions/_widget_backfill.py @@ -0,0 +1,339 @@ +""" +Implementing matplotlib widgets for back compat +""" +from matplotlib.widgets import AxesWidget +from matplotlib import cbook, ticker + +# slider widgets are taken almost verbatim from https://github.com/matplotlib/matplotlib/pull/18829/files +# which was written by me - but incorporates much of the existing matplotlib slider infrastructure +class SliderBase(AxesWidget): + def __init__( + self, ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep + ): + if ax.name == "3d": + raise ValueError("Sliders cannot be added to 3D Axes") + + super().__init__(ax) + + self.orientation = orientation + self.closedmin = closedmin + self.closedmax = closedmax + self.valmin = valmin + self.valmax = valmax + self.valstep = valstep + self.drag_active = False + self.valfmt = valfmt + + if orientation == "vertical": + ax.set_ylim((valmin, valmax)) + axis = ax.yaxis + else: + ax.set_xlim((valmin, valmax)) + axis = ax.xaxis + + self._fmt = axis.get_major_formatter() + if not isinstance(self._fmt, ticker.ScalarFormatter): + self._fmt = ticker.ScalarFormatter() + self._fmt.set_axis(axis) + self._fmt.set_useOffset(False) # No additive offset. + self._fmt.set_useMathText(True) # x sign before multiplicative offset. + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_navigate(False) + self.connect_event("button_press_event", self._update) + self.connect_event("button_release_event", self._update) + if dragging: + self.connect_event("motion_notify_event", self._update) + self._observers = cbook.CallbackRegistry() + + def _stepped_value(self, val): + if self.valstep: + return self.valmin + round((val - self.valmin) / self.valstep) * self.valstep + return val + + def disconnect(self, cid): + """ + Remove the observer with connection id *cid* + + Parameters + ---------- + cid : int + Connection id of the observer to be removed + """ + self._observers.disconnect(cid) + + def reset(self): + """Reset the slider to the initial value""" + if self.val != self.valinit: + self.set_val(self.valinit) + + +class RangeSlider(SliderBase): + """ + A slider representing a floating point range. + + Create a slider from *valmin* to *valmax* in axes *ax*. For the slider to + remain responsive you must maintain a reference to it. Call + :meth:`on_changed` to connect to the slider event. + + Attributes + ---------- + val : tuple of float + Slider value. + """ + + def __init__( + self, + ax, + label, + valmin, + valmax, + valinit=None, + valfmt=None, + closedmin=True, + closedmax=True, + dragging=True, + valstep=None, + orientation="horizontal", + **kwargs, + ): + """ + Parameters + ---------- + ax : Axes + The Axes to put the slider in. + label : str + Slider label. + valmin : float + The minimum value of the slider. + valmax : float + The maximum value of the slider. + valinit : tuple of float or None, default: None + The initial positions of the slider. If None the initial positions + will be at the 25th and 75th percentiles of the range. + valfmt : str, default: None + %-format string used to format the slider values. If None, a + `.ScalarFormatter` is used instead. + closedmin : bool, default: True + Whether the slider interval is closed on the bottom. + closedmax : bool, default: True + Whether the slider interval is closed on the top. + dragging : bool, default: True + If True the slider can be dragged by the mouse. + valstep : float, default: None + If given, the slider will snap to multiples of *valstep*. + orientation : {'horizontal', 'vertical'}, default: 'horizontal' + The orientation of the slider. + + Notes + ----- + Additional kwargs are passed on to ``self.poly`` which is the + `~matplotlib.patches.Rectangle` that draws the slider knob. See the + `.Rectangle` documentation for valid property names (``facecolor``, + ``edgecolor``, ``alpha``, etc.). + """ + super().__init__( + ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep + ) + + self.val = valinit + if valinit is None: + valinit = np.array([valmin + 0.25 * valmax, valmin + 0.75 * valmax]) + else: + valinit = self._value_in_bounds(valinit) + self.val = valinit + self.valinit = valinit + if orientation == "vertical": + self.poly = ax.axhspan(valinit[0], valinit[1], 0, 1, **kwargs) + else: + self.poly = ax.axvspan(valinit[0], valinit[1], 0, 1, **kwargs) + + if orientation == "vertical": + self.label = ax.text( + 0.5, + 1.02, + label, + transform=ax.transAxes, + verticalalignment="bottom", + horizontalalignment="center", + ) + + self.valtext = ax.text( + 0.5, + -0.02, + self._format(valinit), + transform=ax.transAxes, + verticalalignment="top", + horizontalalignment="center", + ) + else: + self.label = ax.text( + -0.02, + 0.5, + label, + transform=ax.transAxes, + verticalalignment="center", + horizontalalignment="right", + ) + + self.valtext = ax.text( + 1.02, + 0.5, + self._format(valinit), + transform=ax.transAxes, + verticalalignment="center", + horizontalalignment="left", + ) + + self.set_val(valinit) + + def _min_in_bounds(self, min): + """ + Ensure the new min value is between valmin and self.val[1] + """ + if min <= self.valmin: + if not self.closedmin: + return self.val[0] + min = self.valmin + + if min > self.val[1]: + min = self.val[1] + return self._stepped_value(min) + + def _max_in_bounds(self, max): + """ + Ensure the new max value is between valmax and self.val[0] + """ + if max >= self.valmax: + if not self.closedmax: + return self.val[1] + max = self.valmax + + if max <= self.val[0]: + max = self.val[0] + return self._stepped_value(max) + + def _value_in_bounds(self, val): + return (self._min_in_bounds(val[0]), self._max_in_bounds(val[1])) + + def _update_val_from_pos(self, pos): + """ + Given a position update the *val* + """ + idx = np.argmin(np.abs(self.val - pos)) + if idx == 0: + val = self._min_in_bounds(pos) + self.set_min(val) + else: + val = self._max_in_bounds(pos) + self.set_max(val) + + def _update(self, event): + """Update the slider position.""" + if self.ignore(event) or event.button != 1: + return + + if event.name == "button_press_event" and event.inaxes == self.ax: + self.drag_active = True + event.canvas.grab_mouse(self.ax) + + if not self.drag_active: + return + + elif (event.name == "button_release_event") or ( + event.name == "button_press_event" and event.inaxes != self.ax + ): + self.drag_active = False + event.canvas.release_mouse(self.ax) + return + if self.orientation == "vertical": + self._update_val_from_pos(event.ydata) + else: + self._update_val_from_pos(event.xdata) + + def _format(self, val): + """Pretty-print *val*.""" + if self.valfmt is not None: + return (self.valfmt % val[0], self.valfmt % val[1]) + else: + # fmt.get_offset is actually the multiplicative factor, if any. + _, s1, s2, _ = self._fmt.format_ticks([self.valmin, *val, self.valmax]) + # fmt.get_offset is actually the multiplicative factor, if any. + s1 += self._fmt.get_offset() + s2 += self._fmt.get_offset() + # use raw string to avoid issues with backslashes from + return rf"({s1}, {s2})" + + def set_min(self, min): + """ + Set the lower value of the slider to *min* + + Parameters + ---------- + min : float + """ + self.set_val((min, self.val[1])) + + def set_max(self, max): + """ + Set the lower value of the slider to *max* + + Parameters + ---------- + max : float + """ + self.set_val((self.val[0], max)) + + def set_val(self, val): + """ + Set slider value to *val* + + Parameters + ---------- + val : tuple or arraylike of float + """ + val = np.sort(np.asanyarray(val)) + if val.shape != (2,): + raise ValueError(f"val must have shape (2,) but has shape {val.shape}") + val[0] = self._min_in_bounds(val[0]) + val[1] = self._max_in_bounds(val[1]) + xy = self.poly.xy + if self.orientation == "vertical": + xy[0] = 0, val[0] + xy[1] = 0, val[1] + xy[2] = 1, val[1] + xy[3] = 1, val[0] + xy[4] = 0, val[0] + else: + xy[0] = val[0], 0 + xy[1] = val[0], 1 + xy[2] = val[1], 1 + xy[3] = val[1], 0 + xy[4] = val[0], 0 + self.poly.xy = xy + self.valtext.set_text(self._format(val)) + if self.drawon: + self.ax.figure.canvas.draw_idle() + self.val = val + if self.eventson: + self._observers.process("changed", val) + + def on_changed(self, func): + """ + When the slider value is changed call *func* with the new + slider value + + Parameters + ---------- + func : callable + Function to call when slider is changed. + The function must accept a numpy array with shape (2,) float + as its argument. + + Returns + ------- + int + Connection id (which can be used to disconnect *func*) + """ + return self._observers.connect("changed", func) diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index 044a7eef..8a114575 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -1,7 +1,10 @@ from matplotlib.cbook import CallbackRegistry from matplotlib.widgets import AxesWidget -from matplotlib import cbook, ticker -import numpy as np + +try: + from matplotlib.widgets import RangeSlider, SliderBase +except ImportError: + from ._widget_backfill import RangeSlider, SliderBase __all__ = [ "scatter_selector", @@ -138,338 +141,3 @@ def on_changed(self, func): Connection id (which can be used to disconnect *func*) """ return self._observers.connect("picked", lambda val: func(val)) - - -# slider widgets are taken almost verbatim from https://github.com/matplotlib/matplotlib/pull/18829/files -# which was written by me - but incorporates much of the existing matplotlib slider infrastructure -class SliderBase(AxesWidget): - def __init__( - self, ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep - ): - if ax.name == "3d": - raise ValueError("Sliders cannot be added to 3D Axes") - - super().__init__(ax) - - self.orientation = orientation - self.closedmin = closedmin - self.closedmax = closedmax - self.valmin = valmin - self.valmax = valmax - self.valstep = valstep - self.drag_active = False - self.valfmt = valfmt - - if orientation == "vertical": - ax.set_ylim((valmin, valmax)) - axis = ax.yaxis - else: - ax.set_xlim((valmin, valmax)) - axis = ax.xaxis - - self._fmt = axis.get_major_formatter() - if not isinstance(self._fmt, ticker.ScalarFormatter): - self._fmt = ticker.ScalarFormatter() - self._fmt.set_axis(axis) - self._fmt.set_useOffset(False) # No additive offset. - self._fmt.set_useMathText(True) # x sign before multiplicative offset. - - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_navigate(False) - self.connect_event("button_press_event", self._update) - self.connect_event("button_release_event", self._update) - if dragging: - self.connect_event("motion_notify_event", self._update) - self._observers = cbook.CallbackRegistry() - - def _stepped_value(self, val): - if self.valstep: - return self.valmin + round((val - self.valmin) / self.valstep) * self.valstep - return val - - def disconnect(self, cid): - """ - Remove the observer with connection id *cid* - - Parameters - ---------- - cid : int - Connection id of the observer to be removed - """ - self._observers.disconnect(cid) - - def reset(self): - """Reset the slider to the initial value""" - if self.val != self.valinit: - self.set_val(self.valinit) - - -class RangeSlider(SliderBase): - """ - A slider representing a floating point range. - - Create a slider from *valmin* to *valmax* in axes *ax*. For the slider to - remain responsive you must maintain a reference to it. Call - :meth:`on_changed` to connect to the slider event. - - Attributes - ---------- - val : tuple of float - Slider value. - """ - - def __init__( - self, - ax, - label, - valmin, - valmax, - valinit=None, - valfmt=None, - closedmin=True, - closedmax=True, - dragging=True, - valstep=None, - orientation="horizontal", - **kwargs, - ): - """ - Parameters - ---------- - ax : Axes - The Axes to put the slider in. - label : str - Slider label. - valmin : float - The minimum value of the slider. - valmax : float - The maximum value of the slider. - valinit : tuple of float or None, default: None - The initial positions of the slider. If None the initial positions - will be at the 25th and 75th percentiles of the range. - valfmt : str, default: None - %-format string used to format the slider values. If None, a - `.ScalarFormatter` is used instead. - closedmin : bool, default: True - Whether the slider interval is closed on the bottom. - closedmax : bool, default: True - Whether the slider interval is closed on the top. - dragging : bool, default: True - If True the slider can be dragged by the mouse. - valstep : float, default: None - If given, the slider will snap to multiples of *valstep*. - orientation : {'horizontal', 'vertical'}, default: 'horizontal' - The orientation of the slider. - - Notes - ----- - Additional kwargs are passed on to ``self.poly`` which is the - `~matplotlib.patches.Rectangle` that draws the slider knob. See the - `.Rectangle` documentation for valid property names (``facecolor``, - ``edgecolor``, ``alpha``, etc.). - """ - super().__init__( - ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep - ) - - self.val = valinit - if valinit is None: - valinit = np.array([valmin + 0.25 * valmax, valmin + 0.75 * valmax]) - else: - valinit = self._value_in_bounds(valinit) - self.val = valinit - self.valinit = valinit - if orientation == "vertical": - self.poly = ax.axhspan(valinit[0], valinit[1], 0, 1, **kwargs) - else: - self.poly = ax.axvspan(valinit[0], valinit[1], 0, 1, **kwargs) - - if orientation == "vertical": - self.label = ax.text( - 0.5, - 1.02, - label, - transform=ax.transAxes, - verticalalignment="bottom", - horizontalalignment="center", - ) - - self.valtext = ax.text( - 0.5, - -0.02, - self._format(valinit), - transform=ax.transAxes, - verticalalignment="top", - horizontalalignment="center", - ) - else: - self.label = ax.text( - -0.02, - 0.5, - label, - transform=ax.transAxes, - verticalalignment="center", - horizontalalignment="right", - ) - - self.valtext = ax.text( - 1.02, - 0.5, - self._format(valinit), - transform=ax.transAxes, - verticalalignment="center", - horizontalalignment="left", - ) - - self.set_val(valinit) - - def _min_in_bounds(self, min): - """ - Ensure the new min value is between valmin and self.val[1] - """ - if min <= self.valmin: - if not self.closedmin: - return self.val[0] - min = self.valmin - - if min > self.val[1]: - min = self.val[1] - return self._stepped_value(min) - - def _max_in_bounds(self, max): - """ - Ensure the new max value is between valmax and self.val[0] - """ - if max >= self.valmax: - if not self.closedmax: - return self.val[1] - max = self.valmax - - if max <= self.val[0]: - max = self.val[0] - return self._stepped_value(max) - - def _value_in_bounds(self, val): - return (self._min_in_bounds(val[0]), self._max_in_bounds(val[1])) - - def _update_val_from_pos(self, pos): - """ - Given a position update the *val* - """ - idx = np.argmin(np.abs(self.val - pos)) - if idx == 0: - val = self._min_in_bounds(pos) - self.set_min(val) - else: - val = self._max_in_bounds(pos) - self.set_max(val) - - def _update(self, event): - """Update the slider position.""" - if self.ignore(event) or event.button != 1: - return - - if event.name == "button_press_event" and event.inaxes == self.ax: - self.drag_active = True - event.canvas.grab_mouse(self.ax) - - if not self.drag_active: - return - - elif (event.name == "button_release_event") or ( - event.name == "button_press_event" and event.inaxes != self.ax - ): - self.drag_active = False - event.canvas.release_mouse(self.ax) - return - if self.orientation == "vertical": - self._update_val_from_pos(event.ydata) - else: - self._update_val_from_pos(event.xdata) - - def _format(self, val): - """Pretty-print *val*.""" - if self.valfmt is not None: - return (self.valfmt % val[0], self.valfmt % val[1]) - else: - # fmt.get_offset is actually the multiplicative factor, if any. - _, s1, s2, _ = self._fmt.format_ticks([self.valmin, *val, self.valmax]) - # fmt.get_offset is actually the multiplicative factor, if any. - s1 += self._fmt.get_offset() - s2 += self._fmt.get_offset() - # use raw string to avoid issues with backslashes from - return rf"({s1}, {s2})" - - def set_min(self, min): - """ - Set the lower value of the slider to *min* - - Parameters - ---------- - min : float - """ - self.set_val((min, self.val[1])) - - def set_max(self, max): - """ - Set the lower value of the slider to *max* - - Parameters - ---------- - max : float - """ - self.set_val((self.val[0], max)) - - def set_val(self, val): - """ - Set slider value to *val* - - Parameters - ---------- - val : tuple or arraylike of float - """ - val = np.sort(np.asanyarray(val)) - if val.shape != (2,): - raise ValueError(f"val must have shape (2,) but has shape {val.shape}") - val[0] = self._min_in_bounds(val[0]) - val[1] = self._max_in_bounds(val[1]) - xy = self.poly.xy - if self.orientation == "vertical": - xy[0] = 0, val[0] - xy[1] = 0, val[1] - xy[2] = 1, val[1] - xy[3] = 1, val[0] - xy[4] = 0, val[0] - else: - xy[0] = val[0], 0 - xy[1] = val[0], 1 - xy[2] = val[1], 1 - xy[3] = val[1], 0 - xy[4] = val[0], 0 - self.poly.xy = xy - self.valtext.set_text(self._format(val)) - if self.drawon: - self.ax.figure.canvas.draw_idle() - self.val = val - if self.eventson: - self._observers.process("changed", val) - - def on_changed(self, func): - """ - When the slider value is changed call *func* with the new - slider value - - Parameters - ---------- - func : callable - Function to call when slider is changed. - The function must accept a numpy array with shape (2,) float - as its argument. - - Returns - ------- - int - Connection id (which can be used to disconnect *func*) - """ - return self._observers.connect("changed", func) From 76968928e544f6f986c9fa1bf0bcf9758e56cca9 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 14:17:13 -0400 Subject: [PATCH 02/15] Create SliderWrappers These are sliders that wrap either ipywidgets or matplotlib sliders This should make it easier to expose the raw sliders created by mpl-interactions to the users. --- mpl_interactions/widgets.py | 209 ++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index 8a114575..b53dcad1 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -1,3 +1,26 @@ +import numpy as np + +from traitlets import ( + HasTraits, + Int, + Float, + Union, + observe, + dlink, + link, + Tuple, + Unicode, + validate, + TraitError, + Any, +) +from traittypes import Array + +try: + import ipywidgets as widgets +except ImportError: + widgets = None +from matplotlib import widgets as mwidgets from matplotlib.cbook import CallbackRegistry from matplotlib.widgets import AxesWidget @@ -11,6 +34,9 @@ "scatter_selector_index", "scatter_selector_value", "RangeSlider", + "SliderWrapper", + "IntSlider", + "IndexSlider", ] @@ -141,3 +167,186 @@ def on_changed(self, func): Connection id (which can be used to disconnect *func*) """ return self._observers.connect("picked", lambda val: func(val)) + + +_gross_traits = [ + "add_traits", + "class_own_trait_events", + "class_own_traits", + "class_trait_names", + "class_traits", + "cross_validation_lock", + "has_trait", + "hold_trait_notifications", + "notify_change", + "on_trait_change", + "set_trait", + "setup_instance", + "trait_defaults", + "trait_events", + "trait_has_value", + "trait_metadata", + "trait_names", + "trait_values", + "traits", +] + + +class SliderWrapper(HasTraits): + """ + A warpper class that provides a consistent interface for both + ipywidgets and matplotlib sliders. + """ + + min = Union([Int(), Float(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) + max = Union([Int(), Float(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) + value = Union([Float(), Int(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) + step = Union([Int(), Float(allow_none=True)]) + index = Int(allow_none=True) + label = Unicode() + readout_format = Unicode("{:.2f}") + + def __init__(self, slider, readout_format=None, setup_value_callbacks=True): + super().__init__() + self._raw_slider = slider + # eventually we can just rely on SliderBase here + # for now keep both for compatibility with mpl < 3.4 + self._mpl = isinstance(slider, (mwidgets.Slider, SliderBase)) + if self._mpl: + self.observe(lambda change: setattr(self._raw_slider, "valmin", change["new"]), "min") + self.observe(lambda change: setattr(self._raw_slider, "valmax", change["new"]), "max") + self.observe(lambda change: self._raw_slider.label.set_text(change["new"]), "label") + if setup_value_callbacks: + self.observe(lambda change: self._raw_slider.set_val(change["new"]), "value") + self._raw_slider.on_changed(lambda val: setattr(self, "value", val)) + self.value = self._raw_slider.val + self.min = self._raw_slider.valmin + self.max = self._raw_slider.valmax + self.step = self._raw_slider.valstep + self.label = self._raw_slider.label.get_text() + else: + if setup_value_callbacks: + link((slider, "value"), (self, "value")) + link((slider, "min"), (self, "min")) + link((slider, "max"), (self, "max")) + link((slider, "step"), (self, "step")) + link((slider, "description"), (self, "label")) + self._callbacks = [] + + @observe("value") + def _on_changed(self, change): + for c in self._callbacks: + c(change["new"]) + + def on_changed(self, callback): + # callback registry? + self._callbacks.append(callback) + + def _get_widget_for_display(self): + return self._raw_slider + + def _ipython_display_(self): + if self._mpl: + pass + else: + return self._get_widget_for_display() + + def __dir__(self): + # hide all the cruft from traitlets for shfit+Tab + return [i for i in super().__dir__() if i not in _gross_traits] + + +class IntSlider(SliderWrapper): + min = Int() + max = Int() + value = Int() + + +class IndexSlider(IntSlider): + """ + A slider class to index through an array of values. + """ + + index = Int() + max_index = Int() + value = Any() + values = Array() + # gotta make values traitlike - traittypes? + + def __init__(self, values, label="", mpl_slider_ax=None, play_button=False): + """ + Parameters + ---------- + values : 1D arraylike + The values to index over + label : str + The slider label + mpl_slider_ax : matplotlib.axes or None + If *None* an ipywidgets slider will be created + """ + if play_button is not False: + raise ValueError("play buttons not yet implemented fool!") + self.values = np.atleast_1d(values) + self.readout_format = "{:.2f}" + if mpl_slider_ax is not None: + # make mpl_slider + slider = mwidgets.Slider( + mpl_slider_ax, + label=label, + valinit=0, + valmin=0, + valmax=self.values.shape[0] - 1, + valstep=1, + ) + + def onchange(val): + self.index = int(val) + slider.valtext.set_text(self.readout_format.format(self.values[int(val)])) + + slider.on_changed(onchange) + self.values + elif widgets: + slider = widgets.IntSlider( + 0, 0, self.values.shape[0] - 1, step=1, readout=False, description=label + ) + self._readout = widgets.Label(value=str(self.values[0])) + widgets.dlink( + (slider, "value"), + (self._readout, "value"), + transform=lambda x: self.readout_format.format(self.values[x]), + ) + link((slider, "value"), (self, "index")) + link((slider, "max"), (self, "max_index")) + else: + raise ValueError("mpl_slider_ax cannot be None if ipywidgets is not available") + super().__init__(slider, setup_value_callbacks=False) + self.value = self.values[self.index] + + def _get_widget_for_display(self): + return widgets.HBox([self._raw_slider, self._readout]) + + @validate("value") + def _validate_value(self, proposal): + if not proposal["value"] in self.values: + raise TraitError( + f"{proposal['value']} is not in the set of values for this index slider." + " To see or change the set of valid values use the `.values` attribute" + ) + # call `int` because traitlets can't handle np int64 + index = int(np.argmin(np.abs(self.values - proposal["value"]))) + self.index = index + + return proposal["value"] + + @observe("index") + def _obs_index(self, change): + # call .item because traitlets is unhappy with numpy types + self.value = self.values[change["new"]].item() + + @validate("values") + def _validate_values(self, proposal): + values = proposal["value"] + if values.ndim > 1: + raise TraitError("Expected 1d array but got an array with shape %s" % (values.shape)) + self.max_index = values.shape[0] + return values From a8b8a769aef18f48f71a063104da24b138bd9dcc Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 14:20:14 -0400 Subject: [PATCH 03/15] Switch to using SliderWrappers instead of raw sliders --- mpl_interactions/controller.py | 131 +++++++++++------------- mpl_interactions/helpers.py | 177 +++++++++++++++++++++++++++++---- mpl_interactions/pyplot.py | 19 ++-- 3 files changed, 223 insertions(+), 104 deletions(-) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index 11d41b2c..924c0e76 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -11,7 +11,8 @@ create_slider_format_dict, kwarg_to_ipywidget, kwarg_to_mpl_widget, - create_mpl_controls_fig, + maybe_create_mpl_controls_axes, + kwarg_to_widget, notebook_backend, process_mpl_widget, ) @@ -61,14 +62,28 @@ def __init__( self._user_callbacks = defaultdict(list) self.add_kwargs(kwargs, slider_formats, play_buttons) - def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_duplicates=False): + def add_kwargs( + self, + kwargs, + slider_formats=None, + play_buttons=None, + allow_duplicates=False, + index_kwargs=None, + ): """ If you pass a redundant kwarg it will just be overwritten maybe should only raise a warning rather than an error? need to implement matplotlib widgets also a big question is how to dynamically update the display of matplotlib widgets. + + Parameters + ---------- + index_kwargs : list of str or None + A list of which sliders should use an index for their callbacks. """ + if not index_kwargs: + index_kwargs = [] if isinstance(play_buttons, bool) or isinstance(play_buttons, str) or play_buttons is None: _play_buttons = defaultdict(lambda: play_buttons) elif isinstance(play_buttons, defaultdict): @@ -85,76 +100,48 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli slider_formats = create_slider_format_dict(slider_formats) for k, v in slider_formats.items(): self.slider_format_strings[k] = v - if self.use_ipywidgets: - for k, v in kwargs.items(): - if k in self.params: - if allow_duplicates: - continue - else: - raise ValueError("can't overwrite an existing param in the controller") - if isinstance(v, AxesWidget): - self.params[k], self.controls[k], _ = process_mpl_widget( - v, partial(self.slider_updated, key=k) - ) + axes, fig = maybe_create_mpl_controls_axes(kwargs) + if fig is not None: + self.control_figures.append((fig)) + for k, v in kwargs.items(): + if k in self.params: + if allow_duplicates: + continue else: - self.params[k], control = kwarg_to_ipywidget( - k, - v, - partial(self.slider_updated, key=k), - self.slider_format_strings[k], - play_button=_play_buttons[k], - ) - if control: - self.controls[k] = control - self.vbox.children = list(self.vbox.children) + [control] - if k == "vmin_vmax": - self.params["vmin"] = self.params["vmin_vmax"][0] - self.params["vmax"] = self.params["vmin_vmax"][1] - else: - if len(kwargs) > 0: - mpl_layout = create_mpl_controls_fig(kwargs) - self.control_figures.append(mpl_layout[0]) - widget_y = 0.05 - for k, v in kwargs.items(): - if k in self.params: - if allow_duplicates: - continue - else: - raise ValueError("Can't overwrite an existing param in the controller") - self.params[k], control, cb, widget_y = kwarg_to_mpl_widget( - mpl_layout[0], - mpl_layout[1:], - widget_y, - k, - v, - partial(self.slider_updated, key=k), - self.slider_format_strings[k], - ) - if control: - self.controls[k] = control - if k == "vmin_vmax": - self.params["vmin"] = self.params["vmin_vmax"][0] - self.params["vmax"] = self.params["vmin_vmax"][1] - - def _slider_updated(self, change, key, values): + raise ValueError("can't overwrite an existing param in the controller") + # TODO: accept existing mpl widget + # if isinstance(v, AxesWidget): + # self.params[k], self.controls[k], _ = process_mpl_widget( + # v, partial(self.slider_updated, key=k) + # ) + # else: + ax = axes.pop() + control = kwarg_to_widget(k, v, ax, play_button=_play_buttons[k]) + if k in index_kwargs: + self.params[k] = control.index + control.observe(partial(self._slider_updated, key=k), names="index") + else: + self.params[k] = control.value + control.observe(partial(self._slider_updated, key=k), names="value") + + if control: + self.controls[k] = control + if ax is None: + self.vbox.children = list(self.vbox.children) + [ + control._get_widget_for_display() + ] + if k == "vmin_vmax": + self.params["vmin"] = self.params["vmin_vmax"][0] + self.params["vmax"] = self.params["vmin_vmax"][1] + + def _slider_updated(self, change, key): """ gotta also give the indices in order to support hyperslicer without horrifying contortions """ - if values is None: - self.params[key] = change["new"] - else: - c = change["new"] - if isinstance(c, tuple): - # This is for range sliders which return 2 indices - self.params[key] = values[[*change["new"]]] - if key == "vmin_vmax": - self.params["vmin"] = self.params[key][0] - self.params["vmax"] = self.params[key][1] - else: - # int casting due to a bug in numpy < 1.19 - # see https://github.com/ianhi/mpl-interactions/pull/155 - self.params[key] = values[int(change["new"])] - self.indices[key] = change["new"] + self.params[key] = change["new"] + if key == "vmin_vmax": + self.params["vmin"] = self.params[key][0] + self.params["vmax"] = self.params[key][1] if self.use_cache: cache = {} else: @@ -162,14 +149,12 @@ def _slider_updated(self, change, key, values): for f, params in self._update_funcs[key]: ps = {} - idxs = {} for k in params: ps[k] = self.params[k] - idxs[k] = self.indices[k] - f(params=ps, indices=idxs, cache=cache) + f(params=ps, cache=cache) + # TODO: see if can combine these with update_funcs for only one loop for f, params in self._user_callbacks[key]: f(**{key: self.params[key] for key in params}) - for f in self.figs[key]: f.canvas.draw_idle() @@ -255,7 +240,7 @@ def save_animation( fig : figure param : str the name of the kwarg to use to animate - interval : int, default: 2o + interval : int, default: 20 interval between frames in ms func_anim_kwargs : dict kwargs to pass the creation of the underlying FuncAnimation diff --git a/mpl_interactions/helpers.py b/mpl_interactions/helpers.py index 80b15cee..354ce6e2 100644 --- a/mpl_interactions/helpers.py +++ b/mpl_interactions/helpers.py @@ -1,6 +1,7 @@ from collections import defaultdict from collections.abc import Callable, Iterable from functools import partial +from .widgets import IndexSlider from numbers import Number import matplotlib.widgets as mwidgets @@ -12,7 +13,7 @@ except ImportError: pass from matplotlib import get_backend -from matplotlib.pyplot import axes, gca, gcf, figure +from matplotlib.pyplot import gca, gcf, figure from numpy.distutils.misc_util import is_sequence from .widgets import RangeSlider @@ -37,9 +38,10 @@ "create_slider_format_dict", "gogogo_figure", "gogogo_display", - "create_mpl_controls_fig", + "maybe_create_mpl_controls_axes", "eval_xy", "choose_fmt_str", + "kwarg_to_widget", ] @@ -256,6 +258,106 @@ def eval_xy(x_, y_, params, cache=None): return np.asanyarray(x), np.asanyarray(y) +def kwarg_to_widget(key, val, mpl_widget_ax=None): + """ + Parameters + ---------- + key : str + val : slider value specification + The value to be interpreted and possibly transformed into an ipywidget + mpl_widget_ax : matplotlib axis, optional + If given then create a matplotlib widget instead of an ipywidget. + + Returns + ------- + widget : + A widget that can be `observed` and will have a `.value` attribute + and a `.index` attribute if applicable. + """ + init_val = 0 + control = None + if isinstance(val, set): + if len(val) == 1: + val = val.pop() + if isinstance(val, tuple): + # want the categories to be ordered + pass + else: + # fixed parameter + # TODO: for mpl as well + return widgets.fixed(val) + else: + val = list(val) + + # TODO: categorical - Make wrappers here! + # if len(val) <= 3: + # selector = widgets.RadioButtons(options=val) + # else: + # selector = widgets.Select(options=val) + # selector.observe(partial(update, values=val), names="index") + # return val[0], selector + if isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed): + if not hasattr(val, "value"): + raise TypeError( + "widgets passed as parameters must have the `value` trait." + "But the widget passed for {key} does not have a `.value` attribute" + ) + if isinstance(val, widgets.fixed): + return val + # TODO: elif ( + # isinstance(val, widgets.Select) + # or isinstance(val, widgets.SelectionSlider) + # or isinstance(val, widgets.RadioButtons) + # ): + # # all the selection widget inherit a private _Selection :( + # # it looks unlikely to change but still would be nice to just check + # # if its a subclass + # return val + # # val.observe(partial(update, values=val.options), names="index") + # else: + # # set values to None and hope for the best + # val.observe(partial(update, values=None), names="value") + # return val.value, val + # # val.observe(partial(update, key=key, label=None), names=["value"]) + else: + # TODO: Range sliders + # if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]: + # # also check for some reasonably easy mispellings + # if isinstance(val[1], (np.ndarray, list)): + # vals = val[1] + # else: + # vals = np.linspace(*val[1:]) + # label = widgets.Label(value=str(vals[0])) + # slider = widgets.IntRangeSlider( + # value=(0, vals.size - 1), min=0, max=vals.size - 1, readout=False, description=key + # ) + # widgets.dlink( + # (slider, "value"), + # (label, "value"), + # transform=lambda x: slider_format_string.format(vals[x[0]]) + # + " - " + # + slider_format_string.format(vals[x[1]]), + # ) + # slider.observe(partial(update, values=vals), names="value") + # controls = widgets.HBox([slider, label]) + # return vals[[0, -1]], controls + + if isinstance(val, tuple) and len(val) in [2, 3]: + # treat as an argument to linspace + # idk if it's acceptable to overwrite kwargs like this + # but I think at this point kwargs is just a dict like any other + val = np.linspace(*val) + val = np.atleast_1d(val) + if val.ndim > 1: + raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") + if len(val) == 1: + # don't need to create a slider + # TODO: make fixed available for mpl as well. + return widgets.fixed(val) + else: + return IndexSlider(val, key, mpl_widget_ax) + + def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None): """ Parameters @@ -360,24 +462,25 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) return val, None else: # params[key] = val[0] - label = widgets.Label(value=slider_format_string.format(val[0])) - slider = widgets.IntSlider(min=0, max=val.size - 1, readout=False, description=key) - widgets.dlink( - (slider, "value"), - (label, "value"), - transform=lambda x: slider_format_string.format(val[x]), - ) - slider.observe(partial(update, values=val), names="value") - if play_button is not None and play_button is not False: - play = widgets.Play(min=0, max=val.size - 1, step=1) - widgets.jslink((play, "value"), (slider, "value")) - if isinstance(play_button, str) and play_button.lower() == "right": - control = widgets.HBox([slider, label, play]) - else: - control = widgets.HBox([play, slider, label]) - else: - control = widgets.HBox([slider, label]) - return val[0], control + slider = IndexSlider(val, key) + # label = widgets.Label(value=slider_format_string.format(val[0])) + # slider = widgets.IntSlider(min=0, max=val.size - 1, readout=False, description=key) + # widgets.dlink( + # (slider, "value"), + # (label, "value"), + # transform=lambda x: slider_format_string.format(val[x]), + # ) + slider.observe(partial(update, values=val), names="index") + # if play_button is not None and play_button is not False: + # play = widgets.Play(min=0, max=val.size - 1, step=1) + # widgets.jslink((play, "value"), (slider, "value")) + # if isinstance(play_button, str) and play_button.lower() == "right": + # control = widgets.HBox([slider, label, play]) + # else: + # control = widgets.HBox([play, slider, label]) + # else: + # control = widgets.HBox([slider, label]) + return val[0], slider._get_widget_for_display() def extract_num_options(val): @@ -420,7 +523,7 @@ def changeify_radio(val, labels, update): update({"new": labels.index(value)}) -def create_mpl_controls_fig(kwargs): +def maybe_create_mpl_controls_axes(kwargs): """ Returns ------- @@ -441,16 +544,20 @@ def create_mpl_controls_fig(kwargs): I think maybe the correct approach is to use transforms and actually specify things in inches - Ian 2020-09-27 """ - init_fig = gcf() n_opts = 0 n_radio = 0 n_sliders = 0 + order = [] + radio_info = [] for key, val in kwargs.items(): if isinstance(val, set): new_opts = extract_num_options(val) if new_opts > 0: n_radio += 1 n_opts += new_opts + order.append("radio") + longest_len = max(list(map(lambda x: len(list(x)), map(str, val)))) + radio_info.append((new_opts, longest_len)) elif ( not isinstance(val, mwidgets.AxesWidget) and not "ipywidgets" in str(val.__class__) # do this to avoid depending on ipywidgets @@ -458,7 +565,16 @@ def create_mpl_controls_fig(kwargs): and len(val) > 1 ): n_sliders += 1 + order.append("slider") + else: + order.append(None) + + if n_sliders == 0 and n_radio == 0: + # do we need to make anything? + # if no just return None for all the axes + return order, None + init_fig = gcf() # These are roughly the sizes used in the matplotlib widget tutorial # https://matplotlib.org/3.2.2/gallery/widgets/slider_demo.html#sphx-glr-gallery-widgets-slider-demo-py slider_in = 0.15 @@ -486,6 +602,23 @@ def create_mpl_controls_fig(kwargs): # reset the active figure - necessary to make legends behave as expected # maybe this should really be handled via axes? idk figure(init_fig.number) + widget_y = 0.05 + axes = [] + for i, o in enumerate(order): + if o == "slider": + axes.append(fig.add_axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height])) + widget_y += slider_height + gap_height + elif o == "radio": + n, longest_len = radio_info.pop() + width = max(0.15, 0.015 * longest_len) + axes.append( + fig.add_axes([0.2, 0.9 - widget_y - radio_height * n, width, radio_height * n]) + ) + widget_y += radio_height * n + gap_height + else: + axes.append(None) + return axes, fig + return fig, slider_height, radio_height, gap_height diff --git a/mpl_interactions/pyplot.py b/mpl_interactions/pyplot.py index c033c42e..c5583de0 100644 --- a/mpl_interactions/pyplot.py +++ b/mpl_interactions/pyplot.py @@ -155,8 +155,9 @@ def f(x, tau): controls, params = gogogo_controls( kwargs, controls, display_controls, slider_formats, play_buttons ) + print(params) - def update(params, indices, cache): + def update(params, cache): if x_and_y: x_, y_ = eval_xy(x, y, params, cache) # broadcast so that we can always index @@ -374,7 +375,7 @@ def f(loc, scale): pc = PatchCollection([]) ax.add_collection(pc, autolim=True) - def update(params, indices, cache): + def update(params, cache): arr_ = callable_else_value(arr, params, cache) new_x, new_y, new_patches = simple_hist(arr_, density=density, bins=bins, weights=weights) stretch(ax, new_x, new_y) @@ -501,7 +502,7 @@ def interactive_scatter( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) - def update(params, indices, cache): + def update(params, cache): if parametric: out = callable_else_value_no_cast(x, params) if not isinstance(out, tuple): @@ -702,7 +703,7 @@ def vmin(**kwargs): def vmax(**kwargs): return kwargs["vmax"] - def update(params, indices, cache): + def update(params, cache): if isinstance(X, Callable): # check this here to avoid setting the data if we don't need to # use the callable_else_value fxn to make use of easy caching @@ -822,7 +823,7 @@ def interactive_axhline( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) - def update(params, indices, cache): + def update(params, cache): y_ = callable_else_value(y, params, cache).item() line.set_ydata([y_, y_]) xmin_ = callable_else_value(xmin, params, cache).item() @@ -919,7 +920,7 @@ def interactive_axvline( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) - def update(params, indices, cache): + def update(params, cache): x_ = callable_else_value(x, params, cache).item() line.set_xdata([x_, x_]) ymin_ = callable_else_value(ymin, params, cache).item() @@ -1007,7 +1008,7 @@ def interactive_title( kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): ax.set_title( callable_else_value_no_cast(title, params, cache).format(**params), fontdict=fontdict, @@ -1094,7 +1095,7 @@ def interactive_xlabel( kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): ax.set_xlabel( callable_else_value_no_cast(xlabel, params, cache).format(**params), fontdict=fontdict, @@ -1179,7 +1180,7 @@ def interactive_ylabel( kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): ax.set_ylabel( callable_else_value_no_cast(ylabel, params, cache).format(**params), fontdict=fontdict, From faadeafc43706172c36e2501486280c6f4639359 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 14:31:53 -0400 Subject: [PATCH 04/15] fix logic around creating axes --- mpl_interactions/controller.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index 924c0e76..283ba242 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -50,7 +50,6 @@ def __init__( self.vbox = widgets.VBox([]) else: self.control_figures = [] # storage for figures made of matplotlib sliders - self.use_cache = use_cache self.kwargs = kwargs self.slider_format_strings = create_slider_format_dict(slider_formats) @@ -69,6 +68,7 @@ def add_kwargs( play_buttons=None, allow_duplicates=False, index_kwargs=None, + use_ipywidgets=True, ): """ If you pass a redundant kwarg it will just be overwritten @@ -100,9 +100,14 @@ def add_kwargs( slider_formats = create_slider_format_dict(slider_formats) for k, v in slider_formats.items(): self.slider_format_strings[k] = v - axes, fig = maybe_create_mpl_controls_axes(kwargs) - if fig is not None: - self.control_figures.append((fig)) + + if not use_ipywidgets: + axes, fig = maybe_create_mpl_controls_axes(kwargs) + if fig is not None: + self.control_figures.append((fig)) + else: + axes = [None] * len(kwargs) + for k, v in kwargs.items(): if k in self.params: if allow_duplicates: From 60ee8f9f3ecad784225a173ada5b2f7b33fce550 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 14:51:56 -0400 Subject: [PATCH 05/15] add play button support --- mpl_interactions/helpers.py | 6 ++++-- mpl_interactions/widgets.py | 21 +++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mpl_interactions/helpers.py b/mpl_interactions/helpers.py index 354ce6e2..d7e00e35 100644 --- a/mpl_interactions/helpers.py +++ b/mpl_interactions/helpers.py @@ -258,7 +258,7 @@ def eval_xy(x_, y_, params, cache=None): return np.asanyarray(x), np.asanyarray(y) -def kwarg_to_widget(key, val, mpl_widget_ax=None): +def kwarg_to_widget(key, val, mpl_widget_ax=None, play_button=False): """ Parameters ---------- @@ -267,6 +267,8 @@ def kwarg_to_widget(key, val, mpl_widget_ax=None): The value to be interpreted and possibly transformed into an ipywidget mpl_widget_ax : matplotlib axis, optional If given then create a matplotlib widget instead of an ipywidget. + play_button : bool or "left" or "right" + Whether to create a play button and where to put it. Returns ------- @@ -355,7 +357,7 @@ def kwarg_to_widget(key, val, mpl_widget_ax=None): # TODO: make fixed available for mpl as well. return widgets.fixed(val) else: - return IndexSlider(val, key, mpl_widget_ax) + return IndexSlider(val, key, mpl_widget_ax, play_button=play_button) def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None): diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index b53dcad1..69ef1ffa 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -18,6 +18,8 @@ try: import ipywidgets as widgets + from ipywidgets.widgets.widget_link import jslink + from IPython.display import display except ImportError: widgets = None from matplotlib import widgets as mwidgets @@ -202,7 +204,6 @@ class SliderWrapper(HasTraits): max = Union([Int(), Float(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) value = Union([Float(), Int(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) step = Union([Int(), Float(allow_none=True)]) - index = Int(allow_none=True) label = Unicode() readout_format = Unicode("{:.2f}") @@ -249,7 +250,7 @@ def _ipython_display_(self): if self._mpl: pass else: - return self._get_widget_for_display() + display(self._get_widget_for_display()) def __dir__(self): # hide all the cruft from traitlets for shfit+Tab @@ -284,8 +285,6 @@ def __init__(self, values, label="", mpl_slider_ax=None, play_button=False): mpl_slider_ax : matplotlib.axes or None If *None* an ipywidgets slider will be created """ - if play_button is not False: - raise ValueError("play buttons not yet implemented fool!") self.values = np.atleast_1d(values) self.readout_format = "{:.2f}" if mpl_slider_ax is not None: @@ -315,6 +314,15 @@ def onchange(val): (self._readout, "value"), transform=lambda x: self.readout_format.format(self.values[x]), ) + self._play_button = None + if play_button: + self._play_button = widgets.Play(step=1) + self._play_button_on_left = not ( + isinstance(play_button, str) and play_button == "right" + ) + jslink((slider, "value"), (self._play_button, "value")) + jslink((slider, "min"), (self._play_button, "min")) + jslink((slider, "max"), (self._play_button, "max")) link((slider, "value"), (self, "index")) link((slider, "max"), (self, "max_index")) else: @@ -323,6 +331,11 @@ def onchange(val): self.value = self.values[self.index] def _get_widget_for_display(self): + if self._play_button: + if self._play_button_on_left: + return widgets.HBox([self._play_button, self._raw_slider, self._readout]) + else: + return widgets.HBox([self._raw_slider, self._readout, self._play_button]) return widgets.HBox([self._raw_slider, self._readout]) @validate("value") From 1908d97acd6d60e8cc3e53afdc4de830d4980f85 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 15:44:56 -0400 Subject: [PATCH 06/15] backfill: add missing import --- mpl_interactions/_widget_backfill.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mpl_interactions/_widget_backfill.py b/mpl_interactions/_widget_backfill.py index 87bf5c94..5e0e5ffd 100644 --- a/mpl_interactions/_widget_backfill.py +++ b/mpl_interactions/_widget_backfill.py @@ -3,6 +3,7 @@ """ from matplotlib.widgets import AxesWidget from matplotlib import cbook, ticker +import numpy as np # slider widgets are taken almost verbatim from https://github.com/matplotlib/matplotlib/pull/18829/files # which was written by me - but incorporates much of the existing matplotlib slider infrastructure From 2b791d8cf867c16647758310379f70e5ab21b000 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 15:45:49 -0400 Subject: [PATCH 07/15] Add formatting to IndexSliders. Default to using mpl's ScalarFormatter --- mpl_interactions/widgets.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index 69ef1ffa..8647e3f6 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -1,4 +1,5 @@ import numpy as np +from numbers import Number from traitlets import ( HasTraits, @@ -25,6 +26,7 @@ from matplotlib import widgets as mwidgets from matplotlib.cbook import CallbackRegistry from matplotlib.widgets import AxesWidget +from matplotlib.ticker import ScalarFormatter try: from matplotlib.widgets import RangeSlider, SliderBase @@ -205,7 +207,6 @@ class SliderWrapper(HasTraits): value = Union([Float(), Int(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) step = Union([Int(), Float(allow_none=True)]) label = Unicode() - readout_format = Unicode("{:.2f}") def __init__(self, slider, readout_format=None, setup_value_callbacks=True): super().__init__() @@ -274,7 +275,9 @@ class IndexSlider(IntSlider): values = Array() # gotta make values traitlike - traittypes? - def __init__(self, values, label="", mpl_slider_ax=None, play_button=False): + def __init__( + self, values, label="", mpl_slider_ax=None, readout_format=None, play_button=False + ): """ Parameters ---------- @@ -286,7 +289,9 @@ def __init__(self, values, label="", mpl_slider_ax=None, play_button=False): If *None* an ipywidgets slider will be created """ self.values = np.atleast_1d(values) - self.readout_format = "{:.2f}" + self.readout_format = readout_format + self._scalar_formatter = ScalarFormatter(useOffset=False) + self._scalar_formatter.create_dummy_axis() if mpl_slider_ax is not None: # make mpl_slider slider = mwidgets.Slider( @@ -300,7 +305,7 @@ def __init__(self, values, label="", mpl_slider_ax=None, play_button=False): def onchange(val): self.index = int(val) - slider.valtext.set_text(self.readout_format.format(self.values[int(val)])) + slider.valtext.set_text(self._format_value(self.values[int(val)])) slider.on_changed(onchange) self.values @@ -312,7 +317,7 @@ def onchange(val): widgets.dlink( (slider, "value"), (self._readout, "value"), - transform=lambda x: self.readout_format.format(self.values[x]), + transform=lambda x: self._format_value(self.values[x]), ) self._play_button = None if play_button: @@ -330,6 +335,14 @@ def onchange(val): super().__init__(slider, setup_value_callbacks=False) self.value = self.values[self.index] + def _format_value(self, value): + if self.readout_format is None: + if isinstance(value, Number): + return self._scalar_formatter.format_data_short(value) + else: + return str(value) + return self.readout_format.format(value) + def _get_widget_for_display(self): if self._play_button: if self._play_button_on_left: @@ -346,8 +359,7 @@ def _validate_value(self, proposal): " To see or change the set of valid values use the `.values` attribute" ) # call `int` because traitlets can't handle np int64 - index = int(np.argmin(np.abs(self.values - proposal["value"]))) - self.index = index + self.index = int(np.where(self.values == proposal["value"])[0][0]) return proposal["value"] From 40a7184304a7f49e5cd4f7e75dfe1cc8953dd938 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 15:48:16 -0400 Subject: [PATCH 08/15] remove errant print --- mpl_interactions/pyplot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mpl_interactions/pyplot.py b/mpl_interactions/pyplot.py index c5583de0..df91e705 100644 --- a/mpl_interactions/pyplot.py +++ b/mpl_interactions/pyplot.py @@ -155,7 +155,6 @@ def f(x, tau): controls, params = gogogo_controls( kwargs, controls, display_controls, slider_formats, play_buttons ) - print(params) def update(params, cache): if x_and_y: From 4205f7c2bd708a89f71f1a492bc6363f61d4e86a Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Tue, 18 May 2021 15:52:00 -0400 Subject: [PATCH 09/15] play buttons error for mpl --- mpl_interactions/widgets.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index 8647e3f6..d16ad98c 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -294,6 +294,11 @@ def __init__( self._scalar_formatter.create_dummy_axis() if mpl_slider_ax is not None: # make mpl_slider + if play_button: + raise ValueError( + "Play Buttons not yet available for matplotlib sliders " + "see https://github.com/ianhi/mpl-interactions/issues/144" + ) slider = mwidgets.Slider( mpl_slider_ax, label=label, From 949925b9fdf8b9add6c247b4851002e7034e0bf8 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 19 May 2021 15:57:21 -0400 Subject: [PATCH 10/15] widgets: restructure inheritance --- mpl_interactions/widgets.py | 205 +++++++++++++++++++++++------------- 1 file changed, 133 insertions(+), 72 deletions(-) diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index d16ad98c..cdf73927 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -32,6 +32,7 @@ from matplotlib.widgets import RangeSlider, SliderBase except ImportError: from ._widget_backfill import RangeSlider, SliderBase +import matplotlib.widgets as mwidgets __all__ = [ "scatter_selector", @@ -41,6 +42,7 @@ "SliderWrapper", "IntSlider", "IndexSlider", + "CategoricalWrapper", ] @@ -82,6 +84,7 @@ def __init__(self, ax, x, y, pickradius=5, which_button=1, **kwargs): def _init_val(self): self.val = (0, (self._x[0], self._y[0])) + self.value = (0, (self._x[0], self._y[0])) def _on_pick(self, event): if event.mouseevent.button == self._button: @@ -90,7 +93,7 @@ def _on_pick(self, event): y = self._y[idx] self._process(idx, (x, y)) - def _process(idx, val): + def _process(self, idx, val): self._observers.process("picked", idx, val) def on_changed(self, func): @@ -119,6 +122,7 @@ class scatter_selector_index(scatter_selector): def _init_val(self): self.val = 0 + self.value = 0 def _process(self, idx, val): self._observers.process("picked", idx) @@ -150,6 +154,7 @@ class scatter_selector_value(scatter_selector): def _init_val(self): self.val = (self._x[0], self._y[0]) + self.value = (self._x[0], self._y[0]) def _process(self, idx, val): self._observers.process("picked", val) @@ -196,7 +201,43 @@ def on_changed(self, func): ] -class SliderWrapper(HasTraits): +class HasTraitsSmallShiftTab(HasTraits): + def __dir__(self): + # hide all the cruft from traitlets for shift+Tab + return [i for i in super().__dir__() if i not in _gross_traits] + + +class WidgetWrapper(HasTraitsSmallShiftTab): + value = Any() + + def __init__(self, mpl_widget, **kwargs) -> None: + super().__init__(self, **kwargs) + self._mpl = mpl_widget + self._callbacks = [] + + def on_changed(self, callback): + # callback registry? + self._callbacks.append(callback) + + def _get_widget_for_display(self): + if self._mpl: + return None + else: + return self._raw_widget + + def _ipython_display_(self): + if self._mpl: + pass + else: + display(self._get_widget_for_display()) + + @observe("value") + def _on_changed(self, change): + for c in self._callbacks: + c(change["new"]) + + +class SliderWrapper(WidgetWrapper): """ A warpper class that provides a consistent interface for both ipywidgets and matplotlib sliders. @@ -208,24 +249,25 @@ class SliderWrapper(HasTraits): step = Union([Int(), Float(allow_none=True)]) label = Unicode() - def __init__(self, slider, readout_format=None, setup_value_callbacks=True): - super().__init__() - self._raw_slider = slider + def __init__(self, slider, readout_format=None, setup_value_callbacks=True, **kwargs): + self._mpl = isinstance(slider, (mwidgets.Slider, SliderBase)) + super().__init__(self, **kwargs) + self._raw_widget = slider + # eventually we can just rely on SliderBase here # for now keep both for compatibility with mpl < 3.4 - self._mpl = isinstance(slider, (mwidgets.Slider, SliderBase)) if self._mpl: - self.observe(lambda change: setattr(self._raw_slider, "valmin", change["new"]), "min") - self.observe(lambda change: setattr(self._raw_slider, "valmax", change["new"]), "max") - self.observe(lambda change: self._raw_slider.label.set_text(change["new"]), "label") + self.observe(lambda change: setattr(self._raw_widget, "valmin", change["new"]), "min") + self.observe(lambda change: setattr(self._raw_widget, "valmax", change["new"]), "max") + self.observe(lambda change: self._raw_widget.label.set_text(change["new"]), "label") if setup_value_callbacks: - self.observe(lambda change: self._raw_slider.set_val(change["new"]), "value") - self._raw_slider.on_changed(lambda val: setattr(self, "value", val)) - self.value = self._raw_slider.val - self.min = self._raw_slider.valmin - self.max = self._raw_slider.valmax - self.step = self._raw_slider.valstep - self.label = self._raw_slider.label.get_text() + self.observe(lambda change: self._raw_widget.set_val(change["new"]), "value") + self._raw_widget.on_changed(lambda val: setattr(self, "value", val)) + self.value = self._raw_widget.val + self.min = self._raw_widget.valmin + self.max = self._raw_widget.valmax + self.step = self._raw_widget.valstep + self.label = self._raw_widget.label.get_text() else: if setup_value_callbacks: link((slider, "value"), (self, "value")) @@ -233,48 +275,55 @@ def __init__(self, slider, readout_format=None, setup_value_callbacks=True): link((slider, "max"), (self, "max")) link((slider, "step"), (self, "step")) link((slider, "description"), (self, "label")) - self._callbacks = [] - @observe("value") - def _on_changed(self, change): - for c in self._callbacks: - c(change["new"]) - def on_changed(self, callback): - # callback registry? - self._callbacks.append(callback) +class IntSlider(SliderWrapper): + min = Int() + max = Int() + value = Int() - def _get_widget_for_display(self): - return self._raw_slider - def _ipython_display_(self): - if self._mpl: - pass - else: - display(self._get_widget_for_display()) +class SelectionWrapper(WidgetWrapper): + index = Int() + values = Array() + max_index = Int() - def __dir__(self): - # hide all the cruft from traitlets for shfit+Tab - return [i for i in super().__dir__() if i not in _gross_traits] + def __init__(self, values, mpl_ax=None, **kwargs) -> None: + super().__init__(mpl_ax is not None, **kwargs) + self.values = values + self.value = self.values[self.index] + @validate("value") + def _validate_value(self, proposal): + if not proposal["value"] in self.values: + raise TraitError( + f"{proposal['value']} is not in the set of values for this index slider." + " To see or change the set of valid values use the `.values` attribute" + ) + # call `int` because traitlets can't handle np int64 + self.index = int(np.where(self.values == proposal["value"])[0][0]) -class IntSlider(SliderWrapper): - min = Int() - max = Int() - value = Int() + return proposal["value"] + + @observe("index") + def _obs_index(self, change): + # call .item because traitlets is unhappy with numpy types + self.value = self.values[change["new"]].item() + + @validate("values") + def _validate_values(self, proposal): + values = proposal["value"] + if values.ndim > 1: + raise TraitError("Expected 1d array but got an array with shape %s" % (values.shape)) + self.max_index = values.shape[0] + return values -class IndexSlider(IntSlider): +class IndexSlider(SelectionWrapper): """ A slider class to index through an array of values. """ - index = Int() - max_index = Int() - value = Any() - values = Array() - # gotta make values traitlike - traittypes? - def __init__( self, values, label="", mpl_slider_ax=None, readout_format=None, play_button=False ): @@ -288,6 +337,7 @@ def __init__( mpl_slider_ax : matplotlib.axes or None If *None* an ipywidgets slider will be created """ + super().__init__(values, mpl_ax=mpl_slider_ax) self.values = np.atleast_1d(values) self.readout_format = readout_format self._scalar_formatter = ScalarFormatter(useOffset=False) @@ -313,8 +363,8 @@ def onchange(val): slider.valtext.set_text(self._format_value(self.values[int(val)])) slider.on_changed(onchange) - self.values elif widgets: + # i've basically recreated the ipywidgets.SelectionSlider here. slider = widgets.IntSlider( 0, 0, self.values.shape[0] - 1, step=1, readout=False, description=label ) @@ -337,8 +387,7 @@ def onchange(val): link((slider, "max"), (self, "max_index")) else: raise ValueError("mpl_slider_ax cannot be None if ipywidgets is not available") - super().__init__(slider, setup_value_callbacks=False) - self.value = self.values[self.index] + self._raw_widget = slider def _format_value(self, value): if self.readout_format is None: @@ -349,34 +398,46 @@ def _format_value(self, value): return self.readout_format.format(value) def _get_widget_for_display(self): + if self._mpl: + return None if self._play_button: if self._play_button_on_left: - return widgets.HBox([self._play_button, self._raw_slider, self._readout]) + return widgets.HBox([self._play_button, self._raw_widget, self._readout]) else: - return widgets.HBox([self._raw_slider, self._readout, self._play_button]) - return widgets.HBox([self._raw_slider, self._readout]) + return widgets.HBox([self._raw_widget, self._readout, self._play_button]) + return widgets.HBox([self._raw_widget, self._readout]) - @validate("value") - def _validate_value(self, proposal): - if not proposal["value"] in self.values: - raise TraitError( - f"{proposal['value']} is not in the set of values for this index slider." - " To see or change the set of valid values use the `.values` attribute" - ) - # call `int` because traitlets can't handle np int64 - self.index = int(np.where(self.values == proposal["value"])[0][0]) - return proposal["value"] +# A vendored version of ipywidgets.fixed - included so don't need to depend on ipywidgets +# https://github.com/jupyter-widgets/ipywidgets/blob/e0d41f6f02324596a282bc9e4650fd7ba63c0004/ipywidgets/widgets/interaction.py#L546 +class fixed(HasTraitsSmallShiftTab): + """A pseudo-widget whose value is fixed and never synced to the client.""" - @observe("index") - def _obs_index(self, change): - # call .item because traitlets is unhappy with numpy types - self.value = self.values[change["new"]].item() + value = Any(help="Any Python object") + description = Unicode("", help="Any Python object") - @validate("values") - def _validate_values(self, proposal): - values = proposal["value"] - if values.ndim > 1: - raise TraitError("Expected 1d array but got an array with shape %s" % (values.shape)) - self.max_index = values.shape[0] - return values + def __init__(self, value, **kwargs): + super().__init__(value=value, **kwargs) + + def get_interact_value(self): + """Return the value for this widget which should be passed to + interactive functions. Custom widgets can change this method + to process the raw value ``self.value``. + """ + return self.value + + +class CategoricalWrapper(SelectionWrapper): + def __init__(self, values, mpl_ax=None, **kwargs): + super().__init__(values, mpl_ax=mpl_ax, **kwargs) + + if mpl_ax is not None: + self._raw_widget = mwidgets.RadioButtons(mpl_ax, values) + + def on_changed(label): + self.index = self._raw_widget.active + + self._raw_widget.on_changed(on_changed) + else: + self._raw_widget = widgets.Select(options=values) + link((self._raw_widget, "index"), (self, "index")) From 9d3314e63cae8f5dddb54ea681a11dc7e519025d Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 19 May 2021 15:57:55 -0400 Subject: [PATCH 11/15] Accept more types of widgets --- mpl_interactions/controller.py | 35 +++++++++++++------ mpl_interactions/helpers.py | 62 ++++++++++++++++++++++++++++------ 2 files changed, 76 insertions(+), 21 deletions(-) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index 283ba242..f26f0650 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -7,14 +7,13 @@ _not_ipython = True pass from collections import defaultdict + from .helpers import ( create_slider_format_dict, - kwarg_to_ipywidget, - kwarg_to_mpl_widget, maybe_create_mpl_controls_axes, kwarg_to_widget, + maybe_get_widget_for_display, notebook_backend, - process_mpl_widget, ) from functools import partial from collections.abc import Iterable @@ -50,6 +49,8 @@ def __init__( self.vbox = widgets.VBox([]) else: self.control_figures = [] # storage for figures made of matplotlib sliders + if widgets: + self.vbox = widgets.VBox([]) self.use_cache = use_cache self.kwargs = kwargs self.slider_format_strings = create_slider_format_dict(slider_formats) @@ -68,7 +69,6 @@ def add_kwargs( play_buttons=None, allow_duplicates=False, index_kwargs=None, - use_ipywidgets=True, ): """ If you pass a redundant kwarg it will just be overwritten @@ -101,7 +101,7 @@ def add_kwargs( for k, v in slider_formats.items(): self.slider_format_strings[k] = v - if not use_ipywidgets: + if not self.use_ipywidgets: axes, fig = maybe_create_mpl_controls_axes(kwargs) if fig is not None: self.control_figures.append((fig)) @@ -122,23 +122,38 @@ def add_kwargs( # else: ax = axes.pop() control = kwarg_to_widget(k, v, ax, play_button=_play_buttons[k]) + # TODO: make the try except silliness less ugly + # the complexity of hiding away the val vs value vs whatever needs to + # be hidden away somewhere - but probably not here if k in index_kwargs: self.params[k] = control.index - control.observe(partial(self._slider_updated, key=k), names="index") + try: + control.observe(partial(self._slider_updated, key=k), names="index") + except AttributeError: + self._setup_mpl_widget_callback(control, k) else: self.params[k] = control.value - control.observe(partial(self._slider_updated, key=k), names="value") + try: + control.observe(partial(self._slider_updated, key=k), names="value") + except AttributeError: + self._setup_mpl_widget_callback(control, k) if control: self.controls[k] = control if ax is None: - self.vbox.children = list(self.vbox.children) + [ - control._get_widget_for_display() - ] + disp = maybe_get_widget_for_display(control) + if disp is not None: + self.vbox.children = list(self.vbox.children) + [disp] if k == "vmin_vmax": self.params["vmin"] = self.params["vmin_vmax"][0] self.params["vmax"] = self.params["vmin_vmax"][1] + def _setup_mpl_widget_callback(self, widget, key): + def on_changed(val): + self._slider_updated({"new": val}, key=key) + + widget.on_changed(on_changed) + def _slider_updated(self, change, key): """ gotta also give the indices in order to support hyperslicer without horrifying contortions diff --git a/mpl_interactions/helpers.py b/mpl_interactions/helpers.py index d7e00e35..7defc13e 100644 --- a/mpl_interactions/helpers.py +++ b/mpl_interactions/helpers.py @@ -1,7 +1,9 @@ from collections import defaultdict from collections.abc import Callable, Iterable from functools import partial -from .widgets import IndexSlider + +from ipywidgets.widgets.widget_float import FloatLogSlider +from .widgets import CategoricalWrapper, IndexSlider, WidgetWrapper, scatter_selector from numbers import Number import matplotlib.widgets as mwidgets @@ -11,14 +13,29 @@ import ipywidgets as widgets from IPython.display import display as ipy_display except ImportError: - pass + widgets = None from matplotlib import get_backend from matplotlib.pyplot import gca, gcf, figure from numpy.distutils.misc_util import is_sequence -from .widgets import RangeSlider +try: + from matplotlib.widgets import RangeSlider, SliderBase +except ImportError: + from ._widget_backfill import RangeSlider, SliderBase +from .widgets import RangeSlider, fixed, SliderWrapper from .utils import ioff +if widgets: + _slider_types = ( + mwidgets.Slider, + widgets.IntSlider, + widgets.FloatSlider, + widgets.FloatLogSlider, + ) + # _categorical_types = (mwidgets.RadioButtons, widgets.RadioButtons, widgets.FloatSlider, widgets.FloatLogSlider) +else: + _slider_types = mwidgets.Slider + __all__ = [ "decompose_bbox", "update_datalim_from_xy", @@ -39,6 +56,7 @@ "gogogo_figure", "gogogo_display", "maybe_create_mpl_controls_axes", + "maybe_get_widget_for_display", "eval_xy", "choose_fmt_str", "kwarg_to_widget", @@ -287,25 +305,37 @@ def kwarg_to_widget(key, val, mpl_widget_ax=None, play_button=False): else: # fixed parameter # TODO: for mpl as well - return widgets.fixed(val) + return fixed(val) else: val = list(val) - # TODO: categorical - Make wrappers here! + return CategoricalWrapper(val, mpl_widget_ax) + # # TODO: categorical - Make wrappers here! # if len(val) <= 3: # selector = widgets.RadioButtons(options=val) # else: # selector = widgets.Select(options=val) # selector.observe(partial(update, values=val), names="index") # return val[0], selector - if isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed): + if isinstance(val, WidgetWrapper): + return val + elif isinstance(val, scatter_selector): + return val + elif isinstance(val, _slider_types): + return SliderWrapper(val) + # TODO: categorical types + # elif isinstance(val, _categorical_types): + # return CategoricalWrapper(val) + # TODO: add a _range_slider_types + elif widgets and isinstance(val, (widgets.Widget, widgets.fixed, fixed)): if not hasattr(val, "value"): raise TypeError( "widgets passed as parameters must have the `value` trait." "But the widget passed for {key} does not have a `.value` attribute" ) - if isinstance(val, widgets.fixed): - return val + return val + # if isinstance(val, widgets.fixed): + # return val # TODO: elif ( # isinstance(val, widgets.Select) # or isinstance(val, widgets.SelectionSlider) @@ -353,13 +383,23 @@ def kwarg_to_widget(key, val, mpl_widget_ax=None, play_button=False): if val.ndim > 1: raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") if len(val) == 1: - # don't need to create a slider - # TODO: make fixed available for mpl as well. - return widgets.fixed(val) + return fixed(val) else: return IndexSlider(val, key, mpl_widget_ax, play_button=play_button) +def maybe_get_widget_for_display(w): + """ + Check if an object can be included in an ipywidgets HBox and if so return + the approriate object + """ + if isinstance(w, WidgetWrapper): + return w._get_widget_for_display() + elif widgets and isinstance(w, widgets.Widget): + return w + return None + + def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None): """ Parameters From 49adef3e12dc37b2f8c3a53cab2b9ce52bfc13c6 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 19 May 2021 16:09:53 -0400 Subject: [PATCH 12/15] adapt hyperslicer to new callback system --- mpl_interactions/controller.py | 35 +++++++++++++++++++++++++++++----- mpl_interactions/generic.py | 4 +++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index f26f0650..3d13d8e3 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -30,6 +30,7 @@ def __init__( play_button_pos="right", use_ipywidgets=None, use_cache=True, + index_kwargs=[], **kwargs ): # it might make sense to also accept kwargs as a straight up arg @@ -60,7 +61,7 @@ def __init__( self.indices = defaultdict(lambda: 0) self._update_funcs = defaultdict(list) self._user_callbacks = defaultdict(list) - self.add_kwargs(kwargs, slider_formats, play_buttons) + self.add_kwargs(kwargs, slider_formats, play_buttons, index_kwargs=index_kwargs) def add_kwargs( self, @@ -349,6 +350,7 @@ def gogogo_controls( play_buttons, extra_controls=None, allow_dupes=False, + index_kwargs=[], ): if controls or (extra_controls and not all([e is None for e in extra_controls])): if extra_controls is not None: @@ -363,7 +365,13 @@ def gogogo_controls( # it was indexed by the user when passed in extra_keys = controls[1] controls = controls[0] - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs( + kwargs, + slider_formats, + play_buttons, + allow_duplicates=allow_dupes, + index_kwargs=index_kwargs, + ) params = {k: controls.params[k] for k in list(kwargs.keys()) + list(extra_keys)} elif isinstance(controls, list): # collected from extra controls @@ -382,14 +390,31 @@ def gogogo_controls( raise ValueError("Only one controls object may be used per function") # now we are garunteed to only have a single entry in controls, so it's ok to pop controls = controls.pop() - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs( + kwargs, + slider_formats, + play_buttons, + allow_duplicates=allow_dupes, + index_kwargs=index_kwargs, + ) params = {k: controls.params[k] for k in list(kwargs.keys()) + list(extra_keys)} else: - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs( + kwargs, + slider_formats, + play_buttons, + allow_duplicates=allow_dupes, + index_kwargs=index_kwargs, + ) params = controls.params return controls, params else: - controls = Controls(slider_formats=slider_formats, play_buttons=play_buttons, **kwargs) + controls = Controls( + slider_formats=slider_formats, + play_buttons=play_buttons, + index_kwargs=index_kwargs, + **kwargs + ) params = controls.params if display_controls: controls.display() diff --git a/mpl_interactions/generic.py b/mpl_interactions/generic.py index af91cc8c..09093868 100644 --- a/mpl_interactions/generic.py +++ b/mpl_interactions/generic.py @@ -688,6 +688,7 @@ def hyperslicer( play_buttons, extra_ctrls, allow_dupes=True, + index_kwargs=list(kwargs.keys()), ) if vmin_vmax is not None: params.pop("vmin_vmax") @@ -700,7 +701,8 @@ def vmin(**kwargs): def vmax(**kwargs): return kwargs["vmax"] - def update(params, indices, cache): + def update(params, cache): + indices = params if title is not None: ax.set_title(title.format(**params)) From fcacb227e30b44647d29588b28d4e6a46e7969c1 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 19 May 2021 16:11:09 -0400 Subject: [PATCH 13/15] add new dependencies --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index c86472b7..477b5fd0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,8 @@ platforms = Linux, Mac OS X, Windows python_requires = >=3.6, <3.10 install_requires = matplotlib >= 3.3 + traitlets + traittypes packages = find: [options.extras_require] From a9cc416bcc67bc7026d76929a0b7fb211a52a949 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 19 May 2021 16:15:30 -0400 Subject: [PATCH 14/15] widgetWrapper base: fix traitlets error --- mpl_interactions/widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index cdf73927..63e0316f 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -211,7 +211,7 @@ class WidgetWrapper(HasTraitsSmallShiftTab): value = Any() def __init__(self, mpl_widget, **kwargs) -> None: - super().__init__(self, **kwargs) + super().__init__(self) self._mpl = mpl_widget self._callbacks = [] From a5c4809273d8fe3ddc4834074440ba82b3d7d9b5 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Wed, 19 May 2021 16:28:50 -0400 Subject: [PATCH 15/15] update save animation for new sliders --- mpl_interactions/controller.py | 48 +++++++++++++++------------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index 3d13d8e3..7e2fbefc 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -7,6 +7,7 @@ _not_ipython = True pass from collections import defaultdict +from mpl_interactions.widgets import IndexSlider, SliderWrapper from .helpers import ( create_slider_format_dict, @@ -278,40 +279,33 @@ def save_animation( anim : matplotlib.animation.FuncAniation """ slider = self.controls[param] - ipywidgets_slider = False - if "Box" in str(slider.__class__): - for obj in slider.children: - if "Slider" in str(obj.__class__): - slider = obj - - if isinstance(slider, mSlider): - min_ = slider.valmin - max_ = slider.valmax - if slider.valstep is None: + # at this point every slider should be wrapped by at least a .widgets.WidgetWrapper + if isinstance(slider, IndexSlider): + N = len(slider.values) + + def f(i): + slider.index = i + return [] + + elif isinstance(slider, SliderWrapper): + min = slider.min + max = slider.max + if slider.step is None: n_steps = N_frames if N_frames else 200 - step = (max_ - min_) / n_steps + step = (max - min) / n_steps else: step = slider.valstep - elif "Slider" in str(slider.__class__): - ipywidgets_slider = True - min_ = slider.min - max_ = slider.max - step = slider.step + N = int((max - min) / step) + + def f(i): + slider.value = min + step * i + return [] + else: raise NotImplementedError( - "Cannot save animation for slider of type %s".format(slider.__class__.__name__) + "Cannot save animation for param of type %s".format(type(slider)) ) - N = int((max_ - min_) / step) - - def f(i): - val = min_ + step * i - if ipywidgets_slider: - slider.value = val - else: - slider.set_val(val) - return [] - repeat = func_anim_kwargs.pop("repeat", False) anim = FuncAnimation(fig, f, frames=N, interval=interval, repeat=repeat, **func_anim_kwargs) # draw then stop necessary to prevent an extra loop after finished saving