Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 87 additions & 10 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@

from __future__ import annotations

import warnings
from collections.abc import Mapping
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, overload
from warnings import warn

import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib import colormaps, gridspec
from matplotlib import pyplot as plt

from .. import logging as logg
from .._compat import old_positionals
from .._utils import _empty
from ..get._aggregated import aggregate
from ._anndata import (
VarGroups,
_plot_dendrogram,
Expand All @@ -23,15 +27,14 @@
from ._utils import check_colornorm, make_grid_spec

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from typing import Literal, Self

import pandas as pd
from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, Normalize

from .._utils import Empty
from ..get._aggregated import AggType
from ._utils import ColorLike, _AxesSubplot

_VarNames = str | Sequence[str]
Expand Down Expand Up @@ -145,7 +148,8 @@ def __init__( # noqa: PLR0913
self.var_group_rotation = var_group_rotation
self.width, self.height = figsize if figsize is not None else (None, None)

self.categories, self.obs_tidy = _prepare_dataframe(
# still need this as pandas handles this procedure more optimally
self.categories, obs_tidy = _prepare_dataframe(
adata,
self.var_names,
groupby,
Expand All @@ -155,6 +159,16 @@ def __init__( # noqa: PLR0913
layer=layer,
gene_symbols=gene_symbols,
)
# we are going to save a view of adata as we still need it for filtering in dotplot by expression_cutoff and mean_only_expressed
# also AnnData is a little lighter than DataFrame
# and we can replace self.adata as it is used elsewhere
self._group_key = obs_tidy.index.name
self._view = AnnData(
X=obs_tidy.values,
obs={self._group_key: obs_tidy.index},
var=pd.DataFrame(index=var_names),
)

if len(self.categories) > self.MAX_NUM_CATEGORIES:
warn(
f"Over {self.MAX_NUM_CATEGORIES} categories found. "
Expand All @@ -164,16 +178,16 @@ def __init__( # noqa: PLR0913
)

if categories_order is not None and (
set(self.obs_tidy.index.categories) != set(categories_order)
set(self.categories) != set(categories_order)
):
logg.error(
"Please check that the categories given by "
"the `order` parameter match the categories that "
"want to be reordered.\n\n"
"Mismatch: "
f"{set(self.obs_tidy.index.categories).difference(categories_order)}\n\n"
f"{set(self.categories).difference(categories_order)}\n\n"
f"Given order categories: {categories_order}\n\n"
f"{groupby} categories: {list(self.obs_tidy.index.categories)}\n"
f"{groupby} categories: {list(self.categories)}\n"
)
return

Expand Down Expand Up @@ -397,10 +411,12 @@ def add_totals(

_sort = sort is not None
_ascending = sort == "ascending"
counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending)
counts_df = self._view.obs[self._group_key].value_counts(
sort=_sort, ascending=_ascending
)

if _sort:
self.categories_order = counts_df.index
self.categories_order = list(counts_df.index)

self.plot_group_extra = {
"kind": "group_totals",
Expand All @@ -411,6 +427,67 @@ def add_totals(
}
return self

@overload
def _agg_df(
self, func: AggType, mask: np.ndarray | None = None
) -> pd.DataFrame: ...

@overload
def _agg_df(
self, func: Iterable[AggType], mask: np.ndarray | None = None
) -> dict[str, pd.DataFrame]: ...

def _agg_df(
self, func: AggType | Iterable[AggType], mask: np.ndarray | None = None
) -> pd.DataFrame | dict[str, pd.DataFrame]:
"""Aggregate `self._view` by `self._group_key`.

Run `func` on X and eturn a DataFrame (or dict of DataFrames) with `index=self.categories`, `columns=self.var_names`.
If `mask` is provided, it should be shape `(n_groups, n_vars)` and will
overwrite view.X before aggregating (useful for dot-cutoff logic).
"""
# make a fresh copy so we never mutate the master view
view = self._view.copy()
if mask is not None:
view.X = mask.astype(view.X.dtype)

ag = aggregate(
view,
by=self._group_key,
func=func,
axis="obs",
)
# if single func, return one DataFrame
if isinstance(func, str):
arr = ag.layers[func]
return pd.DataFrame(arr, index=self.categories, columns=self.var_names)
# if multiple, return a dict of DataFrames
out = {}
for f in func:
arr = ag.layers[f]
out[f] = pd.DataFrame(arr, index=self.categories, columns=self.var_names)
return out

def _scale_df(
self, df: pd.DataFrame, standard_scale: Literal["var", "group", None] = None
) -> pd.DataFrame:
"""Scale `df` based on `standard_scale` parameter."""
if standard_scale == "obs":
standard_scale = "group"
msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead"
warnings.warn(msg, FutureWarning, stacklevel=2)
if standard_scale == "group":
df = df.sub(df.min(1), axis=0)
df = df.div(df.max(1), axis=0).fillna(0)
elif standard_scale == "var":
df -= df.min(0)
df = (df / df.max(0)).fillna(0)
elif standard_scale is None:
pass
else:
logg.warning("Unknown type for standard_scale, ignored")
return df

@old_positionals("cmap")
def style(self, *, cmap: Colormap | str | None | Empty = _empty) -> Self:
r"""Set visual style parameters.
Expand Down
79 changes: 30 additions & 49 deletions src/scanpy/plotting/_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,7 @@
from .._utils import _doc_params, _empty
from ._baseplot_class import BasePlot, doc_common_groupby_plot_args
from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm
from ._utils import (
_dk,
check_colornorm,
fix_kwds,
make_grid_spec,
savefig_or_show,
)
from ._utils import _dk, check_colornorm, fix_kwds, make_grid_spec, savefig_or_show

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -187,46 +181,31 @@ def __init__( # noqa: PLR0913
norm=norm,
**kwds,
)

# for if category defined by groupby (if any) compute for each var_name
# 1. the fraction of cells in the category having a value >expression_cutoff
# 2. the mean value over the category

# 1. compute fraction of cells having value > expression_cutoff
# transform obs_tidy into boolean matrix using the expression_cutoff
obs_bool = self.obs_tidy > expression_cutoff

# compute the sum per group which in the boolean matrix this is the number
# of values >expression_cutoff, and divide the result by the total number of
# values in the group (given by `count()`)
if dot_size_df is None:
dot_size_df = (
obs_bool.groupby(level=0, observed=True).sum()
/ obs_bool.groupby(level=0, observed=True).count()
)
if expression_cutoff > 0:
mask = (expression_cutoff < self._view.X).astype(self._view.X.dtype)
dot_size_df = self._agg_df("mean", mask=mask)
else:
df_all = self._agg_df("count_nonzero")
# count_nonzero → raw counts, divide by group sizes
group_sizes = (
self._view.obs[self._group_key]
.value_counts()
.loc[self.categories]
.values
)
dot_size_df = df_all.div(group_sizes, axis=0)

if dot_color_df is None:
# 2. compute mean expression value value
if mean_only_expressed:
dot_color_df = (
self.obs_tidy.mask(~obs_bool)
.groupby(level=0, observed=True)
.mean()
.fillna(0)
)
else:
dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean()

if standard_scale == "group":
dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0)
dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0)
elif standard_scale == "var":
dot_color_df -= dot_color_df.min(0)
dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0)
elif standard_scale is None:
pass
if mean_only_expressed and expression_cutoff > 0:
mask = expression_cutoff < self._view.X
df_sum = self._agg_df("sum", mask=mask)
expr_counts = dot_size_df.values * group_sizes[:, None]
dot_color_df = df_sum.div(expr_counts).fillna(0)
else:
logg.warning("Unknown type for standard_scale, ignored")
dot_color_df = self._agg_df("mean")

dot_color_df = self._scale_df(dot_color_df, standard_scale)
else:
# check that both matrices have the same shape
if dot_color_df.shape != dot_size_df.shape:
Expand Down Expand Up @@ -255,12 +234,14 @@ def __init__( # noqa: PLR0913
# using the order from the doc_size_df
dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns]

self.dot_color_df, self.dot_size_df = (
df.loc[
categories_order if categories_order is not None else self.categories
]
for df in (dot_color_df, dot_size_df)
)
# reorder rows
self.dot_size_df = dot_size_df.loc[
categories_order if categories_order is not None else self.categories
]
self.dot_color_df = dot_color_df.loc[
categories_order if categories_order is not None else self.categories
]

self.standard_scale = standard_scale

# Set default style parameters
Expand Down
35 changes: 7 additions & 28 deletions src/scanpy/plotting/_matrixplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
import numpy as np
from matplotlib import colormaps, rcParams

from .. import logging as logg
from .._compat import old_positionals
from .._settings import settings
from .._utils import _doc_params, _empty
from ._baseplot_class import BasePlot, doc_common_groupby_plot_args
from ._docs import (
doc_common_plot_args,
doc_show_save_ax,
doc_vboundnorm,
)
from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm
from ._utils import _dk, check_colornorm, fix_kwds, savefig_or_show

if TYPE_CHECKING:
Expand Down Expand Up @@ -167,29 +162,13 @@ def __init__( # noqa: PLR0913
)

if values_df is None:
# compute mean value
values_df = (
self.obs_tidy.groupby(level=0, observed=True)
.mean()
.loc[
self.categories_order
if self.categories_order is not None
else self.categories
]
)
values_df = self._agg_df("mean")

values_df = self._scale_df(values_df, standard_scale)

if standard_scale == "group":
values_df = values_df.sub(values_df.min(1), axis=0)
values_df = values_df.div(values_df.max(1), axis=0).fillna(0)
elif standard_scale == "var":
values_df -= values_df.min(0)
values_df = (values_df / values_df.max(0)).fillna(0)
elif standard_scale is None:
pass
else:
logg.warning("Unknown type for standard_scale, ignored")

self.values_df = values_df
self.values_df = values_df.loc[
categories_order if categories_order is not None else self.categories
]

self.cmap = self.DEFAULT_COLORMAP
self.edge_color = self.DEFAULT_EDGE_COLOR
Expand Down
41 changes: 14 additions & 27 deletions src/scanpy/plotting/_stacked_violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from matplotlib.colors import is_color_like
from packaging.version import Version

from .. import logging as logg
from .._compat import old_positionals
from .._settings import settings
from .._utils import _doc_params, _empty
Expand Down Expand Up @@ -225,22 +224,11 @@ def __init__( # noqa: PLR0913
norm=norm,
**kwds,
)

if standard_scale == "obs":
standard_scale = "group"
msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead"
warnings.warn(msg, FutureWarning, stacklevel=2)
if standard_scale == "group":
self.obs_tidy = self.obs_tidy.sub(self.obs_tidy.min(1), axis=0)
self.obs_tidy = self.obs_tidy.div(self.obs_tidy.max(1), axis=0).fillna(0)
elif standard_scale == "var":
self.obs_tidy -= self.obs_tidy.min(0)
self.obs_tidy = (self.obs_tidy / self.obs_tidy.max(0)).fillna(0)
elif standard_scale is None:
pass
else:
logg.warning("Unknown type for standard_scale, ignored")

# scale before aggregation
X = self._view.X.astype(float)
X = self._scale_df(X, standard_scale)
# replace view.X with the scaled values (NaNs => 0)
self._view.X = np.nan_to_num(X)
# Set default style parameters
self.cmap = self.DEFAULT_COLORMAP
self.row_palette = self.DEFAULT_ROW_PALETTE
Expand Down Expand Up @@ -386,22 +374,21 @@ def _mainplot(self, ax: Axes):
# work on a copy of the dataframes. This is to avoid changes
# on the original data frames after repetitive calls to the
# StackedViolin object, for example once with swap_axes and other without
_matrix = self.obs_tidy.copy()
_matrix = pd.DataFrame(
self._view.X, index=self._view.obs[self._group_key], columns=self.var_names
)

if self.var_names_idx_order is not None:
_matrix = _matrix.iloc[:, self.var_names_idx_order]

# get mean values for color and transform to color values
# using colormap
_color_df = (
_matrix.groupby(level=0, observed=True)
.median()
.loc[
self.categories_order
if self.categories_order is not None
else self.categories
]
)
_color_df = self._agg_df("median").loc[
self.categories_order
if self.categories_order is not None
else self.categories
]

if self.are_axes_swapped:
_color_df = _color_df.T

Expand Down
Loading