|
2 | 2 |
|
3 | 3 | import warnings |
4 | 4 | from datetime import datetime |
5 | | -from itertools import chain |
| 5 | +from itertools import chain, product |
6 | 6 | from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union, get_args |
7 | 7 |
|
8 | 8 | import cf_xarray # noqa: F401 |
|
17 | 17 |
|
18 | 18 | from xcdat import bounds # noqa: F401 |
19 | 19 | from xcdat._logger import _setup_custom_logger |
20 | | -from xcdat.axis import get_dim_coords |
| 20 | +from xcdat.axis import get_dim_coords, get_dim_keys |
21 | 21 | from xcdat.dataset import _get_data_var |
22 | 22 |
|
23 | 23 | logger = _setup_custom_logger(__name__) |
@@ -2025,6 +2025,210 @@ def _calculate_departures( |
2025 | 2025 | return ds_departs |
2026 | 2026 |
|
2027 | 2027 |
|
| 2028 | + def compute_monthly_average(self, data_var): |
| 2029 | + """ |
| 2030 | + Computes monthly averages for dataset |
| 2031 | +
|
| 2032 | + This function ensures that the dataset's time bounds are |
| 2033 | + ordered correctly, computes the target monthly time bounds |
| 2034 | + and associated weights, and then the monthly average. |
| 2035 | +
|
| 2036 | + Parameters |
| 2037 | + ---------- |
| 2038 | + data_var : str |
| 2039 | + The key of the data variable. |
| 2040 | +
|
| 2041 | + Returns |
| 2042 | + ------- |
| 2043 | + xr.Dataset |
| 2044 | + Dataset with the computed monthly average. |
| 2045 | +
|
| 2046 | + Notes |
| 2047 | + ----- |
| 2048 | + The monthly averages are computed from January - December, but |
| 2049 | + it is possible the source dataset starts after January or ends |
| 2050 | + before December. A potential enhancement would be to cater the |
| 2051 | + bounds to the source dataset. For example, if the source dataset |
| 2052 | + starts in March 2010, the resulting monthly dataset would begin |
| 2053 | + in March 2010. |
| 2054 | + """ |
| 2055 | + ds = self._dataset.copy() |
| 2056 | + # ensure source time bounds are ordered correctly |
| 2057 | + ds.temporal.ensure_bounds_order() |
| 2058 | + # get target time and bounds |
| 2059 | + target_time, target_bnds = ds.temporal.generate_monthly_bounds() |
| 2060 | + # get temporal weights |
| 2061 | + weights = ds.temporal.get_temporal_weights(target_bnds) |
| 2062 | + # compute average and return resulting dataset |
| 2063 | + return ds.temporal._experimental_averager(data_var, weights, target_bnds) |
| 2064 | + |
| 2065 | + |
| 2066 | + def _experimental_averager(self, data_var, weights, target_bnds): |
| 2067 | + """ |
| 2068 | + Calculates time period averages for a set of weights and bounds. |
| 2069 | +
|
| 2070 | + Parameters |
| 2071 | + ---------- |
| 2072 | + data_var : str |
| 2073 | + The key of the data variable. |
| 2074 | +
|
| 2075 | + weights : xr.DataArray |
| 2076 | + The weight of each source time slice that should be used to compute |
| 2077 | + a temporal average for each target time slice [target_time, source_time]. |
| 2078 | +
|
| 2079 | + target_bnds : xr.DataArray |
| 2080 | + The time_bnds for the target time slices. |
| 2081 | +
|
| 2082 | + Returns |
| 2083 | + ------- |
| 2084 | + xr.Dataset |
| 2085 | + The dataset with the computed temporal averages |
| 2086 | + """ |
| 2087 | + ds = self._dataset.copy() |
| 2088 | + # get time key |
| 2089 | + time_key = get_dim_keys(ds, 'T') |
| 2090 | + # convert to weighted array |
| 2091 | + da_weighted = ds[data_var].weighted(weights) |
| 2092 | + # compute weighted mean |
| 2093 | + with xr.set_options(keep_attrs=True): |
| 2094 | + da_mean = da_weighted.mean(dim=time_key) |
| 2095 | + # revert to original time coordinate name |
| 2096 | + da_mean = da_mean.rename({'target_time': time_key}) |
| 2097 | + # ensure order is the same as original dataset |
| 2098 | + da_mean = da_mean.transpose(*ds[data_var].dims) |
| 2099 | + # create output dataset |
| 2100 | + dsmean = ds.copy() |
| 2101 | + # The original time dimension is dropped from the dataset because |
| 2102 | + # it becomes obsolete after the data variable is averaged. When the |
| 2103 | + # averaged data variable is added to the dataset, the new time dimension |
| 2104 | + # and its associated coordinates are also added. |
| 2105 | + dsmean = dsmean.drop_dims(time_key) |
| 2106 | + # add weighted mean data array to output dataset |
| 2107 | + dsmean[data_var] = da_mean |
| 2108 | + # add the time bounds to the dataset |
| 2109 | + dsmean[time_key + '_bnds'] = target_bnds |
| 2110 | + return dsmean |
| 2111 | + |
| 2112 | + |
| 2113 | + def get_temporal_weights(self, target_bnds): |
| 2114 | + """Compute the temporal weights for a set of target time bounds. |
| 2115 | +
|
| 2116 | + Parameters |
| 2117 | + ---------- |
| 2118 | + target_bnds : xr.DataArray |
| 2119 | + The bounds for target time averages |
| 2120 | +
|
| 2121 | + Returns |
| 2122 | + ------- |
| 2123 | + xr.DataArray |
| 2124 | + The temporal weights that should be applied to the source data |
| 2125 | + to produce time averaged data corresponding to the target time |
| 2126 | + bounds |
| 2127 | + """ |
| 2128 | + ds = self._dataset.copy() |
| 2129 | + # Get time key and source time bounds |
| 2130 | + time_key = get_dim_keys(ds, 'T') |
| 2131 | + source_bnds = ds.cf.get_bounds(time_key).values |
| 2132 | + target_time = target_bnds['time'] |
| 2133 | + |
| 2134 | + # Preallocate weight matrix |
| 2135 | + weights = np.zeros((len(target_bnds), len(ds[time_key]))) |
| 2136 | + |
| 2137 | + # bounds adjustment |
| 2138 | + for i, tbnd in enumerate(target_bnds.values): |
| 2139 | + # Adjust source bounds to fit within target bounds |
| 2140 | + sbnds = source_bnds.copy() |
| 2141 | + sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment |
| 2142 | + sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment |
| 2143 | + |
| 2144 | + # Handle cases where bounds are outside the target range |
| 2145 | + sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound |
| 2146 | + sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound |
| 2147 | + |
| 2148 | + # Compute weights as the difference between bounds |
| 2149 | + w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]") |
| 2150 | + weights[i, :] = w |
| 2151 | + |
| 2152 | + # Convert weights to xarray DataArray |
| 2153 | + weights = xr.DataArray( |
| 2154 | + data=weights, |
| 2155 | + dims=['target_time', 'time'], |
| 2156 | + coords={'target_time': target_time.values, 'time': ds[time_key].values} |
| 2157 | + ) |
| 2158 | + return weights |
| 2159 | + |
| 2160 | + |
| 2161 | + def generate_monthly_bounds(self): |
| 2162 | + """Generates monthly time bounds and the corresponding time axis |
| 2163 | + for a dataset. |
| 2164 | +
|
| 2165 | + This method will generate monthly time bounds, e.g., |
| 2166 | + [["2010-01-01 00:00:00", "2010-02-01 00:00:00"], |
| 2167 | + ["2010-02-01 00:00:00", "2010-03-01 00:00:00"], |
| 2168 | + ["2010-03-01 00:00:00", "2010-04-01 00:00:00"], |
| 2169 | + ...] |
| 2170 | +
|
| 2171 | + and a time axis, e.g., |
| 2172 | + ["2010-01-16 12:00:00", |
| 2173 | + "2010-02-15 00:00:00", |
| 2174 | + "2010-03-16 12:00:00", |
| 2175 | + ...] |
| 2176 | +
|
| 2177 | + for a dataset. The arrays will start with January 1 of the first |
| 2178 | + year in the original dataset going through December of the final year |
| 2179 | + in the original dataset. |
| 2180 | +
|
| 2181 | + Returns |
| 2182 | + ------- |
| 2183 | + monthly_time : xr.DataArray |
| 2184 | + The centered time axis corresponding to the generated bounds. |
| 2185 | + |
| 2186 | + monthly_bnds : xr.DataArray |
| 2187 | + The generated monthly bounds. |
| 2188 | + """ |
| 2189 | + ds = self._dataset.copy() |
| 2190 | + # get all years in source dataset |
| 2191 | + time_key = get_dim_keys(ds, 'T') |
| 2192 | + years = list(set([t.year for t in ds[time_key].values])) |
| 2193 | + # get time type |
| 2194 | + ttype = type(ds[time_key].values[0]) |
| 2195 | + # create target time bounds and time axis |
| 2196 | + monthly_bnds = [] |
| 2197 | + monthly_time = [] |
| 2198 | + for year, month in product(years, range(1, 13)): |
| 2199 | + lower_bnd = ttype(year, month, 1) |
| 2200 | + upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1) |
| 2201 | + center_time = lower_bnd + (upper_bnd - lower_bnd)/2. |
| 2202 | + monthly_bnds.append([lower_bnd, upper_bnd]) |
| 2203 | + monthly_time.append(center_time) |
| 2204 | + # generate xarray dataarray objexts |
| 2205 | + monthly_time = xr.DataArray(data=monthly_time, |
| 2206 | + dims=[time_key], |
| 2207 | + coords={time_key: monthly_time}) |
| 2208 | + monthly_time.encoding = ds[time_key].encoding |
| 2209 | + target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'}) |
| 2210 | + monthly_bnds = xr.DataArray(data=monthly_bnds, |
| 2211 | + dims=[time_key, 'bnds'], |
| 2212 | + coords={time_key: monthly_time}) |
| 2213 | + monthly_bnds.encoding = ds[time_key].encoding |
| 2214 | + return monthly_time, monthly_bnds |
| 2215 | + |
| 2216 | + |
| 2217 | + def ensure_bounds_order(self): |
| 2218 | + """Ensures that time bounds are ordered [earlier, later] |
| 2219 | +
|
| 2220 | + Raises |
| 2221 | + ------ |
| 2222 | + ValueError |
| 2223 | + If there are any bounds that are out of order. |
| 2224 | + """ |
| 2225 | + ds = self._dataset.copy() |
| 2226 | + time_bnds = ds.bounds.get_bounds("T") |
| 2227 | + for tbnd in time_bnds.values: |
| 2228 | + if tbnd[0] >= tbnd[1]: |
| 2229 | + raise ValueError('Time bounds are not ordered from low-to-high') |
| 2230 | + |
| 2231 | + |
2028 | 2232 | def _infer_freq(time_coords: xr.DataArray) -> Frequency: |
2029 | 2233 | """Infers the time frequency from the coordinates. |
2030 | 2234 |
|
|
0 commit comments