Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_fastplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import xarray as xr
import trajan as _
import matplotlib.pyplot as plt


def test_fastplot(test_data, plot):
ds = xr.open_dataset(test_data / 'xr_spotter_bulk_test_data.nc')

# let us check that the conversion worked well, and that all main functionalities are working

# plotting without centering
plt.figure()
ds.traj.plot.fastplot()

# plotting with centering
plt.figure()
ds.traj.plot.fastplot(center_lon_circmean=True)

if plot:
plt.show()

41 changes: 41 additions & 0 deletions trajan/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
import scipy as sp

from .land import add_land

Expand Down Expand Up @@ -135,6 +136,46 @@ def set_up_map(
def __call__(self, *args, **kwargs):
return self.lines(*args, **kwargs)

def fastplot(self, center_lon_circmean=False):
"""
Do a fast plot of the data. This uses PlateCarree, no transform on the plotting,
and circmean to determine the optimal central longitude to use.

This only plots markers, not lines, to avoid wrapping lines on global datasets.

Args:

center_lon_circmean: bool, whether to use the circmean of the longitude
of the data to set the central_longitude in PlateCarree (if True), or
the default (0) if False.
"""

if self.__cartesian__:
x = self.ds.traj.tx.values.T
y = self.ds.traj.ty.values.T
else:
x = self.ds.traj.tlon.values.T
y = self.ds.traj.tlat.values.T

if center_lon_circmean:
all_y = x.flatten()
mid_y = 180.0 / np.pi * sp.stats.circmean(all_y * np.pi / 180.0, nan_policy="omit")
else:
mid_y = 0.0

ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=mid_y))
ax.coastlines()

if len(x.shape) == 2:
for crrt_buoy_idx in range(x.shape[1]):
plt.plot(x[:, crrt_buoy_idx], y[:, crrt_buoy_idx], linestyle="none", marker=".", markersize=1)
elif len(x.shape) == 1:
plt.plot(x, y, linestyle="none", marker=".", markersize=1)
else:
raise ValueError(f"the position data to plot has shape {x.shape = }; unclear how to plot this!")

gl = ax.gridlines(transform=ccrs.PlateCarree(central_longitude=mid_y), draw_labels=True)

def lines(self, *args, **kwargs):
"""
Plot the trajectory lines.
Expand Down