From 75ead9b6b5f3622d49f4af613e1e699cc76ae9ab Mon Sep 17 00:00:00 2001 From: dorzhey Date: Sun, 3 Aug 2025 18:56:24 -0400 Subject: [PATCH 1/4] Refactor BasePlot & subclasses per issue 3718 --- src/scanpy/plotting/_baseplot_class.py | 93 +++++++++++++++++++++----- src/scanpy/plotting/_dotplot.py | 73 +++++++------------- src/scanpy/plotting/_matrixplot.py | 38 +++-------- src/scanpy/plotting/_stacked_violin.py | 53 ++++++--------- 4 files changed, 131 insertions(+), 126 deletions(-) diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index b643510509..988544f511 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -13,13 +13,9 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _empty -from ._anndata import ( - VarGroups, - _plot_dendrogram, - _plot_var_groups_brackets, - _prepare_dataframe, - _reorder_categories_after_dendrogram, -) +from ..get._aggregated import aggregate +from ._anndata import (VarGroups, _plot_dendrogram, _plot_var_groups_brackets, + _prepare_dataframe, _reorder_categories_after_dendrogram) from ._utils import check_colornorm, make_grid_spec if TYPE_CHECKING: @@ -144,8 +140,9 @@ def __init__( # noqa: PLR0913 del var_group_labels, var_group_positions self.var_group_rotation = var_group_rotation self.width, self.height = figsize if figsize is not None else (None, None) - - self.categories, self.obs_tidy = _prepare_dataframe( + + # still need this as pandas handles this procedure more optimally + self.categories, obs_tidy = _prepare_dataframe( adata, self.var_names, groupby, @@ -155,6 +152,17 @@ def __init__( # noqa: PLR0913 layer=layer, gene_symbols=gene_symbols, ) + # we are going to save a view of adata as we still need it for filtering in dotplot by expression_cutoff and mean_only_expressed + # also AnnData is a little lighter than DataFrame + # and we can replace self.adata as it is used elsewhere + self._group_key = obs_tidy.index.name + self._view = AnnData( + X=obs_tidy.values, + obs=obs_tidy.index.to_frame(index=False), + var=pd.DataFrame(index=var_names), + ) + + if len(self.categories) > self.MAX_NUM_CATEGORIES: warn( f"Over {self.MAX_NUM_CATEGORIES} categories found. " @@ -164,16 +172,16 @@ def __init__( # noqa: PLR0913 ) if categories_order is not None and ( - set(self.obs_tidy.index.categories) != set(categories_order) + set(self.categories) != set(categories_order) ): logg.error( "Please check that the categories given by " "the `order` parameter match the categories that " "want to be reordered.\n\n" "Mismatch: " - f"{set(self.obs_tidy.index.categories).difference(categories_order)}\n\n" + f"{set(self.categories).difference(categories_order)}\n\n" f"Given order categories: {categories_order}\n\n" - f"{groupby} categories: {list(self.obs_tidy.index.categories)}\n" + f"{groupby} categories: {list(self.categories)}\n" ) return @@ -397,11 +405,11 @@ def add_totals( _sort = sort is not None _ascending = sort == "ascending" - counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) + counts_df = self._view.obs[self._group_key].value_counts(sort=_sort, ascending=_ascending) if _sort: - self.categories_order = counts_df.index - + self.categories_order = list(counts_df.index) + self.plot_group_extra = { "kind": "group_totals", "width": size, @@ -411,6 +419,61 @@ def add_totals( } return self + + def _agg_df(self, func, mask: np.ndarray | None = None) -> pd.DataFrame: + """ + Aggregate self._view by self._group_key, running `func` + (or list of funcs) on the X‐matrix. Returns a DataFrame + (or dict of DataFrames) with index=self.categories, columns=self.var_names. + If mask is provided, it should be shape (n_groups, n_vars) and will + overwrite view.X before aggregating (useful for dot‐cutoff logic). + """ + # make a fresh copy so we never mutate the master view + view = self._view.copy() + if mask is not None: + view.X = mask.astype(view.X.dtype) + + ag = aggregate( + view, + by=self._group_key, + func=func, + axis="obs", + ) + # if single func, return one DataFrame + if isinstance(func, str): + arr = ag.layers[func] + return pd.DataFrame(arr, index=self.categories, columns=self.var_names) + # if multiple, return a dict of DataFrames + out = {} + for f in func: + arr = ag.layers[f] + out[f] = pd.DataFrame(arr, index=self.categories, columns=self.var_names) + return out + + def _scale_df( + self, + standard_scale: Literal["var", "group"] | None = None, + df : pd.DataFrame | None = None + ): + """ + Performs scaling of `df` based on `standard_scale` parameter + """ + if standard_scale == "obs": + standard_scale = "group" + msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" + warnings.warn(msg, FutureWarning, stacklevel=2) + if standard_scale == "group": + df = df.sub(df.min(1), axis=0) + df = df.div(df.max(1), axis=0).fillna(0) + elif standard_scale == "var": + df -= df.min(0) + df = (df / df.max(0)).fillna(0) + elif standard_scale is None: + pass + else: + logg.warning("Unknown type for standard_scale, ignored") + return df + @old_positionals("cmap") def style(self, *, cmap: Colormap | str | None | Empty = _empty) -> Self: r"""Set visual style parameters. diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index 8c15574a1d..e7fb126a6a 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -9,15 +9,10 @@ from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params, _empty +from ..get._aggregated import aggregate from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._utils import ( - _dk, - check_colornorm, - fix_kwds, - make_grid_spec, - savefig_or_show, -) +from ._utils import _dk, check_colornorm, fix_kwds, make_grid_spec, savefig_or_show if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -187,46 +182,26 @@ def __init__( # noqa: PLR0913 norm=norm, **kwds, ) - - # for if category defined by groupby (if any) compute for each var_name - # 1. the fraction of cells in the category having a value >expression_cutoff - # 2. the mean value over the category - - # 1. compute fraction of cells having value > expression_cutoff - # transform obs_tidy into boolean matrix using the expression_cutoff - obs_bool = self.obs_tidy > expression_cutoff - - # compute the sum per group which in the boolean matrix this is the number - # of values >expression_cutoff, and divide the result by the total number of - # values in the group (given by `count()`) if dot_size_df is None: - dot_size_df = ( - obs_bool.groupby(level=0, observed=True).sum() - / obs_bool.groupby(level=0, observed=True).count() - ) + if expression_cutoff > 0: + mask = (self._view.X > expression_cutoff).astype(self._view.X.dtype) + dot_size_df = self._agg_df("mean", mask=mask) + else: + df_all = self._agg_df("count_nonzero") + # count_nonzero → raw counts, divide by group sizes + group_sizes = self._view.obs[self._group_key].value_counts().loc[self.categories].values + dot_size_df = df_all.div(group_sizes, axis=0) if dot_color_df is None: - # 2. compute mean expression value value - if mean_only_expressed: - dot_color_df = ( - self.obs_tidy.mask(~obs_bool) - .groupby(level=0, observed=True) - .mean() - .fillna(0) - ) - else: - dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() - - if standard_scale == "group": - dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0) - dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": - dot_color_df -= dot_color_df.min(0) - dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0) - elif standard_scale is None: - pass + if mean_only_expressed and expression_cutoff > 0: + mask = (self._view.X > expression_cutoff) + df_sum = self._agg_df("sum", mask=mask) + expr_counts = dot_size_df.values * group_sizes[:, None] + dot_color_df = df_sum.div(expr_counts).fillna(0) else: - logg.warning("Unknown type for standard_scale, ignored") + dot_color_df = self._agg_df("mean") + + dot_color_df = self._scale_df(standard_scale, dot_color_df) else: # check that both matrices have the same shape if dot_color_df.shape != dot_size_df.shape: @@ -255,12 +230,12 @@ def __init__( # noqa: PLR0913 # using the order from the doc_size_df dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns] - self.dot_color_df, self.dot_size_df = ( - df.loc[ - categories_order if categories_order is not None else self.categories - ] - for df in (dot_color_df, dot_size_df) - ) + + + # reorder rows + self.dot_size_df = dot_size_df.loc[categories_order if categories_order is not None else self.categories] + self.dot_color_df = dot_color_df.loc[categories_order if categories_order is not None else self.categories] + self.standard_scale = standard_scale # Set default style parameters diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index e767e224cc..22783761f2 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -10,11 +10,7 @@ from .._settings import settings from .._utils import _doc_params, _empty from ._baseplot_class import BasePlot, doc_common_groupby_plot_args -from ._docs import ( - doc_common_plot_args, - doc_show_save_ax, - doc_vboundnorm, -) +from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm from ._utils import _dk, check_colornorm, fix_kwds, savefig_or_show if TYPE_CHECKING: @@ -167,29 +163,15 @@ def __init__( # noqa: PLR0913 ) if values_df is None: - # compute mean value - values_df = ( - self.obs_tidy.groupby(level=0, observed=True) - .mean() - .loc[ - self.categories_order - if self.categories_order is not None - else self.categories - ] - ) - - if standard_scale == "group": - values_df = values_df.sub(values_df.min(1), axis=0) - values_df = values_df.div(values_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": - values_df -= values_df.min(0) - values_df = (values_df / values_df.max(0)).fillna(0) - elif standard_scale is None: - pass - else: - logg.warning("Unknown type for standard_scale, ignored") - - self.values_df = values_df + values_df = self._agg_df("mean") + + values_df = self._scale_df(standard_scale, values_df) + + self.values_df = values_df.loc[ + categories_order + if categories_order is not None + else self.categories + ] self.cmap = self.DEFAULT_COLORMAP self.edge_color = self.DEFAULT_EDGE_COLOR diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 404605b5a2..5f5ab9a9e3 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -15,13 +15,8 @@ from .._utils import _doc_params, _empty from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._utils import ( - _deprecated_scale, - _dk, - check_colornorm, - make_grid_spec, - savefig_or_show, -) +from ._utils import (_deprecated_scale, _dk, check_colornorm, make_grid_spec, + savefig_or_show) if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -225,22 +220,11 @@ def __init__( # noqa: PLR0913 norm=norm, **kwds, ) - - if standard_scale == "obs": - standard_scale = "group" - msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" - warnings.warn(msg, FutureWarning, stacklevel=2) - if standard_scale == "group": - self.obs_tidy = self.obs_tidy.sub(self.obs_tidy.min(1), axis=0) - self.obs_tidy = self.obs_tidy.div(self.obs_tidy.max(1), axis=0).fillna(0) - elif standard_scale == "var": - self.obs_tidy -= self.obs_tidy.min(0) - self.obs_tidy = (self.obs_tidy / self.obs_tidy.max(0)).fillna(0) - elif standard_scale is None: - pass - else: - logg.warning("Unknown type for standard_scale, ignored") - + # scale before aggregation + X = self._view.X.astype(float) + X = self._scale_df(standard_scale, X) + # replace view.X with the scaled values (NaNs => 0) + self._view.X = np.nan_to_num(X) # Set default style parameters self.cmap = self.DEFAULT_COLORMAP self.row_palette = self.DEFAULT_ROW_PALETTE @@ -386,22 +370,23 @@ def _mainplot(self, ax: Axes): # work on a copy of the dataframes. This is to avoid changes # on the original data frames after repetitive calls to the # StackedViolin object, for example once with swap_axes and other without - _matrix = self.obs_tidy.copy() - + _matrix = pd.DataFrame( + self._view.X, + index=self._view.obs[self._group_key], + columns=self.var_names + ) + if self.var_names_idx_order is not None: _matrix = _matrix.iloc[:, self.var_names_idx_order] # get mean values for color and transform to color values # using colormap - _color_df = ( - _matrix.groupby(level=0, observed=True) - .median() - .loc[ - self.categories_order - if self.categories_order is not None - else self.categories - ] - ) + _color_df = self._agg_df("median").loc[ + self.categories_order + if self.categories_order is not None + else self.categories + ] + if self.are_axes_swapped: _color_df = _color_df.T From 2c587bca8720cd2ad9619a048fe69b92838380aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Aug 2025 23:17:33 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/plotting/_baseplot_class.py | 25 +++++++++++++++---------- src/scanpy/plotting/_dotplot.py | 24 +++++++++++++++--------- src/scanpy/plotting/_matrixplot.py | 7 ++----- src/scanpy/plotting/_stacked_violin.py | 24 +++++++++++++----------- 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index 988544f511..0e23d92a55 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -14,8 +14,13 @@ from .._compat import old_positionals from .._utils import _empty from ..get._aggregated import aggregate -from ._anndata import (VarGroups, _plot_dendrogram, _plot_var_groups_brackets, - _prepare_dataframe, _reorder_categories_after_dendrogram) +from ._anndata import ( + VarGroups, + _plot_dendrogram, + _plot_var_groups_brackets, + _prepare_dataframe, + _reorder_categories_after_dendrogram, +) from ._utils import check_colornorm, make_grid_spec if TYPE_CHECKING: @@ -140,7 +145,7 @@ def __init__( # noqa: PLR0913 del var_group_labels, var_group_positions self.var_group_rotation = var_group_rotation self.width, self.height = figsize if figsize is not None else (None, None) - + # still need this as pandas handles this procedure more optimally self.categories, obs_tidy = _prepare_dataframe( adata, @@ -162,7 +167,6 @@ def __init__( # noqa: PLR0913 var=pd.DataFrame(index=var_names), ) - if len(self.categories) > self.MAX_NUM_CATEGORIES: warn( f"Over {self.MAX_NUM_CATEGORIES} categories found. " @@ -405,11 +409,13 @@ def add_totals( _sort = sort is not None _ascending = sort == "ascending" - counts_df = self._view.obs[self._group_key].value_counts(sort=_sort, ascending=_ascending) + counts_df = self._view.obs[self._group_key].value_counts( + sort=_sort, ascending=_ascending + ) if _sort: self.categories_order = list(counts_df.index) - + self.plot_group_extra = { "kind": "group_totals", "width": size, @@ -419,7 +425,6 @@ def add_totals( } return self - def _agg_df(self, func, mask: np.ndarray | None = None) -> pd.DataFrame: """ Aggregate self._view by self._group_key, running `func` @@ -449,11 +454,11 @@ def _agg_df(self, func, mask: np.ndarray | None = None) -> pd.DataFrame: arr = ag.layers[f] out[f] = pd.DataFrame(arr, index=self.categories, columns=self.var_names) return out - + def _scale_df( self, standard_scale: Literal["var", "group"] | None = None, - df : pd.DataFrame | None = None + df: pd.DataFrame | None = None, ): """ Performs scaling of `df` based on `standard_scale` parameter @@ -473,7 +478,7 @@ def _scale_df( else: logg.warning("Unknown type for standard_scale, ignored") return df - + @old_positionals("cmap") def style(self, *, cmap: Colormap | str | None | Empty = _empty) -> Self: r"""Set visual style parameters. diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index e7fb126a6a..d3c1d1f458 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -9,7 +9,6 @@ from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params, _empty -from ..get._aggregated import aggregate from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm from ._utils import _dk, check_colornorm, fix_kwds, make_grid_spec, savefig_or_show @@ -184,23 +183,28 @@ def __init__( # noqa: PLR0913 ) if dot_size_df is None: if expression_cutoff > 0: - mask = (self._view.X > expression_cutoff).astype(self._view.X.dtype) + mask = (expression_cutoff < self._view.X).astype(self._view.X.dtype) dot_size_df = self._agg_df("mean", mask=mask) else: df_all = self._agg_df("count_nonzero") # count_nonzero → raw counts, divide by group sizes - group_sizes = self._view.obs[self._group_key].value_counts().loc[self.categories].values + group_sizes = ( + self._view.obs[self._group_key] + .value_counts() + .loc[self.categories] + .values + ) dot_size_df = df_all.div(group_sizes, axis=0) if dot_color_df is None: if mean_only_expressed and expression_cutoff > 0: - mask = (self._view.X > expression_cutoff) + mask = expression_cutoff < self._view.X df_sum = self._agg_df("sum", mask=mask) expr_counts = dot_size_df.values * group_sizes[:, None] dot_color_df = df_sum.div(expr_counts).fillna(0) else: dot_color_df = self._agg_df("mean") - + dot_color_df = self._scale_df(standard_scale, dot_color_df) else: # check that both matrices have the same shape @@ -230,11 +234,13 @@ def __init__( # noqa: PLR0913 # using the order from the doc_size_df dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns] - - # reorder rows - self.dot_size_df = dot_size_df.loc[categories_order if categories_order is not None else self.categories] - self.dot_color_df = dot_color_df.loc[categories_order if categories_order is not None else self.categories] + self.dot_size_df = dot_size_df.loc[ + categories_order if categories_order is not None else self.categories + ] + self.dot_color_df = dot_color_df.loc[ + categories_order if categories_order is not None else self.categories + ] self.standard_scale = standard_scale diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index 22783761f2..965dc242f8 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -5,7 +5,6 @@ import numpy as np from matplotlib import colormaps, rcParams -from .. import logging as logg from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params, _empty @@ -164,13 +163,11 @@ def __init__( # noqa: PLR0913 if values_df is None: values_df = self._agg_df("mean") - + values_df = self._scale_df(standard_scale, values_df) self.values_df = values_df.loc[ - categories_order - if categories_order is not None - else self.categories + categories_order if categories_order is not None else self.categories ] self.cmap = self.DEFAULT_COLORMAP diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 5f5ab9a9e3..774affed83 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -9,14 +9,18 @@ from matplotlib.colors import is_color_like from packaging.version import Version -from .. import logging as logg from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params, _empty from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._utils import (_deprecated_scale, _dk, check_colornorm, make_grid_spec, - savefig_or_show) +from ._utils import ( + _deprecated_scale, + _dk, + check_colornorm, + make_grid_spec, + savefig_or_show, +) if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -371,22 +375,20 @@ def _mainplot(self, ax: Axes): # on the original data frames after repetitive calls to the # StackedViolin object, for example once with swap_axes and other without _matrix = pd.DataFrame( - self._view.X, - index=self._view.obs[self._group_key], - columns=self.var_names + self._view.X, index=self._view.obs[self._group_key], columns=self.var_names ) - + if self.var_names_idx_order is not None: _matrix = _matrix.iloc[:, self.var_names_idx_order] # get mean values for color and transform to color values # using colormap _color_df = self._agg_df("median").loc[ - self.categories_order - if self.categories_order is not None + self.categories_order + if self.categories_order is not None else self.categories - ] - + ] + if self.are_axes_swapped: _color_df = _color_df.T From 132e343d4bfdf95c6bb78679f244044301afb4c2 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 8 Aug 2025 10:41:55 +0200 Subject: [PATCH 3/4] fix types --- src/scanpy/plotting/_baseplot_class.py | 45 +++++++++++++++----------- src/scanpy/plotting/_dotplot.py | 2 +- src/scanpy/plotting/_matrixplot.py | 2 +- src/scanpy/plotting/_stacked_violin.py | 2 +- 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index 0e23d92a55..d1cead241b 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -2,11 +2,14 @@ from __future__ import annotations +import warnings from collections.abc import Mapping -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, overload from warnings import warn import numpy as np +import pandas as pd +from anndata import AnnData from matplotlib import colormaps, gridspec from matplotlib import pyplot as plt @@ -24,15 +27,14 @@ from ._utils import check_colornorm, make_grid_spec if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from typing import Literal, Self - import pandas as pd - from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap, Normalize from .._utils import Empty + from ..get._aggregated import AggType from ._utils import ColorLike, _AxesSubplot _VarNames = str | Sequence[str] @@ -425,13 +427,24 @@ def add_totals( } return self - def _agg_df(self, func, mask: np.ndarray | None = None) -> pd.DataFrame: - """ - Aggregate self._view by self._group_key, running `func` - (or list of funcs) on the X‐matrix. Returns a DataFrame - (or dict of DataFrames) with index=self.categories, columns=self.var_names. - If mask is provided, it should be shape (n_groups, n_vars) and will - overwrite view.X before aggregating (useful for dot‐cutoff logic). + @overload + def _agg_df( + self, func: AggType, mask: np.ndarray | None = None + ) -> pd.DataFrame: ... + + @overload + def _agg_df( + self, func: Iterable[AggType], mask: np.ndarray | None = None + ) -> dict[str, pd.DataFrame]: ... + + def _agg_df( + self, func: AggType | Iterable[AggType], mask: np.ndarray | None = None + ) -> pd.DataFrame | dict[str, pd.DataFrame]: + """Aggregate `self._view` by `self._group_key`. + + Run `func` on X and eturn a DataFrame (or dict of DataFrames) with `index=self.categories`, `columns=self.var_names`. + If `mask` is provided, it should be shape `(n_groups, n_vars)` and will + overwrite view.X before aggregating (useful for dot-cutoff logic). """ # make a fresh copy so we never mutate the master view view = self._view.copy() @@ -456,13 +469,9 @@ def _agg_df(self, func, mask: np.ndarray | None = None) -> pd.DataFrame: return out def _scale_df( - self, - standard_scale: Literal["var", "group"] | None = None, - df: pd.DataFrame | None = None, - ): - """ - Performs scaling of `df` based on `standard_scale` parameter - """ + self, df: pd.DataFrame, standard_scale: Literal["var", "group", None] = None + ) -> pd.DataFrame: + """Scale `df` based on `standard_scale` parameter.""" if standard_scale == "obs": standard_scale = "group" msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index d3c1d1f458..4849da0ce8 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -205,7 +205,7 @@ def __init__( # noqa: PLR0913 else: dot_color_df = self._agg_df("mean") - dot_color_df = self._scale_df(standard_scale, dot_color_df) + dot_color_df = self._scale_df(dot_color_df, standard_scale) else: # check that both matrices have the same shape if dot_color_df.shape != dot_size_df.shape: diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index 965dc242f8..b6f30c3089 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -164,7 +164,7 @@ def __init__( # noqa: PLR0913 if values_df is None: values_df = self._agg_df("mean") - values_df = self._scale_df(standard_scale, values_df) + values_df = self._scale_df(values_df, standard_scale) self.values_df = values_df.loc[ categories_order if categories_order is not None else self.categories diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 774affed83..bb41b3a2bb 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -226,7 +226,7 @@ def __init__( # noqa: PLR0913 ) # scale before aggregation X = self._view.X.astype(float) - X = self._scale_df(standard_scale, X) + X = self._scale_df(X, standard_scale) # replace view.X with the scaled values (NaNs => 0) self._view.X = np.nan_to_num(X) # Set default style parameters From 5f62698a5152ccae3d5ad806f356d8b91282a751 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 8 Aug 2025 11:03:28 +0200 Subject: [PATCH 4/4] fix warning --- src/scanpy/plotting/_baseplot_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index d1cead241b..68ceecbee0 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -165,7 +165,7 @@ def __init__( # noqa: PLR0913 self._group_key = obs_tidy.index.name self._view = AnnData( X=obs_tidy.values, - obs=obs_tidy.index.to_frame(index=False), + obs={self._group_key: obs_tidy.index}, var=pd.DataFrame(index=var_names), )