Skip to content

Commit 17c38aa

Browse files
author
Sonja Stockhaus
committed
add shapes parameter to render shapes as hex/circle/square
1 parent c5cb734 commit 17c38aa

File tree

4 files changed

+129
-7
lines changed

4 files changed

+129
-7
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
from copy import deepcopy
77
from pathlib import Path
8-
from typing import Any
8+
from typing import Any, Literal
99

1010
import matplotlib.pyplot as plt
1111
import numpy as np
@@ -170,6 +170,7 @@ def render_shapes(
170170
method: str | None = None,
171171
table_name: str | None = None,
172172
table_layer: str | None = None,
173+
shape: Literal["circle", "hex", "square"] | None = None,
173174
**kwargs: Any,
174175
) -> sd.SpatialData:
175176
"""
@@ -232,6 +233,9 @@ def render_shapes(
232233
table_layer: str | None
233234
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
234235
:attr:`sdata.table.X` is used for coloring.
236+
shape: Literal["circle", "hex", "square"] | None
237+
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
238+
specified, the shapes are converted to a circle/hexagon/square before rendering.
235239
236240
**kwargs : Any
237241
Additional arguments for customization. This can include:
@@ -276,6 +280,7 @@ def render_shapes(
276280
scale=scale,
277281
table_name=table_name,
278282
table_layer=table_layer,
283+
shape=shape,
279284
method=method,
280285
ds_reduction=kwargs.get("datashader_reduction"),
281286
)
@@ -304,6 +309,7 @@ def render_shapes(
304309
transfunc=kwargs.get("transfunc"),
305310
table_name=param_values["table_name"],
306311
table_layer=param_values["table_layer"],
312+
shape=param_values["shape"],
307313
zorder=n_steps,
308314
method=param_values["method"],
309315
ds_reduction=param_values["ds_reduction"],

src/spatialdata_plot/pl/render.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from spatialdata_plot.pl.utils import (
3737
_ax_show_and_transform,
38+
_convert_shapes,
3839
_create_image_from_datashader_result,
3940
_datashader_aggregate_with_function,
4041
_datashader_map_aggregate_to_color,
@@ -160,6 +161,12 @@ def _render_shapes(
160161
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
161162

162163
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
164+
# convert shapes if necessary
165+
if render_params.shape is not None:
166+
current_type = shapes["geometry"].type
167+
if not (render_params.shape == "circle" and (current_type == "Point").all()):
168+
logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.")
169+
shapes = _convert_shapes(shapes, render_params.shape)
163170

164171
# Determine which method to use for rendering
165172
method = render_params.method
@@ -188,9 +195,7 @@ def _render_shapes(
188195
# apply transformations to the individual points
189196
element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system)
190197
tm = _get_transformation_matrix_for_datashader(element_trans)
191-
transformed_element = sdata_filt.shapes[element].transform(
192-
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
193-
)
198+
transformed_element = shapes.transform(lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2])
194199
transformed_element = ShapesModel.parse(
195200
gpd.GeoDataFrame(
196201
data=sdata_filt.shapes[element].drop("geometry", axis=1),

src/spatialdata_plot/pl/render_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class ShapesRenderParams:
9090
zorder: int = 0
9191
table_name: str | None = None
9292
table_layer: str | None = None
93+
shape: Literal["circle", "hex", "square"] | None = None
9394
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
9495

9596

src/spatialdata_plot/pl/utils.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
import os
45
import warnings
56
from collections import OrderedDict
@@ -51,6 +52,7 @@
5152
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
5253
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
5354
from scanpy.plotting.palettes import default_20, default_28, default_102
55+
from scipy.spatial import ConvexHull
5456
from skimage.color import label2rgb
5557
from skimage.morphology import erosion, square
5658
from skimage.segmentation import find_boundaries
@@ -1709,6 +1711,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
17091711
if size < 0:
17101712
raise ValueError("Parameter 'size' must be a positive number.")
17111713

1714+
if element_type == "shapes" and (shape := param_dict.get("shape")) is not None:
1715+
if not isinstance(shape, str):
1716+
raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.")
1717+
if shape not in ["circle", "hex", "square"]:
1718+
raise ValueError(
1719+
f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']."
1720+
)
1721+
17121722
table_name = param_dict.get("table_name")
17131723
table_layer = param_dict.get("table_layer")
17141724
if table_name and not isinstance(param_dict["table_name"], str):
@@ -1920,6 +1930,7 @@ def _validate_shape_render_params(
19201930
scale: float | int,
19211931
table_name: str | None,
19221932
table_layer: str | None,
1933+
shape: Literal["circle", "hex", "square"] | None,
19231934
method: str | None,
19241935
ds_reduction: str | None,
19251936
) -> dict[str, dict[str, Any]]:
@@ -1939,6 +1950,7 @@ def _validate_shape_render_params(
19391950
"scale": scale,
19401951
"table_name": table_name,
19411952
"table_layer": table_layer,
1953+
"shape": shape,
19421954
"method": method,
19431955
"ds_reduction": ds_reduction,
19441956
}
@@ -1959,6 +1971,7 @@ def _validate_shape_render_params(
19591971
element_params[el]["norm"] = param_dict["norm"]
19601972
element_params[el]["scale"] = param_dict["scale"]
19611973
element_params[el]["table_layer"] = param_dict["table_layer"]
1974+
element_params[el]["shape"] = param_dict["shape"]
19621975

19631976
element_params[el]["color"] = param_dict["color"]
19641977

@@ -2086,7 +2099,7 @@ def _validate_image_render_params(
20862099
def _get_wanted_render_elements(
20872100
sdata: SpatialData,
20882101
sdata_wanted_elements: list[str],
2089-
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
2102+
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
20902103
cs: str,
20912104
element_type: Literal["images", "labels", "points", "shapes"],
20922105
) -> tuple[list[str], list[str], bool]:
@@ -2243,7 +2256,7 @@ def _create_image_from_datashader_result(
22432256

22442257

22452258
def _datashader_aggregate_with_function(
2246-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2259+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
22472260
cvs: Canvas,
22482261
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
22492262
col_for_color: str | None,
@@ -2307,7 +2320,7 @@ def _datashader_aggregate_with_function(
23072320

23082321

23092322
def _datshader_get_how_kw_for_spread(
2310-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2323+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
23112324
) -> str:
23122325
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
23132326
reduction = reduction or "sum"
@@ -2478,3 +2491,100 @@ def _hex_no_alpha(hex: str) -> str:
24782491
return "#" + hex_digits[:6]
24792492

24802493
raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")
2494+
2495+
2496+
def _convert_shapes(shapes: GeoDataFrame, target_shape: str) -> GeoDataFrame:
2497+
"""Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape."""
2498+
2499+
# define individual conversion methods
2500+
def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
2501+
vertices = [
2502+
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
2503+
for angle in range(0, 360, 60)
2504+
]
2505+
return shapely.Polygon(vertices), None
2506+
2507+
def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
2508+
vertices = [
2509+
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
2510+
for angle in range(45, 360, 90)
2511+
]
2512+
return shapely.Polygon(vertices), None
2513+
2514+
def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
2515+
return center, radius
2516+
2517+
def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
2518+
center, radius = _polygon_to_circle(polygon)
2519+
return _circle_to_hexagon(center, radius)
2520+
2521+
def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
2522+
center, radius = _polygon_to_circle(polygon)
2523+
return _circle_to_square(center, radius)
2524+
2525+
def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]:
2526+
coords = np.array(polygon.exterior.coords)
2527+
circle_points = coords[ConvexHull(coords).vertices]
2528+
center = np.mean(circle_points, axis=0)
2529+
radius = max(np.linalg.norm(p - center) for p in circle_points)
2530+
assert isinstance(radius, float) # shut up mypy
2531+
return shapely.Point(center), radius
2532+
2533+
def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
2534+
center, radius = _multipolygon_to_circle(multipolygon)
2535+
return _circle_to_hexagon(center, radius)
2536+
2537+
def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
2538+
center, radius = _multipolygon_to_circle(multipolygon)
2539+
return _circle_to_square(center, radius)
2540+
2541+
def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]:
2542+
coords = []
2543+
for polygon in multipolygon.geoms:
2544+
coords.extend(polygon.exterior.coords)
2545+
points = np.array(coords)
2546+
circle_points = points[ConvexHull(points).vertices]
2547+
center = np.mean(circle_points, axis=0)
2548+
radius = max(np.linalg.norm(p - center) for p in circle_points)
2549+
assert isinstance(radius, float) # shut up mypy
2550+
return shapely.Point(center), radius
2551+
2552+
# define dict with all conversion methods
2553+
if target_shape == "circle":
2554+
conversion_methods = {
2555+
"Point": _circle_to_circle,
2556+
"Polygon": _polygon_to_circle,
2557+
"Multipolygon": _multipolygon_to_circle,
2558+
}
2559+
pass
2560+
elif target_shape == "hex":
2561+
conversion_methods = {
2562+
"Point": _circle_to_hexagon,
2563+
"Polygon": _polygon_to_hexagon,
2564+
"Multipolygon": _multipolygon_to_hexagon,
2565+
}
2566+
else:
2567+
conversion_methods = {
2568+
"Point": _circle_to_square,
2569+
"Polygon": _polygon_to_square,
2570+
"Multipolygon": _multipolygon_to_square,
2571+
}
2572+
2573+
# convert every shape
2574+
for i in range(shapes.shape[0]):
2575+
if shapes["geometry"][i].type == "Point":
2576+
converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore
2577+
elif shapes["geometry"][i].type == "Polygon":
2578+
converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore
2579+
elif shapes["geometry"][i].type == "MultiPolygon":
2580+
converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore
2581+
else:
2582+
error_type = shapes["geometry"][i].type
2583+
raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.")
2584+
shapes["geometry"][i] = converted
2585+
if radius is not None:
2586+
if "radius" not in shapes.columns:
2587+
shapes["radius"] = np.nan
2588+
shapes["radius"][i] = radius
2589+
2590+
return shapes

0 commit comments

Comments
 (0)