Skip to content

Commit 2dba031

Browse files
committed
Prototyping more sophisticated bounds handling for temporal averaging
1 parent 8824b32 commit 2dba031

File tree

1 file changed

+206
-2
lines changed

1 file changed

+206
-2
lines changed

xcdat/temporal.py

Lines changed: 206 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from datetime import datetime
5-
from itertools import chain
5+
from itertools import chain, product
66
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union, get_args
77

88
import cf_xarray # noqa: F401
@@ -17,7 +17,7 @@
1717

1818
from xcdat import bounds # noqa: F401
1919
from xcdat._logger import _setup_custom_logger
20-
from xcdat.axis import get_dim_coords
20+
from xcdat.axis import get_dim_coords, get_dim_keys
2121
from xcdat.dataset import _get_data_var
2222

2323
logger = _setup_custom_logger(__name__)
@@ -2025,6 +2025,210 @@ def _calculate_departures(
20252025
return ds_departs
20262026

20272027

2028+
def compute_monthly_average(self, data_var):
2029+
"""
2030+
Computes monthly averages for dataset
2031+
2032+
This function ensures that the dataset's time bounds are
2033+
ordered correctly, computes the target monthly time bounds
2034+
and associated weights, and then the monthly average.
2035+
2036+
Parameters
2037+
----------
2038+
data_var : str
2039+
The key of the data variable.
2040+
2041+
Returns
2042+
-------
2043+
xr.Dataset
2044+
Dataset with the computed monthly average.
2045+
2046+
Notes
2047+
-----
2048+
The monthly averages are computed from January - December, but
2049+
it is possible the source dataset starts after January or ends
2050+
before December. A potential enhancement would be to cater the
2051+
bounds to the source dataset. For example, if the source dataset
2052+
starts in March 2010, the resulting monthly dataset would begin
2053+
in March 2010.
2054+
"""
2055+
ds = self._dataset.copy()
2056+
# ensure source time bounds are ordered correctly
2057+
ds.temporal.ensure_bounds_order()
2058+
# get target time and bounds
2059+
target_time, target_bnds = ds.temporal.generate_monthly_bounds()
2060+
# get temporal weights
2061+
weights = ds.temporal.get_temporal_weights(target_bnds)
2062+
# compute average and return resulting dataset
2063+
return ds.temporal._experimental_averager(data_var, weights, target_bnds)
2064+
2065+
2066+
def _experimental_averager(self, data_var, weights, target_bnds):
2067+
"""
2068+
Calculates time period averages for a set of weights and bounds.
2069+
2070+
Parameters
2071+
----------
2072+
data_var : str
2073+
The key of the data variable.
2074+
2075+
weights : xr.DataArray
2076+
The weight of each source time slice that should be used to compute
2077+
a temporal average for each target time slice [target_time, source_time].
2078+
2079+
target_bnds : xr.DataArray
2080+
The time_bnds for the target time slices.
2081+
2082+
Returns
2083+
-------
2084+
xr.Dataset
2085+
The dataset with the computed temporal averages
2086+
"""
2087+
ds = self._dataset.copy()
2088+
# get time key
2089+
time_key = get_dim_keys(ds, 'T')
2090+
# convert to weighted array
2091+
da_weighted = ds[data_var].weighted(weights)
2092+
# compute weighted mean
2093+
with xr.set_options(keep_attrs=True):
2094+
da_mean = da_weighted.mean(dim=time_key)
2095+
# revert to original time coordinate name
2096+
da_mean = da_mean.rename({'target_time': time_key})
2097+
# ensure order is the same as original dataset
2098+
da_mean = da_mean.transpose(*ds[data_var].dims)
2099+
# create output dataset
2100+
dsmean = ds.copy()
2101+
# The original time dimension is dropped from the dataset because
2102+
# it becomes obsolete after the data variable is averaged. When the
2103+
# averaged data variable is added to the dataset, the new time dimension
2104+
# and its associated coordinates are also added.
2105+
dsmean = dsmean.drop_dims(time_key)
2106+
# add weighted mean data array to output dataset
2107+
dsmean[data_var] = da_mean
2108+
# add the time bounds to the dataset
2109+
dsmean[time_key + '_bnds'] = target_bnds
2110+
return dsmean
2111+
2112+
2113+
def get_temporal_weights(self, target_bnds):
2114+
"""Compute the temporal weights for a set of target time bounds.
2115+
2116+
Parameters
2117+
----------
2118+
target_bnds : xr.DataArray
2119+
The bounds for target time averages
2120+
2121+
Returns
2122+
-------
2123+
xr.DataArray
2124+
The temporal weights that should be applied to the source data
2125+
to produce time averaged data corresponding to the target time
2126+
bounds
2127+
"""
2128+
ds = self._dataset.copy()
2129+
# Get time key and source time bounds
2130+
time_key = get_dim_keys(ds, 'T')
2131+
source_bnds = ds.cf.get_bounds(time_key).values
2132+
target_time = target_bnds['time']
2133+
2134+
# Preallocate weight matrix
2135+
weights = np.zeros((len(target_bnds), len(ds[time_key])))
2136+
2137+
# bounds adjustment
2138+
for i, tbnd in enumerate(target_bnds.values):
2139+
# Adjust source bounds to fit within target bounds
2140+
sbnds = source_bnds.copy()
2141+
sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment
2142+
sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment
2143+
2144+
# Handle cases where bounds are outside the target range
2145+
sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound
2146+
sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound
2147+
2148+
# Compute weights as the difference between bounds
2149+
w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]")
2150+
weights[i, :] = w
2151+
2152+
# Convert weights to xarray DataArray
2153+
weights = xr.DataArray(
2154+
data=weights,
2155+
dims=['target_time', 'time'],
2156+
coords={'target_time': target_time.values, 'time': ds[time_key].values}
2157+
)
2158+
return weights
2159+
2160+
2161+
def generate_monthly_bounds(self):
2162+
"""Generates monthly time bounds and the corresponding time axis
2163+
for a dataset.
2164+
2165+
This method will generate monthly time bounds, e.g.,
2166+
[["2010-01-01 00:00:00", "2010-02-01 00:00:00"],
2167+
["2010-02-01 00:00:00", "2010-03-01 00:00:00"],
2168+
["2010-03-01 00:00:00", "2010-04-01 00:00:00"],
2169+
...]
2170+
2171+
and a time axis, e.g.,
2172+
["2010-01-16 12:00:00",
2173+
"2010-02-15 00:00:00",
2174+
"2010-03-16 12:00:00",
2175+
...]
2176+
2177+
for a dataset. The arrays will start with January 1 of the first
2178+
year in the original dataset going through December of the final year
2179+
in the original dataset.
2180+
2181+
Returns
2182+
-------
2183+
monthly_time : xr.DataArray
2184+
The centered time axis corresponding to the generated bounds.
2185+
2186+
monthly_bnds : xr.DataArray
2187+
The generated monthly bounds.
2188+
"""
2189+
ds = self._dataset.copy()
2190+
# get all years in source dataset
2191+
time_key = get_dim_keys(ds, 'T')
2192+
years = list(set([t.year for t in ds[time_key].values]))
2193+
# get time type
2194+
ttype = type(ds[time_key].values[0])
2195+
# create target time bounds and time axis
2196+
monthly_bnds = []
2197+
monthly_time = []
2198+
for year, month in product(years, range(1, 13)):
2199+
lower_bnd = ttype(year, month, 1)
2200+
upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1)
2201+
center_time = lower_bnd + (upper_bnd - lower_bnd)/2.
2202+
monthly_bnds.append([lower_bnd, upper_bnd])
2203+
monthly_time.append(center_time)
2204+
# generate xarray dataarray objexts
2205+
monthly_time = xr.DataArray(data=monthly_time,
2206+
dims=[time_key],
2207+
coords={time_key: monthly_time})
2208+
monthly_time.encoding = ds[time_key].encoding
2209+
target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'})
2210+
monthly_bnds = xr.DataArray(data=monthly_bnds,
2211+
dims=[time_key, 'bnds'],
2212+
coords={time_key: monthly_time})
2213+
monthly_bnds.encoding = ds[time_key].encoding
2214+
return monthly_time, monthly_bnds
2215+
2216+
2217+
def ensure_bounds_order(self):
2218+
"""Ensures that time bounds are ordered [earlier, later]
2219+
2220+
Raises
2221+
------
2222+
ValueError
2223+
If there are any bounds that are out of order.
2224+
"""
2225+
ds = self._dataset.copy()
2226+
time_bnds = ds.bounds.get_bounds("T")
2227+
for tbnd in time_bnds.values:
2228+
if tbnd[0] >= tbnd[1]:
2229+
raise ValueError('Time bounds are not ordered from low-to-high')
2230+
2231+
20282232
def _infer_freq(time_coords: xr.DataArray) -> Frequency:
20292233
"""Infers the time frequency from the coordinates.
20302234

0 commit comments

Comments
 (0)