@@ -206,54 +206,69 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
206
206
weights : torch.Tensor of shape (n_timepoints,), optional
207
207
Only included if weights are not `None`.
208
208
"""
209
- group_id = self ._group_ids [index ]
210
-
211
- if self ._group :
212
- mask = self ._groups [group_id ]
209
+ time = self .time
210
+ feature_cols = self .feature_cols
211
+ _target = self ._target
212
+ _known = self ._known
213
+ _static = self ._static
214
+ _group = self ._group
215
+ _groups = self ._groups
216
+ _group_ids = self ._group_ids
217
+ weight = self .weight
218
+ data_future = self .data_future
219
+
220
+ group_id = _group_ids [index ]
221
+
222
+ if _group :
223
+ mask = _groups [group_id ]
213
224
data = self .data .loc [mask ]
214
225
else :
215
226
data = self .data
216
227
217
- cutoff_time = data [self .time ].max ()
228
+ cutoff_time = data [time ].max ()
229
+
230
+ data_vals = data [time ].values
231
+ data_tgt_vals = data [_target ].values
232
+ data_feat_vals = data [feature_cols ].values
218
233
219
234
result = {
220
- "t" : data [ self . time ]. values ,
221
- "y" : torch .tensor (data [ self . _target ]. values ),
222
- "x" : torch .tensor (data [ self . feature_cols ]. values ),
235
+ "t" : data_vals ,
236
+ "y" : torch .tensor (data_tgt_vals ),
237
+ "x" : torch .tensor (data_feat_vals ),
223
238
"group" : torch .tensor ([hash (str (group_id ))]),
224
- "st" : torch .tensor (data [self . _static ].iloc [0 ].values if self . _static else []),
239
+ "st" : torch .tensor (data [_static ].iloc [0 ].values if _static else []),
225
240
"cutoff_time" : cutoff_time ,
226
241
}
227
242
228
- if self . data_future is not None :
229
- if self . _group :
230
- future_mask = self .data_future .groupby (self . _group ).groups [group_id ]
243
+ if data_future is not None :
244
+ if _group :
245
+ future_mask = self .data_future .groupby (_group ).groups [group_id ]
231
246
future_data = self .data_future .loc [future_mask ]
232
247
else :
233
248
future_data = self .data_future
234
249
235
- combined_times = np . concatenate (
236
- [ data [ self . time ]. values , future_data [ self . time ]. values ]
237
- )
250
+ data_fut_vals = future_data [ time ]. values
251
+
252
+ combined_times = np . concatenate ([ data_vals , data_fut_vals ] )
238
253
combined_times = np .unique (combined_times )
239
254
combined_times .sort ()
240
255
241
256
num_timepoints = len (combined_times )
242
- x_merged = np .full ((num_timepoints , len (self . feature_cols )), np .nan )
243
- y_merged = np .full ((num_timepoints , len (self . _target )), np .nan )
257
+ x_merged = np .full ((num_timepoints , len (feature_cols )), np .nan )
258
+ y_merged = np .full ((num_timepoints , len (_target )), np .nan )
244
259
245
260
current_time_indices = {t : i for i , t in enumerate (combined_times )}
246
- for i , t in enumerate (data [ self . time ]. values ):
261
+ for i , t in enumerate (data_vals ):
247
262
idx = current_time_indices [t ]
248
- x_merged [idx ] = data [ self . feature_cols ]. values [i ]
249
- y_merged [idx ] = data [ self . _target ]. values [i ]
263
+ x_merged [idx ] = data_feat_vals [i ]
264
+ y_merged [idx ] = data_tgt_vals [i ]
250
265
251
- for i , t in enumerate (future_data [ self . time ]. values ):
266
+ for i , t in enumerate (data_fut_vals ):
252
267
if t in current_time_indices :
253
268
idx = current_time_indices [t ]
254
- for j , col in enumerate (self . _known ):
255
- if col in self . feature_cols :
256
- feature_idx = self . feature_cols .index (col )
269
+ for j , col in enumerate (_known ):
270
+ if col in feature_cols :
271
+ feature_idx = feature_cols .index (col )
257
272
x_merged [idx , feature_idx ] = future_data [col ].values [i ]
258
273
259
274
result .update (
@@ -264,17 +279,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
264
279
}
265
280
)
266
281
267
- if self . weight :
282
+ if weight :
268
283
if self .data_future is not None and self .weight in self .data_future .columns :
269
284
weights_merged = np .full (num_timepoints , np .nan )
270
- for i , t in enumerate (data [ self . time ]. values ):
285
+ for i , t in enumerate (data_vals ):
271
286
idx = current_time_indices [t ]
272
- weights_merged [idx ] = data [self . weight ].values [i ]
287
+ weights_merged [idx ] = data [weight ].values [i ]
273
288
274
- for i , t in enumerate (future_data [ self . time ]. values ):
289
+ for i , t in enumerate (data_fut_vals ):
275
290
if t in current_time_indices and self .weight in future_data .columns :
276
291
idx = current_time_indices [t ]
277
- weights_merged [idx ] = future_data [self . weight ].values [i ]
292
+ weights_merged [idx ] = future_data [weight ].values [i ]
278
293
279
294
result ["weights" ] = torch .tensor (weights_merged , dtype = torch .float32 )
280
295
else :
0 commit comments