Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,22 +496,26 @@ def generate_curve_label(
self,
index_tuple: list[int],
axis_lookup: list[str],
) -> str:
*,
skip_text: bool = False,
) -> str | float:
"""Get a meaningful label for a subset of data.

Used when plotting 1D arrays out of a multidimensional array.

Parameters
----------
index_tuple : list[int]
indices of the 1D data array position in the ND array
indices of the 1D data array position in the ND array.
axis_lookup : list[str]
Names of the axes to use
Names of the axes to use.
skip_text : bool, optional
If set to true, omits the text parts of the label. By default False.

Returns
-------
str
A string label for the plot legend.
str | float
A string label for the plot legend or a number for Text plotter.

"""
if self._n_dim < 2:
Expand Down Expand Up @@ -540,13 +544,17 @@ def generate_curve_label(
picked_value = round(picked_value, 1)

label += f"{axis_label}={picked_value} {axis_unit}, "
if skip_text:
return float(picked_value)
return label.rstrip(", ")

def curves_vs_axis(
self,
x_axis_details: tuple[str, str],
max_limit: int = 1,
) -> list[np.ndarray]:
*,
skip_label_text: bool = False,
) -> dict[int, np.ndarray]:
"""Prepare a set of curves for plotting.

Parameters
Expand All @@ -555,10 +563,12 @@ def curves_vs_axis(
Name and original unit of the primary plotting axis
max_limit : int, optional
Maximum number of curves allowed by plotter, by default 1
skip_label_text: bool, optional
Whether to skip the axis name and unit in the curve label, by default False.

Returns
-------
list[np.ndarray]
dict[int, np.ndarray]
List of data arrays ready for plotting

"""
Expand Down Expand Up @@ -610,6 +620,7 @@ def curves_vs_axis(
self._curve_labels[index_tuple] = self.generate_curve_label(
index_tuple,
label_lookup,
skip_text=skip_label_text,
)

return self._curves
Expand Down
74 changes: 56 additions & 18 deletions MDANSE_GUI/Src/MDANSE_GUI/Tabs/Plotters/Text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
if TYPE_CHECKING:
from qtpy.QtWidgets import QTextBrowser

from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext, SingleDataset
from MDANSE_GUI.Tabs.Models.PlottingContext import (
PlotArgs,
PlottingContext,
SingleDataset,
)


class DatasetFormatter:
Expand Down Expand Up @@ -82,7 +86,7 @@ def take_new_input(self, pc: PlottingContext):

for databundle in self._plotting_context.datasets().values():
header, data = self.process_data(
databundle.dataset,
databundle,
main_axis=databundle.main_axis,
)
self._new_text.append(
Expand All @@ -97,7 +101,7 @@ def datasets_for_csv(self):
return ["No data selected"]

for databundle in self._plotting_context.datasets().values():
yield self.process_data(databundle.dataset, main_axis=databundle.main_axis)
yield self.process_data(databundle, main_axis=databundle.main_axis)

def make_dataset_header(self, dataset: SingleDataset, comment_character="#"):
"""Return the dataset informartion as text.
Expand Down Expand Up @@ -141,7 +145,12 @@ def join_for_gui(
data = data_array

text_data = "\n".join(
separator.join(str(round(x, self._rounding_prec)) for x in line)
separator.join(
str(round(x, self._rounding_prec))
if hasattr(x, "__round__")
else str(x)
for x in line
)
for line in data
)

Expand All @@ -154,21 +163,22 @@ def join_for_gui(

def process_data(
self,
dataset: SingleDataset,
databundle: PlotArgs,
main_axis: str | None = None,
):
"""Wrapper for approriately handling ND data."""
"""Wrapper for appropriately handling ND data."""
dataset = databundle.dataset

if dataset._n_dim == 1:
return self.process_1D_data(dataset)
return self.process_1D_data(databundle)
if dataset._n_dim == 2:
return self.process_2D_data(dataset, main_axis=main_axis)
return self.process_2D_data(databundle, main_axis=main_axis)

return self.process_ND_data(dataset)
return self.process_ND_data(databundle)

def process_1D_data(
self,
dataset: SingleDataset,
databundle: PlotArgs,
) -> tuple[list[str], Iterator[Iterator[float]]]:
"""Turn a 1D array into text.

Expand All @@ -178,8 +188,8 @@ def process_1D_data(

Parameters
----------
dataset : SingleDataset
A SingleDataset read from an .MDA file (HDF5).
dataset : PlotArgs
A SingleDataset with the GUI parameters.

Returns
-------
Expand All @@ -189,6 +199,7 @@ def process_1D_data(
A data table with 2 columns.

"""
dataset = databundle.dataset
header_lines, _ = self.make_dataset_header(
dataset, comment_character=self._comment
)
Expand All @@ -214,16 +225,16 @@ def process_1D_data(

def process_2D_data(
self,
dataset: SingleDataset,
databundle: PlotArgs,
*,
main_axis: str | None = None,
) -> tuple[list[str], Iterator[Iterator[float]]]:
"""Convert a 2D data array into text.

Parameters
----------
dataset : SingleDataset
A SingleDataset read from an .MDA file (HDF5).
dataset : PlotArgs
A SingleDataset with the GUI parameters.
main_axis : str or None
Main axis to plot.

Expand All @@ -243,6 +254,8 @@ def process_2D_data(
v

"""
dataset = databundle.dataset

header_lines, comment_char = self.make_dataset_header(
dataset, comment_character=self._comment
)
Expand All @@ -267,18 +280,43 @@ def process_2D_data(

LOG.debug(f"process_2D_data: axis {ax_key} has length {len(axis)}")

rc = "column" if n == flip_array else "row"
rc = "column" if n else "row"
header_lines.append(
f"{comment_char} first {rc} is {ax_key} in units {new_unit}"
)

LOG.debug(f"Data shape: {dataset._data.shape}")
try:
best_unit, best_axis = (
dataset._axes_units[databundle.main_axis],
databundle.main_axis,
)
except KeyError:
best_unit, best_axis = dataset.longest_axis()

if self._is_preview:
curves_limit = self._preview_columns if flip_array else self._preview_lines
else:
curves_limit = (
dataset._data.shape[0] if flip_array else dataset._data.shape[1]
)

multi_curves = np.vstack(
list(
dataset.curves_vs_axis(
(best_unit, best_axis), max_limit=curves_limit, skip_label_text=True
).values()
)
)
# Add corner nil
xaxis = prepend(0.0, new_axes[axis_numbers[1]].flat)
xaxis = prepend("_", new_axes[axis_numbers[flip_array]].flat)

# Add axes to data
data_lines = zip(new_axes[axis_numbers[0]].flat, dataset.data, strict=True)
data_lines = zip(
dataset._curve_labels.values(),
multi_curves,
strict=True,
)

# Put xaxis in
temp = prepend(xaxis, data_lines)
Expand Down
Loading