diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 98d16c920..7739dcc2a 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -1430,25 +1430,34 @@ def _data_to_tensors(self, data: pd.DataFrame) -> dict[str, torch.Tensor]: time index """ - def _to_tensor(cols, long=True) -> torch.Tensor: + def _to_tensor(cols, long=True, real=False) -> torch.Tensor: """Convert data[cols] to torch tensor. Converts sub-frames to numpy and then to torch tensor. Makes the following choices for types: - * float columns are converted to torch.float - * integer columns are converted to torch.int64 or torch.long, - depending on the long argument + - real is True: + * the sub-frame is converted to a torch.float32 tensor + - long is True (and real is False): + * the sub-frame is converted to a torch.long tensor + - real is False and long is False: + * if all columns are integer or boolean, the sub-frame is + converted to a torch.int64 tensor + * if one column is a float, the sub-frame is converted to + a torch.float32 tensor """ if not isinstance(cols, list) and cols not in data.columns: return None if isinstance(cols, list) and len(cols) == 0: dtypekind = "f" elif isinstance(cols, list): # and len(cols) > 0 - dtypekind = data.dtypes[cols[0]].kind + # dtypekind = data.dtypes[cols[0]].kind + dtypekind = np.result_type(*data[cols].dtypes.to_list()).kind else: dtypekind = data.dtypes[cols].kind - if not long: + if real: + return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float) + elif not long: return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.int64) elif dtypekind in "bi": return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.long) diff --git a/tests/test_data/test_timeseries.py b/tests/test_data/test_timeseries.py index 0b1b0ce74..f681c23f1 100644 --- a/tests/test_data/test_timeseries.py +++ b/tests/test_data/test_timeseries.py @@ -678,3 +678,41 @@ def distance_to_weights(dist): if idx > 100: break print(a) + + +def test_correct_dtype_inference(): + # Create a small dataset + data = pd.DataFrame( + { + "time_idx": np.arange(30), + "value": np.sin(np.arange(30) / 5) + np.random.normal(scale=1, size=30), + "group": ["A"] * 30, + } + ) + + # Define the dataset + dataset = TimeSeriesDataSet( + data.copy(), + time_idx="time_idx", + target="value", + group_ids=["group"], + static_categoricals=["group"], + max_encoder_length=4, + max_prediction_length=2, + time_varying_unknown_reals=["value"], + target_normalizer=None, + # WATCH THIS + time_varying_known_reals=["time_idx"], + scalers=dict(time_idx=None), + ) + + # and the dataloader + dataloader = dataset.to_dataloader(batch_size=8) + + x, y = next(iter(dataset)) + # real features must be real + assert x["x_cont"].dtype is torch.float + + x, y = next(iter(dataloader)) + # real features must be real + assert x["encoder_cont"].dtype is torch.float