Skip to content

Commit fd11c33

Browse files
Simplified affine matrix calculation for datashader (#491)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 31b1ee3 commit fd11c33

File tree

2 files changed

+4
-42
lines changed

2 files changed

+4
-42
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from scanpy._settings import settings as sc_settings
2020
from spatialdata import get_extent, get_values, join_spatialelement_table
2121
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
22-
from spatialdata.transformations import get_transformation, set_transformation
22+
from spatialdata.transformations import set_transformation
2323
from spatialdata.transformations.transformations import Identity
2424
from xarray import DataTree
2525

@@ -44,7 +44,6 @@
4444
_get_colors_for_categorical_obs,
4545
_get_extent_and_range_for_datashader_canvas,
4646
_get_linear_colormap,
47-
_get_transformation_matrix_for_datashader,
4847
_hex_no_alpha,
4948
_is_coercable_to_float,
5049
_map_color_seg,
@@ -186,10 +185,9 @@ def _render_shapes(
186185
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
187186

188187
# apply transformations to the individual points
189-
element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system)
190-
tm = _get_transformation_matrix_for_datashader(element_trans)
188+
tm = trans.get_matrix()
191189
transformed_element = sdata_filt.shapes[element].transform(
192-
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
190+
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2]
193191
)
194192
transformed_element = ShapesModel.parse(
195193
gpd.GeoDataFrame(

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,8 @@
6565
from spatialdata._core.query.relational_query import _locate_value
6666
from spatialdata._types import ArrayLike
6767
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement
68-
69-
# from spatialdata.transformations.transformations import Scale
70-
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation
71-
from spatialdata.transformations import Sequence as SDSequence
7268
from spatialdata.transformations.operations import get_transformation
69+
from spatialdata.transformations.transformations import Scale
7370
from xarray import DataArray, DataTree
7471

7572
from spatialdata_plot._logging import logger
@@ -2381,39 +2378,6 @@ def _prepare_transformation(
23812378
return trans, trans_data
23822379

23832380

2384-
def _get_datashader_trans_matrix_of_single_element(
2385-
trans: Identity | Scale | Affine | MapAxis | Translation,
2386-
) -> npt.NDArray[Any]:
2387-
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
2388-
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))
2389-
2390-
if isinstance(trans, Identity):
2391-
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2392-
if isinstance(trans, (Scale | Affine)):
2393-
# idea: "flip the y-axis", apply transformation, flip back
2394-
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
2395-
return flip_and_transform
2396-
if isinstance(trans, MapAxis):
2397-
# no flipping needed
2398-
return tm
2399-
# for a Translation, we need the transposed transformation matrix
2400-
tm_T = tm.T
2401-
assert isinstance(tm_T, np.ndarray)
2402-
return tm_T
2403-
2404-
2405-
def _get_transformation_matrix_for_datashader(
2406-
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
2407-
) -> npt.NDArray[Any]:
2408-
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
2409-
if isinstance(trans, SDSequence):
2410-
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
2411-
for x in trans.transformations:
2412-
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
2413-
return tm
2414-
return _get_datashader_trans_matrix_of_single_element(trans)
2415-
2416-
24172381
def _datashader_map_aggregate_to_color(
24182382
agg: DataArray,
24192383
cmap: str | list[str] | ListedColormap,

0 commit comments

Comments
 (0)