Skip to content

Commit f8c94e6

Browse files
committed
simplify TimeSeries.__getitem__
1 parent 28df3c3 commit f8c94e6

File tree

1 file changed

+44
-29
lines changed

1 file changed

+44
-29
lines changed

pytorch_forecasting/data/timeseries/_timeseries_v2.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -206,54 +206,69 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
206206
weights : torch.Tensor of shape (n_timepoints,), optional
207207
Only included if weights are not `None`.
208208
"""
209-
group_id = self._group_ids[index]
210-
211-
if self._group:
212-
mask = self._groups[group_id]
209+
time = self.time
210+
feature_cols = self.feature_cols
211+
_target = self._target
212+
_known = self._known
213+
_static = self._static
214+
_group = self._group
215+
_groups = self._groups
216+
_group_ids = self._group_ids
217+
weight = self.weight
218+
data_future = self.data_future
219+
220+
group_id = _group_ids[index]
221+
222+
if _group:
223+
mask = _groups[group_id]
213224
data = self.data.loc[mask]
214225
else:
215226
data = self.data
216227

217-
cutoff_time = data[self.time].max()
228+
cutoff_time = data[time].max()
229+
230+
data_vals = data[time].values
231+
data_tgt_vals = data[_target].values
232+
data_feat_vals = data[feature_cols].values
218233

219234
result = {
220-
"t": data[self.time].values,
221-
"y": torch.tensor(data[self._target].values),
222-
"x": torch.tensor(data[self.feature_cols].values),
235+
"t": data_vals,
236+
"y": torch.tensor(data_tgt_vals),
237+
"x": torch.tensor(data_feat_vals),
223238
"group": torch.tensor([hash(str(group_id))]),
224-
"st": torch.tensor(data[self._static].iloc[0].values if self._static else []),
239+
"st": torch.tensor(data[_static].iloc[0].values if _static else []),
225240
"cutoff_time": cutoff_time,
226241
}
227242

228-
if self.data_future is not None:
229-
if self._group:
230-
future_mask = self.data_future.groupby(self._group).groups[group_id]
243+
if data_future is not None:
244+
if _group:
245+
future_mask = self.data_future.groupby(_group).groups[group_id]
231246
future_data = self.data_future.loc[future_mask]
232247
else:
233248
future_data = self.data_future
234249

235-
combined_times = np.concatenate(
236-
[data[self.time].values, future_data[self.time].values]
237-
)
250+
data_fut_vals = future_data[time].values
251+
252+
combined_times = np.concatenate([data_vals, data_fut_vals])
238253
combined_times = np.unique(combined_times)
239254
combined_times.sort()
240255

241256
num_timepoints = len(combined_times)
242-
x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)
243-
y_merged = np.full((num_timepoints, len(self._target)), np.nan)
257+
x_merged = np.full((num_timepoints, len(feature_cols)), np.nan)
258+
y_merged = np.full((num_timepoints, len(_target)), np.nan)
244259

245260
current_time_indices = {t: i for i, t in enumerate(combined_times)}
246-
for i, t in enumerate(data[self.time].values):
261+
for i, t in enumerate(data_vals):
247262
idx = current_time_indices[t]
248-
x_merged[idx] = data[self.feature_cols].values[i]
249-
y_merged[idx] = data[self._target].values[i]
263+
x_merged[idx] = data_feat_vals[i]
264+
y_merged[idx] = data_tgt_vals[i]
250265

251-
for i, t in enumerate(future_data[self.time].values):
266+
for i, t in enumerate(data_fut_vals):
252267
if t in current_time_indices:
253268
idx = current_time_indices[t]
254-
for j, col in enumerate(self._known):
255-
if col in self.feature_cols:
256-
feature_idx = self.feature_cols.index(col)
269+
for j, col in enumerate(_known):
270+
if col in feature_cols:
271+
feature_idx = feature_cols.index(col)
257272
x_merged[idx, feature_idx] = future_data[col].values[i]
258273

259274
result.update(
@@ -264,17 +279,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
264279
}
265280
)
266281

267-
if self.weight:
282+
if weight:
268283
if self.data_future is not None and self.weight in self.data_future.columns:
269284
weights_merged = np.full(num_timepoints, np.nan)
270-
for i, t in enumerate(data[self.time].values):
285+
for i, t in enumerate(data_vals):
271286
idx = current_time_indices[t]
272-
weights_merged[idx] = data[self.weight].values[i]
287+
weights_merged[idx] = data[weight].values[i]
273288

274-
for i, t in enumerate(future_data[self.time].values):
289+
for i, t in enumerate(data_fut_vals):
275290
if t in current_time_indices and self.weight in future_data.columns:
276291
idx = current_time_indices[t]
277-
weights_merged[idx] = future_data[self.weight].values[i]
292+
weights_merged[idx] = future_data[weight].values[i]
278293

279294
result["weights"] = torch.tensor(weights_merged, dtype=torch.float32)
280295
else:

0 commit comments

Comments
 (0)