Skip to content

Commit 41299ac

Browse files
committed
add: get_xyprofile; draft: stats
1 parent d9994f6 commit 41299ac

File tree

6 files changed

+362
-63
lines changed

6 files changed

+362
-63
lines changed

notebooks/0-example-with-warper.ipynb

Lines changed: 14 additions & 14 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ description = "Conformal mapping-based warping of neuronal arbor morphologies."
99
authors = []
1010
requires-python = ">=3.9.0"
1111
dependencies = [
12+
"alphashape>=1.3.1",
1213
"numpy>=2.0.2",
1314
"pandas>=2.2.3",
1415
"pygridfit>=0.1.4",

pywarper/arbor.py

Lines changed: 92 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from numpy.linalg import lstsq
6+
from scipy.ndimage import gaussian_filter
67
from scipy.special import i0
78

89

@@ -308,13 +309,13 @@ def segment_lengths(
308309
parent = edges[:, 1].astype(int) - 1
309310

310311
density = np.zeros(nodes.shape[0], dtype=float)
311-
mid = nodes.copy()
312+
mid = nodes.copy()
312313

313-
vec = nodes[parent] - nodes[child]
314-
seg_len = np.linalg.norm(vec, axis=1)
314+
vec = nodes[parent] - nodes[child]
315+
seg_len = np.linalg.norm(vec, axis=1)
315316

316-
density[child] = seg_len
317-
mid[child] = nodes[child] + 0.5 * vec
317+
density[child] = seg_len
318+
mid[child] = nodes[child] + 0.5 * vec
318319

319320
return density, mid
320321

@@ -336,27 +337,27 @@ def gridder1d(
336337
# Constants
337338
# ------------------------------------------------------------------
338339
alpha, W, err = 2, 5, 1e-3
339-
S = int(np.ceil(0.91 / err / alpha))
340+
S = int(np.ceil(0.91 / err / alpha))
340341
beta = np.pi * np.sqrt((W / alpha * (alpha - 0.5))**2 - 0.8)
341342

342343
# ------------------------------------------------------------------
343344
# Pre-computed Kaiser–Bessel lookup table (LUT)
344345
# ------------------------------------------------------------------
345-
s = np.linspace(-1, 1, 2 * S * W + 1)
346-
F_kbZ = i0(beta * np.sqrt(1 - s**2))
346+
s = np.linspace(-1, 1, 2 * S * W + 1)
347+
F_kbZ = i0(beta * np.sqrt(1 - s**2))
347348
F_kbZ /= F_kbZ.max()
348349

349350
# ------------------------------------------------------------------
350351
# Fourier transform of the 1-D kernel
351352
# ------------------------------------------------------------------
352-
Gz = alpha * n
353-
z = np.arange(-Gz // 2, Gz // 2)
354-
arg = (np.pi * W * z / Gz)**2 - beta**2
353+
Gz = alpha * n
354+
z = np.arange(-Gz // 2, Gz // 2)
355+
arg = (np.pi * W * z / Gz)**2 - beta**2
355356

356-
kbZ = np.empty_like(arg, dtype=float)
357+
kbZ = np.empty_like(arg, dtype=float)
357358
pos, neg = arg > 1e-12, arg < -1e-12
358-
kbZ[pos] = np.sin(np.sqrt(arg[pos])) / np.sqrt(arg[pos])
359-
kbZ[neg] = np.sinh(np.sqrt(-arg[neg])) / np.sqrt(-arg[neg])
359+
kbZ[pos] = np.sin(np.sqrt(arg[pos])) / np.sqrt(arg[pos])
360+
kbZ[neg] = np.sinh(np.sqrt(-arg[neg])) / np.sqrt(-arg[neg])
360361
kbZ[~(pos | neg)] = 1.0
361362
kbZ *= np.sqrt(Gz)
362363

@@ -367,13 +368,13 @@ def gridder1d(
367368
out = np.zeros(n_os, dtype=float)
368369

369370
centre = n_os / 2 + 1 # 1-based like MATLAB
370-
nz = centre + n_os * z_samples # fractional indices
371+
nz = centre + n_os * z_samples # fractional indices
371372

372373
half_w = (W - 1) // 2
373374
for lz in range(-half_w, half_w + 1):
374375
nzt = np.round(nz + lz).astype(int)
375376
zpos = S * ((nz - nzt) + W / 2)
376-
kw = F_kbZ[np.round(zpos).astype(int)]
377+
kw = F_kbZ[np.round(zpos).astype(int)]
377378

378379
nzt = np.clip(nzt, 0, n_os - 1) # clamp out-of-range
379380
np.add.at(out, nzt, density * kw)
@@ -384,12 +385,12 @@ def gridder1d(
384385
# myifft → de-apodise → abs(myfft3)
385386
# ------------------------------------------------------------------
386387
u = n
387-
f = np.fft.ifftshift(np.fft.ifft(np.fft.ifftshift(out))) * np.sqrt(u)
388-
f = f[int(np.ceil((f.size - u) / 2)) : int(np.ceil((f.size + u) / 2))]
388+
f = np.fft.ifftshift(np.fft.ifft(np.fft.ifftshift(out))) * np.sqrt(u)
389+
f = f[int(np.ceil((f.size - u) / 2)) : int(np.ceil((f.size + u) / 2))]
389390

390391
f /= kbZ[u // 2 : 3 * u // 2] # de-apodisation
391392

392-
F = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(f))) / np.sqrt(f.size)
393+
F = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(f))) / np.sqrt(f.size)
393394
return np.abs(F)
394395

395396
# =====================================================================
@@ -402,7 +403,7 @@ def get_zprofile(
402403
z_window: Optional[list[float]] = None,
403404
on_sac_pos: float = 0.0,
404405
off_sac_pos: float = 12.0,
405-
grid_point_count: int = 120,
406+
nbins: int = 120,
406407
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
407408
"""
408409
Compute a 1-D depth profile (length per z-bin) from a warped arbor.
@@ -428,7 +429,7 @@ def get_zprofile(
428429
on_sac_pos, off_sac_pos
429430
Desired positions of the starburst layers in the *final* profile
430431
(µm). Defaults reproduce the numbers quoted in Sümbül et al. 2014.
431-
grid_point_count
432+
nbins
432433
Number of evenly-spaced output bins along z.
433434
434435
Returns
@@ -445,7 +446,7 @@ def get_zprofile(
445446
"""
446447

447448
# 0) decide the common span
448-
dz_onoff = off_sac_pos - on_sac_pos # 12 µm by default
449+
dz_onoff = off_sac_pos - on_sac_pos # 12 µm by default
449450
if z_window is None:
450451
z_min, z_max = None, None # auto-span
451452
else:
@@ -456,36 +457,95 @@ def get_zprofile(
456457
warped_arbor["edges"])
457458

458459
vz_on, vz_off = warped_arbor["medVZmin"], warped_arbor["medVZmax"]
459-
rel_depth = (nodes[:, 2] / z_res - vz_on) / (vz_off - vz_on) # 0→ON, 1→OFF
460-
z_phys = on_sac_pos + rel_depth * dz_onoff # µm in global frame
460+
rel_depth = (nodes[:, 2] / z_res - vz_on) / (vz_off - vz_on) # 0→ON, 1→OFF
461+
z_phys = on_sac_pos + rel_depth * dz_onoff # µm in global frame
461462

462463

463464
# 2) decide bin edges *once*
464465
if z_min is None or z_max is None:
465466
# grow just enough to contain this cell, then round to one bin
466-
z_min = np.floor(z_phys.min() / dz_onoff * grid_point_count) * dz_onoff / grid_point_count
467-
z_max = np.ceil (z_phys.max() / dz_onoff * grid_point_count) * dz_onoff / grid_point_count
467+
z_min = np.floor(z_phys.min() / dz_onoff * nbins) * dz_onoff / nbins
468+
z_max = np.ceil (z_phys.max() / dz_onoff * nbins) * dz_onoff / nbins
468469

469-
bin_edges = np.linspace(z_min, z_max, grid_point_count + 1)
470+
bin_edges = np.linspace(z_min, z_max, nbins + 1)
470471

471472
# 3) histogram-based z profile
472473
z_hist, _ = np.histogram(z_phys, bins=bin_edges, weights=density)
473-
z_hist *= density.sum() / (z_hist.sum() * z_res)
474+
z_hist *= density.sum() / (z_hist.sum() * z_res)
474475

475476

476477
# 4) Kaiser–Bessel gridded version (needs centred –0.5…0.5 inputs)
477-
centre = (z_min + z_max) / 2
478+
centre = (z_min + z_max) / 2
478479
halfspan = (z_max - z_min) / 2
479480
z_samples = (z_phys - centre) / halfspan # now in [-1, 1]
480481

481-
z_dist = gridder1d(z_samples / 2, density, grid_point_count) # /2 → [-0.5, 0.5]
482-
z_dist *= density.sum() / (z_dist.sum() * z_res)
482+
z_dist = gridder1d(z_samples / 2, density, nbins) # /2 → [-0.5, 0.5]
483+
z_dist *= density.sum() / (z_dist.sum() * z_res)
483484

484485
# 5) bin centres & rescaled arbor
485-
x_um = 0.5 * (bin_edges[1:] + bin_edges[:-1]) # centre of each bin
486+
x_um = 0.5 * (bin_edges[1:] + bin_edges[:-1]) # centre of each bin
486487

487488
nodes_norm = warped_arbor["nodes"].copy()
488489
nodes_norm[:, 2] = z_phys
489490
normed_arbor = {**warped_arbor, "nodes": nodes_norm}
490491

491492
return x_um, z_dist, z_hist, normed_arbor
493+
494+
def get_xyprofile(
495+
warped_arbor: dict,
496+
xy_window: Optional[list[float]] = None,
497+
nbins: int = 20,
498+
sigma_bins: float = 1.0,
499+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
500+
"""
501+
2-D dendritic-length density on a fixed XY grid (no per-cell rotation).
502+
503+
Parameters
504+
----------
505+
warped_arbor
506+
output of ``warp_arbor()`` (nodes in µm).
507+
xy_window
508+
(xmin, xmax, ymin, ymax) in µm that *all* cells
509+
share. If ``None`` use this arbor's tight bounding box.
510+
nbins
511+
number of bins along X **and** Y (default 20).
512+
513+
Returns
514+
-------
515+
x_um: (nbins,) µm
516+
bin centres along X
517+
y_um: (nbins,) µm
518+
bin centres along Y
519+
xy_dist: (nbins, nbins) µm
520+
smoothed dendritic length per bin
521+
xy_hist: (nbins, nbins) µm
522+
histogram-based dendritic length per bin
523+
"""
524+
525+
# 1) edge lengths and mid-points (same helper you already have)
526+
density, mid = segment_lengths(warped_arbor["nodes"],
527+
warped_arbor["edges"])
528+
529+
# 2) decide the common window
530+
if xy_window is None:
531+
xmin, xmax = mid[:, 0].min(), mid[:, 0].max()
532+
ymin, ymax = mid[:, 1].min(), mid[:, 1].max()
533+
else:
534+
xmin, xmax, ymin, ymax = xy_window
535+
536+
# 3) 2-D histogram weighted by edge length and density
537+
xy_hist, x_edges, y_edges = np.histogram2d(
538+
mid[:, 0], mid[:, 1],
539+
bins=[nbins, nbins],
540+
range=[[xmin, xmax], [ymin, ymax]],
541+
weights=density
542+
)
543+
544+
xy_dist = gaussian_filter(xy_hist, sigma=sigma_bins, mode='nearest')
545+
xy_dist *= density.sum() / xy_dist.sum() # keep Σ = total length
546+
547+
# 5) bin centres for plotting
548+
x = 0.5 * (x_edges[:-1] + x_edges[1:])
549+
y = 0.5 * (y_edges[:-1] + y_edges[1:])
550+
551+
return x, y, xy_dist, xy_hist

pywarper/stats.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
from alphashape import alphashape
3+
4+
5+
def get_convex_hull(points: np.ndarray) -> np.ndarray:
6+
"""Get the convex hull of a set of points."""
7+
if len(points) < 3:
8+
return points
9+
hull = alphashape(points, alpha=0)
10+
return np.array(hull.exterior.xy).T

pywarper/warper.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Union
1+
from typing import Optional, Union
22

33
import numpy as np
44
import pandas as pd
55

6-
from pywarper.arbor import get_zprofile, warp_arbor
6+
from pywarper.arbor import get_xyprofile, get_zprofile, warp_arbor
77
from pywarper.surface import fit_surface, warp_surface
88
from pywarper.utils import read_arbor_trace
99

@@ -31,13 +31,10 @@ def __init__(
3131
on_sac: Union[dict[str, np.ndarray], tuple[np.ndarray, np.ndarray, np.ndarray]],
3232
swc_path: str,
3333
*,
34-
smoothness: int = 15,
35-
conformal_jump: int = 2,
3634
voxel_resolution: list[float] = [0.4, 0.4, 0.5],
3735
verbose: bool = False,
3836
) -> None:
39-
self.smoothness = smoothness
40-
self.conformal_jump = conformal_jump
37+
4138
self.voxel_resolution = voxel_resolution
4239
self.verbose = verbose
4340
self.swc_path = swc_path
@@ -52,21 +49,21 @@ def __init__(
5249
# ---------------------------------------------------------------------
5350
# public pipeline ------------------------------------------------------
5451
# ---------------------------------------------------------------------
55-
def fit_surfaces(self) -> "Warper":
52+
def fit_surfaces(self, smoothness: int = 15) -> "Warper":
5653
"""Fit ON / OFF SAC meshes with *pygridfit*."""
5754
if self.verbose:
5855
print("[Warper] Fitting OFF‑SAC surface …")
5956
self.vz_off, *_ = fit_surface(
60-
x=self.off_sac[0], y=self.off_sac[1], z=self.off_sac[2], smoothness=self.smoothness
57+
x=self.off_sac[0], y=self.off_sac[1], z=self.off_sac[2], smoothness=smoothness
6158
)
6259
if self.verbose:
6360
print("[Warper] Fitting ON‑SAC surface …")
6461
self.vz_on, *_ = fit_surface(
65-
x=self.on_sac[0], y=self.on_sac[1], z=self.on_sac[2], smoothness=self.smoothness
62+
x=self.on_sac[0], y=self.on_sac[1], z=self.on_sac[2], smoothness=smoothness
6663
)
6764
return self
6865

69-
def build_mapping(self) -> "Warper":
66+
def build_mapping(self, conformal_jump: int = 2) -> "Warper":
7067
"""Create the quasi‑conformal surface mapping."""
7168
if self.vz_off is None or self.vz_on is None:
7269
raise RuntimeError("Surfaces not fitted. Call fit_surfaces() first.")
@@ -81,12 +78,12 @@ def build_mapping(self) -> "Warper":
8178
self.vz_on,
8279
self.vz_off,
8380
bounds,
84-
conformal_jump=self.conformal_jump,
81+
conformal_jump=conformal_jump,
8582
verbose=self.verbose,
8683
)
8784
return self
8885

89-
def warp_arbor(self) -> "Warper":
86+
def warp_arbor(self, conformal_jump: int = 2) -> "Warper":
9087
"""Apply the mapping to the arbor."""
9188
if self.mapping is None:
9289
raise RuntimeError("Mapping missing. Call build_mapping() first.")
@@ -98,22 +95,38 @@ def warp_arbor(self) -> "Warper":
9895
self.radii,
9996
self.mapping,
10097
voxel_resolution=self.voxel_resolution,
101-
conformal_jump=self.conformal_jump,
98+
conformal_jump=conformal_jump,
10299
verbose=self.verbose,
103100
)
104101
return self
105102

106103
# convenience helpers --------------------------------------------------
107-
def get_arbor_denstiy(self, z_res: float = 0.5, z_window: list[float] = [-30, 30]) -> "Warper":
104+
def get_arbor_denstiy(
105+
self,
106+
z_res: float = 0.5,
107+
z_window: list[float] = [-30, 30],
108+
z_nbins: int = 120,
109+
xy_window: Optional[list[float]] = None,
110+
xy_nbins: int = 20,
111+
xy_sigma_bins: float = 1.
112+
) -> "Warper":
108113
"""Return depth profile as in *get_zprofile*."""
109114
if self.warped_arbor is None:
110115
raise RuntimeError("Arbor not warped yet. Call warp().")
111-
x, z_dist, z_hist, normed_arbor = get_zprofile(self.warped_arbor, z_res=z_res, z_window=z_window)
112-
self.x: np.ndarray = x
116+
z_x, z_dist, z_hist, normed_arbor = get_zprofile(self.warped_arbor, z_res=z_res, z_window=z_window)
117+
self.z_x: np.ndarray = z_x
113118
self.z_dist: np.ndarray = z_dist
114119
self.z_hist: np.ndarray = z_hist
115120
self.normed_arbor: dict = normed_arbor
116121

122+
xy_x, xy_y, xy_dist, xy_hist = get_xyprofile(
123+
self.warped_arbor, xy_window=xy_window, nbins=xy_nbins, sigma_bins=xy_sigma_bins
124+
)
125+
self.xy_x: np.ndarray = xy_x
126+
self.xy_y: np.ndarray = xy_y
127+
self.xy_dist: np.ndarray = xy_dist
128+
self.xy_hist: np.ndarray = xy_hist
129+
117130
return self
118131

119132
def save(self, out_path: str) -> None:
@@ -151,3 +164,13 @@ def _as_xyz(data) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
151164
if isinstance(data, (tuple, list)) and len(data) == 3:
152165
return map(np.asarray, data) # type: ignore[arg-type]
153166
raise TypeError("SAC data must be a mapping with keys x/y/z or a 3‑tuple of arrays.")
167+
168+
169+
def stats(self):
170+
171+
"""Return the statistics of the warped arbor."""
172+
if self.warped_arbor is None:
173+
raise RuntimeError("Arbor not warped yet. Call warp().")
174+
175+
# Calculate the statistics
176+
##

0 commit comments

Comments
 (0)