diff --git a/setup.py b/setup.py index 52867f50..c72107e0 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,9 @@ "modernnca": [ "category_encoders", ], + "beta": [ + "category_encoders", + ], } benchmark_requires = [] @@ -53,6 +56,7 @@ "tabdpt", "tabm", "modernnca", + "beta", ]: benchmark_requires += extras_require[extra_package] benchmark_requires = list(set(benchmark_requires)) diff --git a/tabrepo/benchmark/models/ag/__init__.py b/tabrepo/benchmark/models/ag/__init__.py index 4cfded7c..39725986 100644 --- a/tabrepo/benchmark/models/ag/__init__.py +++ b/tabrepo/benchmark/models/ag/__init__.py @@ -1,15 +1,19 @@ from __future__ import annotations +from tabrepo.benchmark.models.ag.beta.beta_model import BetaModel from tabrepo.benchmark.models.ag.ebm.ebm_model import ExplainableBoostingMachineModel from tabrepo.benchmark.models.ag.modernnca.modernnca_model import ModernNCAModel from tabrepo.benchmark.models.ag.realmlp.realmlp_model import RealMLPModel from tabrepo.benchmark.models.ag.tabdpt.tabdpt_model import TabDPTModel from tabrepo.benchmark.models.ag.tabicl.tabicl_model import TabICLModel from tabrepo.benchmark.models.ag.tabm.tabm_model import TabMModel -from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_client_model import TabPFNV2ClientModel +from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_client_model import ( + TabPFNV2ClientModel, +) from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_model import TabPFNV2Model __all__ = [ + "BetaModel", "ExplainableBoostingMachineModel", "ModernNCAModel", "RealMLPModel", diff --git a/tabrepo/benchmark/models/ag/beta/__init__.py b/tabrepo/benchmark/models/ag/beta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tabrepo/benchmark/models/ag/beta/beta_model.py b/tabrepo/benchmark/models/ag/beta/beta_model.py new file mode 100644 index 00000000..99f61372 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/beta_model.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import torch +from autogluon.common.utils.resource_utils import ResourceManager +from autogluon.core.models import AbstractModel +from autogluon.features.generators import LabelEncoderFeatureGenerator +from sklearn.impute import SimpleImputer + +if TYPE_CHECKING: + import pandas as pd + +logger = logging.getLogger(__name__) + + +class BetaModel(AbstractModel): + ag_key = "BETA" + ag_name = "BetaTabPFN" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.cat_col_names_ = None + self.has_num_cols = None + self.num_prep_ = None + + def _fit( + self, + X: pd.DataFrame, + y: pd.Series, + X_val: pd.DataFrame = None, + y_val: pd.Series = None, + time_limit: float | None = None, + num_cpus: int = 1, + num_gpus: float = 0, + **kwargs, + ): + from sklearn.pipeline import Pipeline + + from tabrepo.benchmark.models.ag.beta.deps.talent_utils import ( + get_deep_args, + set_seeds, + ) + from tabrepo.benchmark.models.ag.beta.talent_beta_method import BetaMethod + from tabrepo.benchmark.models.ag.modernnca.modernnca_model import ( + RTDLQuantileTransformer, + ) + + device = "cpu" if num_gpus == 0 else "cuda:0" + if (device == "cuda:0") and (not torch.cuda.is_available()): + raise AssertionError( + "Fit specified to use GPU, but CUDA is not available on this machine. " + "Please switch to CPU usage instead.", + ) + + # Format data for TALENT code + X = self.preprocess(X, is_train=True) + if X_val is not None: + X_val = self.preprocess(X_val) + else: + raise ValueError("Validation data (X_val) must be provided for BetaModel.") + + self.num_prep_ = Pipeline( + steps=[ + ("qt", RTDLQuantileTransformer()), + ("imp", SimpleImputer(add_indicator=True)), + ] + ) + self.has_num_cols = bool(set(X.columns) - set(self.cat_col_names_)) + ds_parts = {} + for part, X_data, y_data in [("train", X, y), ("val", X_val, y_val)]: + tensors = {} + + tensors["x_cat"] = X_data[self.cat_col_names_].to_numpy() + if self.has_num_cols: + x_cont_np = X_data.drop(columns=self.cat_col_names_).to_numpy( + dtype=np.float32 + ) + if part == "train": + self.num_prep_.fit(x_cont_np) + tensors["x_cont"] = self.num_prep_.transform(x_cont_np) + else: + tensors["x_cont"] = np.empty((len(X_data), 0), dtype=np.float32) + + if self.problem_type == "regression": + tensors["y"] = y_data.to_numpy(np.float32) + else: + tensors["y"] = y_data.to_numpy(np.int32) + ds_parts[part] = tensors + + data = [ + {part: ds_parts[part][tens_name] for part in ["train", "val"]} + for tens_name in ["x_cont", "x_cat", "y"] + ] + info = { + "task_type": "binclass" + if self.problem_type == "binary" + else self.problem_type, + "n_num_features": ds_parts["train"]["x_cont"].shape[1], + "n_cat_features": ds_parts["train"]["x_cat"].shape[1], + } + if info["n_num_features"] == 0: + data[0] = None + if info["n_cat_features"] == 0: + data[1] = None + data = tuple(data) + + # Set up model + hyp = self._get_model_params() + args, _, _ = get_deep_args() + set_seeds(hyp["random_state"]) + args.device = device + args.max_epoch = hyp["max_epoch"] + args.batch_size = hyp["batch_size"] + # TODO: come up with solution to set this based on dataset size + max_context_size = hyp.get("max_context_size", 1000) + + if info["n_num_features"] > 200: + # Use less K as otherwise exploding memory constraints + args.config["model"]["k"] = 10 + + args.time_to_fit_in_seconds = time_limit + args.early_stopping_metric = self.stopping_metric + + save_path = self.path + "/tmp_model" + Path(save_path).mkdir(parents=True, exist_ok=True) + args.save_path = str(save_path) + + self.model = BetaMethod(args, self.problem_type == "regression", max_context_size=max_context_size) + self.model.fit(data=data, info=info, train=True, model_name="best-val") + shutil.rmtree(save_path, ignore_errors=True) + + def _predict_proba(self, X, **kwargs) -> np.ndarray: + X = self.preprocess(X, **kwargs).copy() + + # TALENT Format + tensors = {} + tensors["x_cat"] = X[self.cat_col_names_].to_numpy() + tensors["x_cont"] = ( + self.num_prep_.transform( + X.drop(columns=X[self.cat_col_names_]).to_numpy(dtype=np.float32) + ) + if self.has_num_cols + else np.empty((len(X), 0), dtype=np.float32) + ) + if self.problem_type == "regression": + tensors["y"] = np.zeros(tensors["x_cat"].shape[0]) + else: + tensors["y"] = np.zeros(tensors["x_cat"].shape[0], dtype=np.int32) + data = [{"test": tensors[tens_name]} for tens_name in ["x_cont", "x_cat", "y"]] + for i in range(2): + if data[i]["test"].size == 0: + data[i] = None + data = tuple(data) + + # AG Predict Output + y_pred = self.model.predict(data=data, info=None, model_name="best-val") + if self.problem_type == "regression": + return y_pred.numpy() + + y_pred_proba = torch.softmax(y_pred, dim=-1).numpy() + if y_pred_proba.shape[1] == 2: + y_pred_proba = y_pred_proba[:, 1] + return self._convert_proba_to_unified_form(y_pred_proba) + + def _preprocess( + self, + X: pd.DataFrame, + is_train: bool = False, + bool_to_cat: bool = False, + impute_bool: bool = True, + **kwargs, + ) -> pd.DataFrame: + X = super()._preprocess(X, **kwargs) + + # Ordinal Encoding of cat features but keep as cat + if is_train: + self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0) + self._feature_generator.fit(X=X) + if self._feature_generator.features_in: + X = X.copy() + X[self._feature_generator.features_in] = self._feature_generator.transform( + X=X + ) + if self.cat_col_names_ is None: + self.cat_col_names_ = self._feature_generator.features_in[:] + else: + self.cat_col_names_ = [] + + return X + + def _set_default_params(self): + default_params = { + "random_state": 0, + "max_epoch": 200, + "batch_size": 1024, + } + for param, val in default_params.items(): + self._set_default_param_value(param, val) + + @classmethod + def supported_problem_types(cls) -> list[str] | None: + return ["binary", "multiclass"] + + def _get_default_resources(self) -> tuple[int, int]: + import torch + + # logical=False is faster in training + num_cpus = ResourceManager.get_cpu_count_psutil(logical=False) + num_gpus = 1 if torch.cuda.is_available() else 0 + return num_cpus, num_gpus + + @classmethod + def _class_tags(cls): + return {"can_estimate_memory_usage_static": False} + + def _more_tags(self) -> dict: + return {"can_refit_full": False} diff --git a/tabrepo/benchmark/models/ag/beta/deps/__init__.py b/tabrepo/benchmark/models/ag/beta/deps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabm_utils.py b/tabrepo/benchmark/models/ag/beta/deps/tabm_utils.py new file mode 100644 index 00000000..73338b39 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/tabm_utils.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from typing import Literal + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +_TORCH = None + + +@torch.inference_mode() +def _init_scaling_by_sections( + weight: Tensor, + distribution: Literal["normal", "random-signs"], + init_sections: list[int], +) -> None: + """Initialize the (typically, first) scaling in a special way. + + For a given efficient emsemble member, all weights within one section + are initialized with the same value. + Typically, one section corresponds to one feature. + """ + assert weight.ndim == 2 + # print(weight.shape) + # print(init_sections) + assert weight.shape[1] == sum(init_sections) + + if distribution == "normal": + init_fn_ = nn.init.normal_ + elif distribution == "random-signs": + init_fn_ = init_random_signs_ + else: + raise ValueError(f"Unknown distribution: {distribution}") + + section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()] + for i in range(len(init_sections)): + w = torch.empty((len(weight), 1), dtype=weight.dtype, device=weight.device) + init_fn_(w) + weight[:, section_bounds[i] : section_bounds[i + 1]] = w + + +def _torch(): + global _TORCH + if _TORCH is None: + import torch + + _TORCH = torch + return _TORCH + + +def is_oom_exception(err: RuntimeError) -> bool: + return isinstance(err, _torch().cuda.OutOfMemoryError) or any( + x in str(err) + for x in [ + "CUDA out of memory", + "CUBLAS_STATUS_ALLOC_FAILED", + "CUDA error: out of memory", + ] + ) + + +# ====================================================================================== +# Initialization +# ====================================================================================== +def init_rsqrt_uniform_(x: Tensor, d: int) -> Tensor: + assert d > 0 + d_rsqrt = d**-0.5 + return nn.init.uniform_(x, -d_rsqrt, d_rsqrt) + + +@torch.inference_mode() +def init_random_signs_(x: Tensor) -> Tensor: + return x.bernoulli_(0.5).mul_(2).add_(-1) + + +# ====================================================================================== +# Modules +# ====================================================================================== +class Identity(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + return x + + +class Mean(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + return x.mean(dim=self.dim) + + +class ScaleEnsemble(nn.Module): + def __init__( + self, + k: int, + d: int, + *, + init: Literal["ones", "normal", "random-signs"], + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(k, d)) + self._weight_init = init + self.reset_parameters() + + def reset_parameters(self) -> None: + if self._weight_init == "ones": + nn.init.ones_(self.weight) + elif self._weight_init == "normal": + nn.init.normal_(self.weight) + elif self._weight_init == "random-signs": + init_random_signs_(self.weight) + else: + raise ValueError(f"Unknown weight_init: {self._weight_init}") + + def forward(self, x: Tensor) -> Tensor: + assert x.ndim >= 2 + return x * self.weight + + +class ElementwiseAffineEnsemble(nn.Module): + def __init__( + self, + k: int, + d: int, + *, + weight: bool, + bias: bool, + weight_init: Literal["ones", "normal", "random-signs"], + ) -> None: + assert weight or bias + super().__init__() + self.weight = nn.Parameter(torch.empty(k, d)) if weight else None + self.bias = nn.Parameter(torch.empty(k, d)) if bias else None + self._weight_init = weight_init + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.weight is not None: + if self._weight_init == "ones": + nn.init.ones_(self.weight) + elif self._weight_init == "normal": + nn.init.normal_(self.weight) + elif self._weight_init == "random-signs": + init_random_signs_(self.weight) + else: + raise ValueError(f"Unknown weight_init: {self._weight_init}") + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + # x.shape == (B, K, D) + assert x.ndim == 3 + return ( + x * self.weight + if self.bias is None + else x + self.bias + if self.weight is None + else torch.addcmul(self.bias, self.weight, x) + ) + + +class LinearEfficientEnsemble(nn.Module): + """This layer is a more configurable version of the "BatchEnsemble" layer + from the paper + "BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning" + (link: https://arxiv.org/abs/2002.06715). + + First, this layer allows to select only some of the "ensembled" parts: + - the input scaling (r_i in the BatchEnsemble paper) + - the output scaling (s_i in the BatchEnsemble paper) + - the output bias (not mentioned in the BatchEnsemble paper, + but is presented in public implementations) + + Second, the initialization of the scaling weights is configurable + through the `scaling_init` argument. + """ + + r: None | Tensor + s: None | Tensor + bias: None | Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + *, + k: int, + ensemble_scaling_in: bool, + ensemble_scaling_out: bool, + ensemble_bias: bool, + scaling_init: Literal["ones", "random-signs"], + ): + assert k > 0 + if ensemble_bias: + assert bias + super().__init__() + + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + self.register_parameter( + "r", + ( + nn.Parameter(torch.empty(k, in_features)) + if ensemble_scaling_in + else None + ), # type: ignore[code] + ) + self.register_parameter( + "s", + ( + nn.Parameter(torch.empty(k, out_features)) + if ensemble_scaling_out + else None + ), # type: ignore[code] + ) + self.register_parameter( + "bias", + ( + nn.Parameter(torch.empty(out_features)) # type: ignore[code] + if bias and not ensemble_bias + else nn.Parameter(torch.empty(k, out_features)) + if ensemble_bias + else None + ), + ) + + self.in_features = in_features + self.out_features = out_features + self.k = k + self.scaling_init = scaling_init + + self.reset_parameters() + + def reset_parameters(self): + init_rsqrt_uniform_(self.weight, self.in_features) + scaling_init_fn = {"ones": nn.init.ones_, "random-signs": init_random_signs_}[ + self.scaling_init + ] + if self.r is not None: + scaling_init_fn(self.r) + if self.s is not None: + scaling_init_fn(self.s) + if self.bias is not None: + bias_init = torch.empty( + # NOTE: the shape of bias_init is (out_features,) not (k, out_features). + # It means that all biases have the same initialization. + # This is similar to having one shared bias plus + # k zero-initialized non-shared biases. + self.out_features, + dtype=self.weight.dtype, + device=self.weight.device, + ) + bias_init = init_rsqrt_uniform_(bias_init, self.in_features) + with torch.inference_mode(): + self.bias.copy_(bias_init) + + def forward(self, x: Tensor) -> Tensor: + # x.shape == (B, K, D) + assert x.ndim == 3 + + # >>> The equation (5) from the BatchEnsemble paper (arXiv v2). + if self.r is not None: + x = x * self.r + x = x @ self.weight.T + if self.s is not None: + x = x * self.s + # <<< + + if self.bias is not None: + x = x + self.bias + return x + + +def make_efficient_ensemble(module: nn.Module, **kwargs) -> None: + for name, submodule in list(module.named_children()): + if isinstance(submodule, nn.Linear): + module.add_module( + name, + LinearEfficientEnsemble( + in_features=submodule.in_features, + out_features=submodule.out_features, + bias=submodule.bias is not None, + **kwargs, + ), + ) + else: + make_efficient_ensemble(submodule, **kwargs) + + +class OneHotEncoding0d(nn.Module): + # Input: (*, n_cat_features=len(cardinalities)) + # Output: (*, sum(cardinalities)) + + def __init__(self, cardinalities: list[int]) -> None: + super().__init__() + self._cardinalities = cardinalities + + def forward(self, x: Tensor) -> Tensor: + assert x.ndim >= 1 + assert x.shape[-1] == len(self._cardinalities) + + return torch.cat( + [ + # NOTE + # This is a quick hack to support out-of-vocabulary categories. + # + # Recall that lib.data.transform_cat encodes categorical features + # as follows: + # - In-vocabulary values receive indices from `range(cardinality)`. + # - All out-of-vocabulary values (i.e. new categories in validation + # and test data that are not presented in the training data) + # receive the index `cardinality`. + # + # As such, the line below will produce the standard one-hot encoding for + # known categories, and the all-zeros encoding for unknown categories. + # This may not be the best approach to deal with unknown values, + # but should be enough for our purposes. + F.one_hot(x[..., i], cardinality + 1)[..., :-1] + for i, cardinality in enumerate(self._cardinalities) + ], + -1, + ) diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabpfn/__init__.py b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabpfn/encoders.py b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/encoders.py new file mode 100644 index 00000000..3f7d0b73 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/encoders.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import math + +import torch +from torch import nn + + +def torch_masked_mean(x, mask, dim=0, return_share_of_ignored_values=False): + """Returns the mean of a torch tensor and only considers the elements, where the mask is true. + If return_share_of_ignored_values is true it returns a second tensor with the percentage of ignored values + because of the mask. + """ + num = torch.where(mask, torch.full_like(x, 1), torch.full_like(x, 0)).sum(dim=dim) + value = torch.where(mask, x, torch.full_like(x, 0)).sum(dim=dim) + if return_share_of_ignored_values: + return value / num, 1.0 - num / x.shape[dim] + return value / num + + +def torch_masked_std(x, mask, dim=0): + """Returns the std of a torch tensor and only considers the elements, where the mask is true. + If get_mean is true it returns as a first Tensor the mean and as a second tensor the std. + """ + num = torch.where(mask, torch.full_like(x, 1), torch.full_like(x, 0)).sum(dim=dim) + value = torch.where(mask, x, torch.full_like(x, 0)).sum(dim=dim) + mean = value / num + mean_broadcast = torch.repeat_interleave(mean.unsqueeze(dim), x.shape[dim], dim=dim) + quadratic_difference_from_mean = torch.square( + torch.where(mask, mean_broadcast - x, torch.full_like(x, 0)) + ) + return torch.sqrt(torch.sum(quadratic_difference_from_mean, dim=dim) / (num - 1)) + + +def torch_nanmean(x, dim=0, return_nanshare=False): + return torch_masked_mean( + x, ~torch.isnan(x), dim=dim, return_share_of_ignored_values=return_nanshare + ) + + +def torch_nanstd(x, dim=0): + return torch_masked_std(x, ~torch.isnan(x), dim=dim) + + +def normalize_data(data, normalize_positions=-1): + if normalize_positions > 0: + mean = torch_nanmean(data[:normalize_positions], dim=0) + std = torch_nanstd(data[:normalize_positions], dim=0) + 0.000001 + else: + mean = torch_nanmean(data, dim=0) + std = torch_nanstd(data, dim=0) + 0.000001 + data = (data - mean) / std + return torch.clip(data, min=-100, max=100) + + +class _PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.0): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + self.d_model = d_model + self.device_test_tensor = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x): # T x B x num_features + assert self.d_model % x.shape[-1] * 2 == 0 + d_per_feature = self.d_model // x.shape[-1] + pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device) + # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + interval_size = 10 + div_term = ( + (1.0 / interval_size) + * 2 + * math.pi + * torch.exp( + torch.arange( + 0, d_per_feature, 2, device=self.device_test_tensor.device + ).float() + * math.log(math.sqrt(2)) + ) + ) + # print(div_term/2/math.pi) + pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term) + pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term) + return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model) + + +Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize) + + +class EmbeddingEncoder(nn.Module): + def __init__(self, num_features, em_size, num_embs=100): + super().__init__() + self.num_embs = num_embs + self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True) + self.init_weights(0.1) + self.min_max = (-2, +2) + + @property + def width(self): + return self.min_max[1] - self.min_max[0] + + def init_weights(self, initrange): + self.embeddings.weight.data.uniform_(-initrange, initrange) + + def discretize(self, x): + split_size = self.width / self.num_embs + return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1) + + def forward(self, x): # T x B x num_features + x_idxs = self.discretize(x) + x_idxs += ( + torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs + ) + # print(x_idxs,self.embeddings.weight.shape) + return self.embeddings(x_idxs).mean(-2) + + +Linear = nn.Linear +MLP = lambda num_features, emsize: nn.Sequential( + nn.Linear(num_features + 1, emsize * 2), nn.ReLU(), nn.Linear(emsize * 2, emsize) +) + + +class NanHandlingEncoder(nn.Module): + def __init__(self, num_features, emsize, keep_nans=True): + super().__init__() + self.num_features = 2 * num_features if keep_nans else num_features + self.emsize = emsize + self.keep_nans = keep_nans + self.layer = nn.Linear(self.num_features, self.emsize) + + def forward(self, x): + if self.keep_nans: + x = torch.cat( + [ + torch.nan_to_num(x, nan=0.0), + normalize_data( + torch.isnan(x) * -1 + + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1 + + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2 + ), + ], + -1, + ) + else: + x = torch.nan_to_num(x, nan=0.0) + return self.layer(x) + + +class Linear(nn.Linear): + def __init__(self, num_features, emsize, replace_nan_by_zero=False): + super().__init__(num_features, emsize) + self.num_features = num_features + self.emsize = emsize + self.replace_nan_by_zero = replace_nan_by_zero + + def forward(self, x): + if self.replace_nan_by_zero: + x = torch.nan_to_num(x, nan=0.0) + return super().forward(x) + + def __setstate__(self, state): + super().__setstate__(state) + self.__dict__.setdefault("replace_nan_by_zero", True) + + +class Conv(nn.Module): + def __init__(self, input_size, emsize): + super().__init__() + self.convs = torch.nn.ModuleList( + [nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)] + ) + self.linear = nn.Linear(64, emsize) + + def forward(self, x): + size = math.isqrt(x.shape[-1]) + assert size * size == x.shape[-1] + x = x.reshape(*x.shape[:-1], 1, size, size) + for conv in self.convs: + if x.shape[-1] < 4: + break + x = conv(x) + x.relu_() + x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1) + return self.linear(x) + + +class CanEmb(nn.Embedding): + def __init__( + self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs + ): + assert embedding_dim % num_features == 0 + embedding_dim = embedding_dim // num_features + super().__init__(num_embeddings, embedding_dim, *args, **kwargs) + + def forward(self, x): + lx = x.long() + assert (lx == x).all(), "CanEmb only works with tensors of whole numbers" + x = super().forward(lx) + return x.view(*x.shape[:-2], -1) + + +class RegressionEmbedding(nn.Module): + def __init__(self, num_features, emsize): + super().__init__() + self.num_features = num_features + self.emsize = emsize + self.layer = nn.Parameter(torch.ones(num_features, emsize)) + + def forward(self, x): + return self.layer[None] * x[:, :, None].squeeze(-2) + + +def get_Canonical(num_classes): + return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize) + + +def get_Embedding(num_embs_per_feature=100): + return lambda num_features, emsize: EmbeddingEncoder( + num_features, emsize, num_embs=num_embs_per_feature + ) + + +def get_Regression(num_features, emsize=512): + return RegressionEmbedding(num_features, emsize) diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabpfn/layer.py b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/layer.py new file mode 100644 index 00000000..8c1cc452 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/layer.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from functools import partial + +import torch +from torch.nn import functional as F +from torch.nn.modules.transformer import ( + Dropout, + LayerNorm, + Linear, + Module, + MultiheadAttention, + Optional, + Tensor, + _get_activation_fn, +) +from torch.utils.checkpoint import checkpoint + + +class TransformerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False``. + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + + Alternatively, when ``batch_first`` is ``True``: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + """ + + __constants__ = ["batch_first"] + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + layer_norm_eps=1e-5, + batch_first=False, + pre_norm=False, + device=None, + dtype=None, + recompute_attn=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # print(dtype) + super().__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.pre_norm = pre_norm + self.recompute_attn = recompute_attn + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = F.relu + super().__setstate__(state) + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + src_ = self.norm1(src) if self.pre_norm else src + d_type = src.dtype + if isinstance(src_mask, tuple): + # global attention setup + assert not self.self_attn.batch_first + assert src_key_padding_mask is None + + global_src_mask, trainset_src_mask, valset_src_mask = src_mask + + num_global_tokens = global_src_mask.shape[0] + num_train_tokens = trainset_src_mask.shape[0] + + global_tokens_src = src_[:num_global_tokens] + train_tokens_src = src_[ + num_global_tokens : num_global_tokens + num_train_tokens + ] + global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens] + eval_tokens_src = src_[num_global_tokens + num_train_tokens :] + + attn = ( + partial(checkpoint, self.self_attn) + if self.recompute_attn + else self.self_attn + ) + + global_tokens_src2 = attn( + global_tokens_src, + global_and_train_tokens_src, + global_and_train_tokens_src, + None, + True, + global_src_mask, + )[0] + train_tokens_src2 = attn( + train_tokens_src, + global_tokens_src, + global_tokens_src, + None, + True, + trainset_src_mask, + )[0] + eval_tokens_src2 = attn( + eval_tokens_src, src_, src_, None, True, valset_src_mask + )[0] + + src2 = torch.cat( + [global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0 + ) + elif isinstance(src_mask, int): + assert src_key_padding_mask is None + single_eval_position = src_mask + src_left = self.self_attn( + src_[:single_eval_position], + src_[:single_eval_position], + src_[:single_eval_position], + )[0] + src_right = self.self_attn( + src_[single_eval_position:], + src_[:single_eval_position], + src_[:single_eval_position], + )[0] + src2 = torch.cat([src_left, src_right], dim=0) + elif self.recompute_attn: + src2 = checkpoint( + self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask + )[0] + else: + src2 = self.self_attn( + src_, + src_, + src_, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout1(src2) + if not self.pre_norm: + src = self.norm1(src) + + src_ = self.norm2(src) if self.pre_norm else src + src2 = self.linear2(self.dropout(self.activation(self.linear1(src_)))) + src = src.to(d_type) + self.dropout2(src2).to(d_type) + + if not self.pre_norm: + src = self.norm2(src) + + return src diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabpfn/transformer.py b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/transformer.py new file mode 100644 index 00000000..33307c92 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/transformer.py @@ -0,0 +1,372 @@ +from __future__ import annotations + +import math + +import torch +from tabrepo.benchmark.models.ag.beta.deps.tabpfn import layer +from torch import Tensor, nn +from torch.nn import Module, TransformerEncoder + + +def bool_mask_to_att_mask(mask): + return ( + mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) + ) + + +class SeqBN(nn.Module): + def __init__(self, d_model): + super().__init__() + self.bn = nn.BatchNorm1d(d_model) + self.d_model = d_model + + def forward(self, x): + assert self.d_model == x.shape[-1] + flat_x = x.view(-1, self.d_model) + flat_x = self.bn(flat_x) + return flat_x.view(*x.shape) + + +class TransformerModel(nn.Module): + def __init__( + self, + encoder, + n_out, + ninp, + nhead, + nhid, + nlayers, + dropout=0.0, + style_encoder=None, + y_encoder=None, + pos_encoder=None, + decoder=None, + input_normalization=False, + init_method=None, + pre_norm=False, + activation="gelu", + recompute_attn=False, + num_global_att_tokens=0, + full_attention=False, + all_layers_same_init=False, + efficient_eval_masking=True, + ): + super().__init__() + self.model_type = "Transformer" + encoder_layer_creator = lambda: layer.TransformerEncoderLayer( + ninp, + nhead, + nhid, + dropout, + activation=activation, + pre_norm=pre_norm, + recompute_attn=recompute_attn, + ) + # encoder_layer_creator = lambda: layer.FlashAttentionTransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation, + # pre_norm=pre_norm, recompute_attn=recompute_attn) + self.transformer_encoder = ( + TransformerEncoder(encoder_layer_creator(), nlayers) + if all_layers_same_init + else TransformerEncoderDiffInit(encoder_layer_creator, nlayers) + ) + self.ninp = ninp + self.encoder = encoder + self.y_encoder = y_encoder + self.pos_encoder = pos_encoder + self.decoder = ( + decoder(ninp, nhid, n_out) + if decoder is not None + else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out)) + ) + self.input_ln = SeqBN(ninp) if input_normalization else None + self.style_encoder = style_encoder + self.init_method = init_method + if num_global_att_tokens is not None: + assert not full_attention + self.global_att_embeddings = ( + nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None + ) + self.full_attention = full_attention + self.efficient_eval_masking = efficient_eval_masking + + self.n_out = n_out + self.nhid = nhid + + self.init_weights() + + def __setstate__(self, state): + super().__setstate__(state) + self.__dict__.setdefault("efficient_eval_masking", False) + + @staticmethod + def generate_square_subsequent_mask(sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + return bool_mask_to_att_mask(mask) + + @staticmethod + def generate_D_q_matrix(sz, query_size): + train_size = sz - query_size + mask = torch.zeros(sz, sz) == 0 + mask[:, train_size:].zero_() + mask |= torch.eye(sz) == 1 + return bool_mask_to_att_mask(mask) + + @staticmethod + def generate_global_att_query_matrix( + num_global_att_tokens, seq_len, num_query_tokens + ): + train_size = seq_len + num_global_att_tokens - num_query_tokens + sz = seq_len + num_global_att_tokens + mask = torch.zeros(num_query_tokens, sz) == 0 + mask[:, train_size:].zero_() + mask[:, train_size:] |= torch.eye(num_query_tokens) == 1 + return bool_mask_to_att_mask(mask) + + @staticmethod + def generate_global_att_trainset_matrix( + num_global_att_tokens, seq_len, num_query_tokens + ): + seq_len + num_global_att_tokens - num_query_tokens + trainset_size = seq_len - num_query_tokens + mask = torch.zeros(trainset_size, num_global_att_tokens) == 0 + # mask[:,num_global_att_tokens:].zero_() + # mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1 + return bool_mask_to_att_mask(mask) + + @staticmethod + def generate_global_att_globaltokens_matrix( + num_global_att_tokens, seq_len, num_query_tokens + ): + mask = ( + torch.zeros( + num_global_att_tokens, + num_global_att_tokens + seq_len - num_query_tokens, + ) + == 0 + ) + return bool_mask_to_att_mask(mask) + + def init_weights(self): + # if isinstance(self.encoder,EmbeddingEncoder): + # self.encoder.weight.data.uniform_(-initrange, initrange) + # self.decoder.bias.data.zero_() + # self.decoder.weight.data.uniform_(-initrange, initrange) + if self.init_method is not None: + self.apply(self.init_method) + for layer in self.transformer_encoder.layers: + nn.init.zeros_(layer.linear2.weight) + nn.init.zeros_(layer.linear2.bias) + attns = ( + layer.self_attn + if isinstance(layer.self_attn, nn.ModuleList) + else [layer.self_attn] + ) + for attn in attns: + nn.init.zeros_(attn.out_proj.weight) + nn.init.zeros_(attn.out_proj.bias) + + def forward(self, src, src_mask=None, single_eval_pos=None): + assert isinstance(src, tuple), ( + "inputs (src) have to be given as (x,y) or (style,x,y) tuple" + ) + + if len(src) == 2: # (x,y) and no style + src = (None, *src) + + style_src, x_src, y_src = src + # print(self.encoder) + x_src = self.encoder(x_src) + + y_src = self.y_encoder( + y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src + ) + style_src = ( + self.style_encoder(style_src).unsqueeze(0) + if self.style_encoder + else torch.tensor([], device=x_src.device) + ) + global_src = ( + torch.tensor([], device=x_src.device) + if self.global_att_embeddings is None + else self.global_att_embeddings.weight.unsqueeze(1).repeat( + 1, x_src.shape[1], 1 + ) + ) + + if src_mask is not None: + assert self.global_att_embeddings is None or isinstance(src_mask, tuple) + if src_mask is None: + if self.global_att_embeddings is None: + full_len = len(x_src) + len(style_src) + if self.full_attention: + src_mask = bool_mask_to_att_mask( + torch.ones((full_len, full_len), dtype=torch.bool) + ).to(x_src.device) + elif self.efficient_eval_masking: + src_mask = single_eval_pos + len(style_src) + else: + src_mask = self.generate_D_q_matrix( + full_len, len(x_src) - single_eval_pos + ).to(x_src.device) + else: + src_mask_args = ( + self.global_att_embeddings.num_embeddings, + len(x_src) + len(style_src), + len(x_src) + len(style_src) - single_eval_pos, + ) + src_mask = ( + self.generate_global_att_globaltokens_matrix(*src_mask_args).to( + x_src.device + ), + self.generate_global_att_trainset_matrix(*src_mask_args).to( + x_src.device + ), + self.generate_global_att_query_matrix(*src_mask_args).to( + x_src.device + ), + ) + # print(x_src.shape, y_src.shape) + + train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos] + # print(train_x.dtype) + src = torch.cat( + [global_src, style_src, train_x, x_src[single_eval_pos:]], 0 + ).to(train_x.dtype) + + if self.input_ln is not None: + src = self.input_ln(src) + + if self.pos_encoder is not None: + src = self.pos_encoder(src) + # print(src.dtype) + output = self.transformer_encoder(src, src_mask) + output = self.decoder(output) + # print(output.shape) + # print(output[single_eval_pos+len(style_src)+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]) + return output[ + single_eval_pos + + len(style_src) + + ( + self.global_att_embeddings.num_embeddings + if self.global_att_embeddings + else 0 + ) : + ] + + @torch.no_grad() + def init_from_small_model(self, small_model): + assert isinstance(self.decoder, nn.Linear) + assert isinstance(self.encoder, (nn.Linear, nn.Sequential)) + assert isinstance(self.y_encoder, (nn.Linear, nn.Sequential)) + + def set_encoder_weights(my_encoder, small_model_encoder): + my_encoder_linear, small_encoder_linear = ( + (my_encoder, small_model_encoder) + if isinstance(my_encoder, nn.Linear) + else (my_encoder[-1], small_model_encoder[-1]) + ) + small_in_dim = small_encoder_linear.out_features + my_encoder_linear.weight.zero_() + my_encoder_linear.bias.zero_() + my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight + my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias + + set_encoder_weights(self.encoder, small_model.encoder) + set_encoder_weights(self.y_encoder, small_model.y_encoder) + + small_in_dim = small_model.decoder.in_features + + self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight + self.decoder.bias = small_model.decoder.bias + + for my_layer, small_layer in zip( + self.transformer_encoder.layers, small_model.transformer_encoder.layers + ): + small_hid_dim = small_layer.linear1.out_features + my_in_dim = my_layer.linear1.in_features + + # packed along q,k,v order in first dim + my_in_proj_w = my_layer.self_attn.in_proj_weight + small_in_proj_w = small_layer.self_attn.in_proj_weight + + my_in_proj_w.view(3, my_in_dim, my_in_dim)[ + :, :small_in_dim, :small_in_dim + ] = small_in_proj_w.view(3, small_in_dim, small_in_dim) + my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = ( + small_layer.self_attn.in_proj_bias.view(3, small_in_dim) + ) + + my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = ( + small_layer.self_attn.out_proj.weight + ) + my_layer.self_attn.out_proj.bias[:small_in_dim] = ( + small_layer.self_attn.out_proj.bias + ) + + my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = ( + small_layer.linear1.weight + ) + my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias + + my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = ( + small_layer.linear2.weight + ) + my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias + + my_layer.norm1.weight[:small_in_dim] = ( + math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight + ) + my_layer.norm2.weight[:small_in_dim] = ( + math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight + ) + + my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias + my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias + + +class TransformerEncoderDiffInit(Module): + r"""TransformerEncoder is a stack of N encoder layers. + + Args: + encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + """ + + __constants__ = ["norm"] + + def __init__(self, encoder_layer_creator, num_layers, norm=None): + super().__init__() + self.layers = nn.ModuleList( + [encoder_layer_creator() for _ in range(num_layers)] + ) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + + for mod in self.layers: + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + + if self.norm is not None: + output = self.norm(output) + + return output diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabpfn/utils.py b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/utils.py new file mode 100644 index 00000000..1c3e259d --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/tabpfn/utils.py @@ -0,0 +1,1181 @@ +from __future__ import annotations + +import io +import os +import pickle +import random +from functools import partial +from pathlib import Path + +import numpy as np +import torch +from sklearn.preprocessing import ( + PowerTransformer, + QuantileTransformer, + RobustScaler, +) +from tabrepo.benchmark.models.ag.beta.deps.tabpfn import encoders, transformer +from torch.utils.checkpoint import checkpoint + + +def torch_masked_mean(x, mask, dim=0, return_share_of_ignored_values=False): + """Returns the mean of a torch tensor and only considers the elements, where the mask is true. + If return_share_of_ignored_values is true it returns a second tensor with the percentage of ignored values + because of the mask. + """ + num = torch.where(mask, torch.full_like(x, 1), torch.full_like(x, 0)).sum(dim=dim) + value = torch.where(mask, x, torch.full_like(x, 0)).sum(dim=dim) + if return_share_of_ignored_values: + return value / num, 1.0 - num / x.shape[dim] + return value / num + + +def torch_masked_std(x, mask, dim=0): + """Returns the std of a torch tensor and only considers the elements, where the mask is true. + If get_mean is true it returns as a first Tensor the mean and as a second tensor the std. + """ + num = torch.where(mask, torch.full_like(x, 1), torch.full_like(x, 0)).sum(dim=dim) + value = torch.where(mask, x, torch.full_like(x, 0)).sum(dim=dim) + mean = value / num + mean_broadcast = torch.repeat_interleave(mean.unsqueeze(dim), x.shape[dim], dim=dim) + quadratic_difference_from_mean = torch.square( + torch.where(mask, mean_broadcast - x, torch.full_like(x, 0)) + ) + return torch.sqrt(torch.sum(quadratic_difference_from_mean, dim=dim) / (num - 1)) + + +def torch_nanmean(x, dim=0, return_nanshare=False): + return torch_masked_mean( + x, ~torch.isnan(x), dim=dim, return_share_of_ignored_values=return_nanshare + ) + + +def torch_nanstd(x, dim=0): + return torch_masked_std(x, ~torch.isnan(x), dim=dim) + + +def normalize_data(data, normalize_positions=-1): + if normalize_positions > 0: + mean = torch_nanmean(data[:normalize_positions], dim=0) + std = torch_nanstd(data[:normalize_positions], dim=0) + 0.000001 + else: + mean = torch_nanmean(data, dim=0) + std = torch_nanstd(data, dim=0) + 0.000001 + data = (data - mean) / std + return torch.clip(data, min=-100, max=100) + + +def normalize_by_used_features_f( + x, num_features_used, num_features, normalize_with_sqrt=False +): + if normalize_with_sqrt: + return x / (num_features_used / num_features) ** (1 / 2) + return x / (num_features_used / num_features) + + +def to_ranking_low_mem(data): + x = torch.zeros_like(data) + for col in range(data.shape[-1]): + x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2) + x_ = x_.sum(0) + x[:, :, col] = x_ + return x + + +def remove_outliers(X, n_sigma=4, normalize_positions=-1): + # Expects T, B, H + assert len(X.shape) == 3, "X must be T,B,H" + + data = X if normalize_positions == -1 else X[:normalize_positions] + + data_mean, data_std = torch_nanmean(data, dim=0), torch_nanstd(data, dim=0) + cut_off = data_std * n_sigma + lower, upper = data_mean - cut_off, data_mean + cut_off + + mask = (data <= upper) & (data >= lower) & ~torch.isnan(data) + data_mean, data_std = torch_masked_mean(data, mask), torch_masked_std(data, mask) + + cut_off = data_std * n_sigma + lower, upper = data_mean - cut_off, data_mean + cut_off + + X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X) + return torch.minimum(torch.log(1 + torch.abs(X)) + upper, X) + # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std) + + +def load_model_only_inference(path, filename, device): + """Loads a saved model from the specified position. This function only restores inference capabilities and + cannot be used for further training. + """ + model_state, optimizer_state, config_sample = torch.load( + os.path.join(path, filename), map_location="cpu", weights_only=False, + ) + + if ( + ( + "nan_prob_no_reason" in config_sample + and config_sample["nan_prob_no_reason"] > 0.0 + ) + or ( + "nan_prob_a_reason" in config_sample + and config_sample["nan_prob_a_reason"] > 0.0 + ) + or ( + "nan_prob_unknown_reason" in config_sample + and config_sample["nan_prob_unknown_reason"] > 0.0 + ) + ): + encoder = encoders.NanHandlingEncoder + else: + encoder = partial(encoders.Linear, replace_nan_by_zero=True) + + n_out = config_sample["max_num_classes"] + device = device if torch.cuda.is_available() else "cpu:0" + encoder = encoder(config_sample["num_features"], config_sample["emsize"]) + + nhid = config_sample["emsize"] * config_sample["nhid_factor"] + y_encoder_generator = ( + encoders.get_Canonical(config_sample["max_num_classes"]) + if config_sample.get("canonical_y_encoder", False) + else encoders.Linear + ) + + assert config_sample["max_num_classes"] > 2 + loss = torch.nn.CrossEntropyLoss( + reduction="none", weight=torch.ones(int(config_sample["max_num_classes"])) + ) + + model = transformer.TransformerModel( + encoder, + n_out, + config_sample["emsize"], + config_sample["nhead"], + nhid, + config_sample["nlayers"], + y_encoder=y_encoder_generator(1, config_sample["emsize"]), + dropout=config_sample["dropout"], + efficient_eval_masking=config_sample["efficient_eval_masking"], + ) + + # print(f"Using a Transformer with {sum(p.numel() for p in model.parameters()) / 1000 / 1000:.{2}f} M parameters") + + model.criterion = loss + module_prefix = "module." + model_state = {k.replace(module_prefix, ""): v for k, v in model_state.items()} + model.load_state_dict(model_state) + model.to(device) + model.eval() + + return (float("inf"), float("inf"), model), config_sample # no loss measured + + +def load_model_only_inference_regression(path, filename, device): + """Loads a saved model from the specified position. This function only restores inference capabilities and + cannot be used for further training. + """ + model_state, optimizer_state, config_sample = torch.load( + os.path.join(path, filename), map_location="cpu" + ) + # file_path = '/data1/Benchmark/T1/model/models/models_diff/prior_diff_real_checkpoint_multiclass_12_30_2024_23_42_32_n_0_epoch_52.cpkt' + # model_state, optimizer_state, config_sample = torch.load(file_path, map_location='cpu') + if ( + ( + "nan_prob_no_reason" in config_sample + and config_sample["nan_prob_no_reason"] > 0.0 + ) + or ( + "nan_prob_a_reason" in config_sample + and config_sample["nan_prob_a_reason"] > 0.0 + ) + or ( + "nan_prob_unknown_reason" in config_sample + and config_sample["nan_prob_unknown_reason"] > 0.0 + ) + ): + encoder = encoders.NanHandlingEncoder + else: + encoder = partial(encoders.Linear, replace_nan_by_zero=True) + n_out = config_sample["max_num_classes"] + device = device if torch.cuda.is_available() else "cpu:0" + encoder = encoder(config_sample["num_features"], config_sample["emsize"]) + + nhid = config_sample["emsize"] * config_sample["nhid_factor"] + y_encoder_generator = ( + encoders.get_Canonical(config_sample["max_num_classes"]) + if config_sample.get("canonical_y_encoder", False) + else encoders.Linear + ) + + assert config_sample["max_num_classes"] > 2 + loss = torch.nn.CrossEntropyLoss( + reduction="none", weight=torch.ones(int(config_sample["max_num_classes"])) + ) + + model = transformer.TransformerModel( + encoder, + n_out, + config_sample["emsize"], + config_sample["nhead"], + nhid, + config_sample["nlayers"], + y_encoder=y_encoder_generator(1, config_sample["emsize"]), + dropout=config_sample["dropout"], + efficient_eval_masking=config_sample["efficient_eval_masking"], + ) + + # print(f"Using a Transformer with {sum(p.numel() for p in model.parameters()) / 1000 / 1000:.{2}f} M parameters") + # y_encoder_generator = y_encoder_generator(1, config_sample['emsize']) + model.criterion = loss + module_prefix = "module." + # model.y_encoder = encoders.get_Regression(1) + + model_state = {k.replace(module_prefix, ""): v for k, v in model_state.items()} + model.load_state_dict(model_state) + + # model.decoder = nn.Sequential(nn.Linear(512, nhid), nn.GELU(), nn.Linear(nhid, 1)) + model.to(device) + model.eval() + + return (float("inf"), float("inf"), model), config_sample # no loss measured + + +def fix_loaded_config_sample(loaded_config_sample, config): + def copy_to_sample(*k): + t, s = loaded_config_sample, config + for k_ in k[:-1]: + t = t[k_] + s = s[k_] + t[k[-1]] = s[k[-1]] + + copy_to_sample("num_features_used") + copy_to_sample("num_classes") + copy_to_sample( + "differentiable_hyperparameters", "prior_mlp_activations", "choice_values" + ) + + +def load_config_sample(path, template_config): + model_state, optimizer_state, loaded_config_sample = torch.load( + path, map_location="cpu" + ) + fix_loaded_config_sample(loaded_config_sample, template_config) + return loaded_config_sample + + +def get_default_spec(test_datasets, valid_datasets): + bptt = 10000 + eval_positions = [ + 1000, + 2000, + 3000, + 4000, + 5000, + ] # list(2 ** np.array([4, 5, 6, 7, 8, 9, 10, 11, 12])) + max_features = max( + [X.shape[1] for (_, X, _, _, _, _) in test_datasets] + + [X.shape[1] for (_, X, _, _, _, _) in valid_datasets] + ) + max_splits = 5 + + return bptt, eval_positions, max_features, max_splits + + +def get_mlp_prior_hyperparameters(config): + from tabpfn.priors.utils import gamma_sampler_f + + config = { + hp: (next(iter(config[hp].values()))) + if type(config[hp]) is dict + else config[hp] + for hp in config + } + + if "random_feature_rotation" not in config: + config["random_feature_rotation"] = True + + if "prior_sigma_gamma_k" in config: + sigma_sampler = gamma_sampler_f( + config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"] + ) + config["init_std"] = sigma_sampler + if "prior_noise_std_gamma_k" in config: + noise_std_sampler = gamma_sampler_f( + config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"] + ) + config["noise_std"] = noise_std_sampler + + return config + + +def get_gp_mix_prior_hyperparameters(config): + return { + "lengthscale_concentration": config["prior_lengthscale_concentration"], + "nu": config["prior_nu"], + "outputscale_concentration": config["prior_outputscale_concentration"], + "categorical_data": config["prior_y_minmax_norm"], + "y_minmax_norm": config["prior_lengthscale_concentration"], + "noise_concentration": config["prior_noise_concentration"], + "noise_rate": config["prior_noise_rate"], + } + + +def get_gp_prior_hyperparameters(config): + return { + hp: (next(iter(config[hp].values()))) + if type(config[hp]) is dict + else config[hp] + for hp in config + } + + +def get_meta_gp_prior_hyperparameters(config): + from tabpfn.priors.utils import trunc_norm_sampler_f + + config = { + hp: (next(iter(config[hp].values()))) + if type(config[hp]) is dict + else config[hp] + for hp in config + } + + if "outputscale_mean" in config: + outputscale_sampler = trunc_norm_sampler_f( + config["outputscale_mean"], + config["outputscale_mean"] * config["outputscale_std_f"], + ) + config["outputscale"] = outputscale_sampler + if "lengthscale_mean" in config: + lengthscale_sampler = trunc_norm_sampler_f( + config["lengthscale_mean"], + config["lengthscale_mean"] * config["lengthscale_std_f"], + ) + config["lengthscale"] = lengthscale_sampler + + return config + + +def get_uniform_single_eval_pos_sampler(max_len, min_len=0): + """Just sample any evaluation position with the same weight + :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`. + """ + return lambda: random.choices(range(min_len, max_len))[0] + + +def get_model( + config, + device, + should_train=True, + verbose=False, + state_dict=None, + epoch_callback=None, +): + import math + + from tabpfn import encoders, priors + from tabpfn.train import Losses, train + + extra_kwargs = {} + verbose_train, verbose_prior = verbose >= 1, verbose >= 2 + config["verbose"] = verbose_prior + + if "aggregate_k_gradients" not in config or config["aggregate_k_gradients"] is None: + config["aggregate_k_gradients"] = math.ceil( + config["batch_size"] + * ( + (config["nlayers"] * config["emsize"] * config["bptt"] * config["bptt"]) + / 10824640000 + ) + ) + + config["num_steps"] = math.ceil( + config["num_steps"] * config["aggregate_k_gradients"] + ) + config["batch_size"] = math.ceil( + config["batch_size"] / config["aggregate_k_gradients"] + ) + config["recompute_attn"] = config.get("recompute_attn", False) + + def make_get_batch(model_proto, **extra_kwargs): + def new_get_batch( + batch_size, + seq_len, + num_features, + hyperparameters, + device, + model_proto=model_proto, + **kwargs, + ): + kwargs = {**extra_kwargs, **kwargs} # new args overwrite pre-specified args + return model_proto.get_batch( + batch_size=batch_size, + seq_len=seq_len, + device=device, + hyperparameters=hyperparameters, + num_features=num_features, + **kwargs, + ) + + return new_get_batch + + if config["prior_type"] == "prior_bag": + # Prior bag combines priors + get_batch_gp = make_get_batch(priors.fast_gp) + get_batch_mlp = make_get_batch(priors.mlp) + if config.get("flexible"): + get_batch_gp = make_get_batch( + priors.flexible_categorical, get_batch=get_batch_gp + ) + get_batch_mlp = make_get_batch( + priors.flexible_categorical, get_batch=get_batch_mlp + ) + prior_bag_hyperparameters = { + "prior_bag_get_batch": (get_batch_gp, get_batch_mlp), + "prior_bag_exp_weights_1": 2.0, + } + prior_hyperparameters = { + **get_mlp_prior_hyperparameters(config), + **get_gp_prior_hyperparameters(config), + **prior_bag_hyperparameters, + } + model_proto = priors.prior_bag + else: + if config["prior_type"] == "mlp": + prior_hyperparameters = get_mlp_prior_hyperparameters(config) + model_proto = priors.mlp + elif config["prior_type"] == "gp": + prior_hyperparameters = get_gp_prior_hyperparameters(config) + model_proto = priors.fast_gp + elif config["prior_type"] == "gp_mix": + prior_hyperparameters = get_gp_mix_prior_hyperparameters(config) + model_proto = priors.fast_gp_mix + else: + raise Exception() + + if config.get("flexible"): + get_batch_base = make_get_batch(model_proto) + extra_kwargs["get_batch"] = get_batch_base + model_proto = priors.flexible_categorical + + if config.get("flexible"): + prior_hyperparameters["normalize_labels"] = True + prior_hyperparameters["check_is_compatible"] = True + prior_hyperparameters["prior_mlp_scale_weights_sqrt"] = ( + config["prior_mlp_scale_weights_sqrt"] + if "prior_mlp_scale_weights_sqrt" in prior_hyperparameters + else None + ) + prior_hyperparameters["rotate_normalized_labels"] = ( + config["rotate_normalized_labels"] + if "rotate_normalized_labels" in prior_hyperparameters + else True + ) + + use_style = False + + if config.get("differentiable"): + get_batch_base = make_get_batch(model_proto, **extra_kwargs) + extra_kwargs = { + "get_batch": get_batch_base, + "differentiable_hyperparameters": config["differentiable_hyperparameters"], + } + model_proto = priors.differentiable_prior + use_style = True + print(f"Using style prior: {use_style}") + + if ( + ("nan_prob_no_reason" in config and config["nan_prob_no_reason"] > 0.0) + or ("nan_prob_a_reason" in config and config["nan_prob_a_reason"] > 0.0) + or ( + "nan_prob_unknown_reason" in config + and config["nan_prob_unknown_reason"] > 0.0 + ) + ): + encoder = encoders.NanHandlingEncoder + else: + encoder = partial(encoders.Linear, replace_nan_by_zero=True) + + if config["max_num_classes"] == 2: + loss = Losses.bce + elif config["max_num_classes"] > 2: + loss = Losses.ce(config["max_num_classes"]) + + False if "multiclass_loss_type" not in config else ( + config["multiclass_loss_type"] == "compatible" + ) + config["multiclass_type"] = config.get("multiclass_type", "rank") + config["mix_activations"] = config.get("mix_activations", False) + + config["bptt_extra_samples"] = config.get("bptt_extra_samples", None) + config["eval_positions"] = ( + [int(config["bptt"] * 0.95)] + if config["bptt_extra_samples"] is None + else [int(config["bptt"])] + ) + + epochs = 0 if not should_train else config["epochs"] + # print('MODEL BUILDER', model_proto, extra_kwargs['get_batch']) + return train( + model_proto.DataLoader, + loss, + encoder, + style_encoder_generator=encoders.StyleEncoder if use_style else None, + emsize=config["emsize"], + nhead=config["nhead"], + # For unsupervised learning change to NanHandlingEncoder + y_encoder_generator=encoders.get_Canonical(config["max_num_classes"]) + if config.get("canonical_y_encoder", False) + else encoders.Linear, + pos_encoder_generator=None, + batch_size=config["batch_size"], + nlayers=config["nlayers"], + nhid=config["emsize"] * config["nhid_factor"], + epochs=epochs, + warmup_epochs=20, + bptt=config["bptt"], + gpu_device=device, + dropout=config["dropout"], + steps_per_epoch=config["num_steps"], + single_eval_pos_gen=get_uniform_single_eval_pos_sampler( + config.get("max_eval_pos", config["bptt"]), + min_len=config.get("min_eval_pos", 0), + ), + load_weights_from_this_state_dict=state_dict, + aggregate_k_gradients=config["aggregate_k_gradients"], + recompute_attn=config["recompute_attn"], + epoch_callback=epoch_callback, + bptt_extra_samples=config["bptt_extra_samples"], + train_mixed_precision=config["train_mixed_precision"], + extra_prior_kwargs_dict={ + "num_features": config["num_features"], + "hyperparameters": prior_hyperparameters, + # , 'dynamic_batch_size': 1 if ('num_global_att_tokens' in config and config['num_global_att_tokens']) else 2 + "batch_size_per_gp_sample": config.get("batch_size_per_gp_sample", None), + **extra_kwargs, + }, + lr=config["lr"], + verbose=verbose_train, + weight_decay=config.get("weight_decay", 0.0), + ) + + +def load_model(path, filename, device, eval_positions, verbose): + # TODO: This function only restores evaluation functionality but training canät be continued. It is also not flexible. + # print('Loading....') + print("!! Warning: GPyTorch must be installed !!") + model_state, optimizer_state, config_sample = torch.load( + os.path.join(path, filename), map_location="cpu" + ) + if ( + "differentiable_hyperparameters" in config_sample + and "prior_mlp_activations" in config_sample["differentiable_hyperparameters"] + ): + config_sample["differentiable_hyperparameters"]["prior_mlp_activations"][ + "choice_values_used" + ] = config_sample["differentiable_hyperparameters"]["prior_mlp_activations"][ + "choice_values" + ] + config_sample["differentiable_hyperparameters"]["prior_mlp_activations"][ + "choice_values" + ] = [ + torch.nn.Tanh + for k in config_sample["differentiable_hyperparameters"][ + "prior_mlp_activations" + ]["choice_values"] + ] + + config_sample["categorical_features_sampler"] = lambda: lambda x: ([], [], []) + config_sample["num_features_used_in_training"] = config_sample["num_features_used"] + config_sample["num_features_used"] = lambda: config_sample["num_features"] + config_sample["num_classes_in_training"] = config_sample["num_classes"] + config_sample["num_classes"] = 2 + config_sample["batch_size_in_training"] = config_sample["batch_size"] + config_sample["batch_size"] = 1 + config_sample["bptt_in_training"] = config_sample["bptt"] + config_sample["bptt"] = 10 + config_sample["bptt_extra_samples_in_training"] = config_sample[ + "bptt_extra_samples" + ] + config_sample["bptt_extra_samples"] = None + + # print('Memory', str(get_gpu_memory())) + + model = get_model(config_sample, device=device, should_train=False, verbose=verbose) + module_prefix = "module." + model_state = {k.replace(module_prefix, ""): v for k, v in model_state.items()} + model[2].load_state_dict(model_state) + model[2].to(device) + # model[2].eval() + + return model, config_sample + + +def load_model_workflow( + i, + e, + add_name, + base_path, + device="cpu", + eval_addition="", + only_inference=True, + if_regression=False, +): + """Workflow for loading a model and setting appropriate parameters for diffable hparam tuning. + + :param i: + :param e: + :param eval_positions_valid: + :param add_name: + :param base_path: + :param device: + :param eval_addition: + :return: + """ + + def get_file(e): + """Returns the different paths of model_file, model_path and results_file.""" + model_file = ( + f"models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt" + ) + model_path = os.path.join(base_path, model_file) + # print('Evaluate ', model_path) + results_file = os.path.join( + base_path, + f"models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl", + ) + return model_file, model_path, results_file + + def check_file(e): + model_file, model_path, results_file = get_file(e) + if not Path(model_path).is_file(): # or Path(results_file).is_file(): + print( + "We have to download the TabPFN, as there is no checkpoint at ", + model_path, + ) + print("It has about 100MB, so this might take a moment.") + import requests + + url = "https://github.com/PriorLabs/TabPFN/raw/refs/tags/v1.0.0/tabpfn/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt" + # print('hhh') + r = requests.get(url, allow_redirects=True) + # print('hhh') + os.makedirs(os.path.dirname(model_path), exist_ok=True) + open(model_path, "wb").write(r.content) + return model_file, model_path, results_file + + model_file = None + if e == -1: + for e_ in range(100, -1, -1): + model_file_, model_path_, results_file_ = check_file(e_) + if model_file_ is not None: + e = e_ + model_file, model_path, results_file = ( + model_file_, + model_path_, + results_file_, + ) + break + else: + model_file, model_path, results_file = check_file(e) + + if model_file is None: + model_file, model_path, results_file = get_file(e) + raise Exception("No checkpoint found at " + str(model_path)) + + # print(f'Loading {model_file}') + if only_inference: + # print('Loading model that can be used for inference only') + model, c = load_model_only_inference(base_path, model_file, device) + if if_regression: + model, c = load_model_only_inference_regression( + base_path, model_file, device + ) + + else: + # until now also only capable of inference + model, c = load_model( + base_path, model_file, device, eval_positions=[], verbose=False + ) + # model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False) + + return model, c, results_file + + +def load_model_workflow_reg( + i, e, add_name, base_path, device="cpu", eval_addition="", only_inference=True +): + """Workflow for loading a model and setting appropriate parameters for diffable hparam tuning. + + :param i: + :param e: + :param eval_positions_valid: + :param add_name: + :param base_path: + :param device: + :param eval_addition: + :return: + """ + + def get_file(e): + """Returns the different paths of model_file, model_path and results_file.""" + model_file = ( + f"models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt" + ) + model_path = os.path.join(base_path, model_file) + # print('Evaluate ', model_path) + results_file = os.path.join( + base_path, + f"models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl", + ) + return model_file, model_path, results_file + + def check_file(e): + model_file, model_path, results_file = get_file(e) + if not Path(model_path).is_file(): # or Path(results_file).is_file(): + print( + "We have to download the TabPFN, as there is no checkpoint at ", + model_path, + ) + print("It has about 100MB, so this might take a moment.") + import requests + + url = "https://github.com/automl/TabPFN/raw/main/tabpfn/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt" + print("hhh") + r = requests.get(url, allow_redirects=True) + print("hhh") + os.makedirs(os.path.dirname(model_path), exist_ok=True) + open(model_path, "wb").write(r.content) + return model_file, model_path, results_file + + model_file = None + if e == -1: + for e_ in range(100, -1, -1): + model_file_, model_path_, results_file_ = check_file(e_) + if model_file_ is not None: + e = e_ + model_file, model_path, results_file = ( + model_file_, + model_path_, + results_file_, + ) + break + else: + model_file, model_path, results_file = check_file(e) + + if model_file is None: + model_file, model_path, results_file = get_file(e) + raise Exception("No checkpoint found at " + str(model_path)) + + # print(f'Loading {model_file}') + if only_inference: + # print('Loading model that can be used for inference only') + model, c = load_model_only_inference_regression(base_path, model_file, device) + + else: + # until now also only capable of inference + model, c = load_model( + base_path, model_file, device, eval_positions=[], verbose=False + ) + # model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False) + + return model, c, results_file + + +class CustomUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if name == "Manager": + from settings import Manager + + return Manager + try: + return self.find_class_cpu(module, name) + except: + return None + + def find_class_cpu(self, module, name): + if module == "torch.storage" and name == "_load_from_bytes": + return lambda b: torch.load(io.BytesIO(b), map_location="cpu") + return super().find_class(module, name) + + +import time + + +def transformer_predict( + model, + eval_xs, + eval_ys, + eval_position, + device="cpu", + max_features=100, + style=None, + inference_mode=False, + num_classes=2, + extend_features=True, + normalize_with_test=False, + normalize_to_ranking=False, + softmax_temperature=0.0, + multiclass_decoder="permutation", + preprocess_transform="mix", + categorical_feats=None, + feature_shift_decoder=False, + N_ensemble_configurations=10, + batch_size_inference=16, + differentiable_hps_as_style=False, + average_logits=True, + fp16_inference=False, + normalize_with_sqrt=False, + seed=0, + no_grad=True, + return_logits=False, + **kwargs, +): + """:param model: + :param eval_xs: + :param eval_ys: + :param eval_position: + :param rescale_features: + :param device: + :param max_features: + :param style: + :param inference_mode: + :param num_classes: + :param extend_features: + :param normalize_to_ranking: + :param softmax_temperature: + :param multiclass_decoder: + :param preprocess_transform: + :param categorical_feats: + :param feature_shift_decoder: + :param N_ensemble_configurations: + :param average_logits: + :param normalize_with_sqrt: + :param metric_used: + :return: + """ + if categorical_feats is None: + categorical_feats = [] + num_classes = len(torch.unique(eval_ys)) + + # N_ensemble_configurations=32 + def predict(eval_xs, eval_ys, used_style, softmax_temperature, return_logits): + # Initialize results array size S, B, Classes + + # no_grad disables inference_mode, because otherwise the gradients are lost + inference_mode_call = ( + torch.inference_mode() if inference_mode and no_grad else NOP() + ) + with inference_mode_call: + time.time() + output = model( + ( + used_style.repeat(eval_xs.shape[1], 1) + if used_style is not None + else None, + eval_xs, + eval_ys.float(), + ), + single_eval_pos=eval_position, + )[:, :, 0:num_classes] + + output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature) + if not return_logits: + output = torch.nn.functional.softmax(output, dim=-1) + # else: + # output[:, :, 1] = model((style.repeat(eval_xs.shape[1], 1) if style is not None else None, eval_xs, eval_ys.float()), + # single_eval_pos=eval_position) + + # output[:, :, 1] = torch.sigmoid(output[:, :, 1]).squeeze(-1) + # output[:, :, 0] = 1 - output[:, :, 1] + + # print('RESULTS', eval_ys.shape, torch.unique(eval_ys, return_counts=True), output.mean(axis=0)) + # print(output) + return output + + def preprocess_input(eval_xs, preprocess_transform): + import warnings + + if eval_xs.shape[1] > 1: + raise Exception("Transforms only allow one batch dim - TODO") + + if eval_xs.shape[2] > max_features: + eval_xs = eval_xs[ + :, + :, + sorted(np.random.choice(eval_xs.shape[2], max_features, replace=False)), + ] + + if preprocess_transform != "none": + if preprocess_transform in {"power", "power_all"}: + pt = PowerTransformer(standardize=True) + elif preprocess_transform in {"quantile", "quantile_all"}: + pt = QuantileTransformer(output_distribution="normal") + elif preprocess_transform in {"robust", "robust_all"}: + pt = RobustScaler(unit_variance=True) + + # eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys) + eval_xs = normalize_data( + eval_xs, normalize_positions=-1 if normalize_with_test else eval_position + ) + + # Removing empty features + eval_xs = eval_xs[:, 0, :] + sel = [ + len(torch.unique(eval_xs[0 : eval_ys.shape[0], col])) > 1 + for col in range(eval_xs.shape[1]) + ] + eval_xs = eval_xs[:, sel] + + warnings.simplefilter("error") + if preprocess_transform != "none": + eval_xs = eval_xs.cpu().numpy() + feats = ( + set(range(eval_xs.shape[1])) + if "all" in preprocess_transform + else set(range(eval_xs.shape[1])) - set(categorical_feats) + ) + for col in feats: + try: + pt.fit(eval_xs[0:eval_position, col : col + 1]) + trans = pt.transform(eval_xs[:, col : col + 1]) + # print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])])) + eval_xs[:, col : col + 1] = trans + except: + pass + eval_xs = torch.tensor(eval_xs).float() + warnings.simplefilter("default") + + eval_xs = eval_xs.unsqueeze(1) + + # TODO: Caution there is information leakage when to_ranking is used, we should not use it + eval_xs = ( + remove_outliers( + eval_xs, + normalize_positions=-1 if normalize_with_test else eval_position, + ) + if not normalize_to_ranking + else normalize_data(to_ranking_low_mem(eval_xs)) + ) + # Rescale X + eval_xs = normalize_by_used_features_f( + eval_xs, + eval_xs.shape[-1], + max_features, + normalize_with_sqrt=normalize_with_sqrt, + ) + + return eval_xs.to(device) + + eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device) + eval_ys = eval_ys[:eval_position] + # print(eval_xs[eval_position:]) + model.to(device) + + model.eval() + + import itertools + + if not differentiable_hps_as_style: + style = None + + if style is not None: + style = style.to(device) + style = style.unsqueeze(0) if len(style.shape) == 1 else style + num_styles = style.shape[0] + softmax_temperature = ( + softmax_temperature + if softmax_temperature.shape + else softmax_temperature.unsqueeze(0).repeat(num_styles) + ) + else: + num_styles = 1 + style = None + softmax_temperature = torch.log(torch.tensor([0.8])) + + styles_configurations = range(num_styles) + + def get_preprocess(i): + if i == 0: + return "power_all" + # if i == 1: + # return 'robust_all' + if i == 1: + return "none" + return None + + preprocess_transform_configurations = ( + ["none", "power_all"] + if preprocess_transform == "mix" + else [preprocess_transform] + ) + + if seed is not None: + torch.manual_seed(seed) + + feature_shift_configurations = ( + torch.randperm(eval_xs.shape[2]) if feature_shift_decoder else [0] + ) + class_shift_configurations = ( + torch.randperm(len(torch.unique(eval_ys))) + if multiclass_decoder == "permutation" + else [0] + ) + + ensemble_configurations = list( + itertools.product(class_shift_configurations, feature_shift_configurations) + ) + # default_ensemble_config = ensemble_configurations[0] + + rng = random.Random(seed) + rng.shuffle(ensemble_configurations) + ensemble_configurations = list( + itertools.product( + ensemble_configurations, + preprocess_transform_configurations, + styles_configurations, + ) + ) + ensemble_configurations = ensemble_configurations[0:N_ensemble_configurations] + # if N_ensemble_configurations == 1: + # ensemble_configurations = [default_ensemble_config] + + output = None + + eval_xs_transformed = {} + inputs, labels = [], [] + time.time() + for ensemble_configuration in ensemble_configurations: + ( + (class_shift_configuration, feature_shift_configuration), + preprocess_transform_configuration, + styles_configuration, + ) = ensemble_configuration + + style_ = ( + style[styles_configuration : styles_configuration + 1, :] + if style is not None + else style + ) + softmax_temperature_ = softmax_temperature[styles_configuration] + + eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone() + # print(preprocess_transform_configuration) + if preprocess_transform_configuration in eval_xs_transformed: + eval_xs_ = eval_xs_transformed[preprocess_transform_configuration].clone() + else: + eval_xs_ = preprocess_input( + eval_xs_, preprocess_transform=preprocess_transform_configuration + ) + if no_grad: + eval_xs_ = eval_xs_.detach() + eval_xs_transformed[preprocess_transform_configuration] = eval_xs_ + + # eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float() + # print(class_shift_configuration) + eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float() + + eval_xs_ = torch.cat( + [ + eval_xs_[..., feature_shift_configuration:], + eval_xs_[..., :feature_shift_configuration], + ], + dim=-1, + ) + + # Extend X + if extend_features: + eval_xs_ = torch.cat( + [ + eval_xs_, + torch.zeros( + ( + eval_xs_.shape[0], + eval_xs_.shape[1], + max_features - eval_xs_.shape[2], + ) + ).to(device), + ], + -1, + ) + inputs += [eval_xs_] + labels += [eval_ys_] + # print(eval_xs_) + inputs = torch.cat(inputs, 1) + inputs = torch.split(inputs, batch_size_inference, dim=1) + labels = torch.cat(labels, 1) + labels = torch.split(labels, batch_size_inference, dim=1) + # print(inputs[0].shape, labels[0].shape) + # print('PREPROCESSING TIME', str(time.time() - start)) + outputs = [] + time.time() + for batch_input, batch_label in zip(inputs, labels): + # preprocess_transform_ = preprocess_transform if styles_configuration % 2 == 0 else 'none' + print(batch_input.shape, batch_label.shape) + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="None of the inputs have requires_grad=True. Gradients will be None", + ) + warnings.filterwarnings( + "ignore", + message="torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling.", + ) + if device == "cpu": + output_batch = checkpoint( + predict, + batch_input, + batch_label, + style_, + softmax_temperature_, + True, + ) + else: + with torch.cuda.amp.autocast(enabled=fp16_inference): + output_batch = checkpoint( + predict, + batch_input, + batch_label, + style_, + softmax_temperature_, + True, + ) + outputs += [output_batch] + # print('MODEL INFERENCE TIME ('+str(batch_input.device)+' vs '+device+', '+str(fp16_inference)+')', str(time.time()-start)) + + outputs = torch.cat(outputs, 1) + for i, ensemble_configuration in enumerate(ensemble_configurations): + ( + (class_shift_configuration, feature_shift_configuration), + preprocess_transform_configuration, + styles_configuration, + ) = ensemble_configuration + output_ = outputs[:, i : i + 1, :] + output_ = torch.cat( + [ + output_[..., class_shift_configuration:], + output_[..., :class_shift_configuration], + ], + dim=-1, + ) + + # output_ = predict(eval_xs, eval_ys, style_, preprocess_transform_) + if not average_logits and not return_logits: + # transforms every ensemble_configuration into a probability -> equal contribution of every configuration + output_ = torch.nn.functional.softmax(output_, dim=-1) + output = output_ if output is None else output + output_ + + output = output / len(ensemble_configurations) + # if average_logits and not return_logits: + # output = torch.nn.functional.softmax(output, dim=-1) + + return torch.transpose(output, 0, 1) + + +def get_params_from_config(c): + return { + "max_features": c["num_features"], + "rescale_features": c["normalize_by_used_features"], + "normalize_to_ranking": c["normalize_to_ranking"], + "normalize_with_sqrt": c.get("normalize_with_sqrt", False), + } diff --git a/tabrepo/benchmark/models/ag/beta/deps/tabr_utils.py b/tabrepo/benchmark/models/ag/beta/deps/tabr_utils.py new file mode 100644 index 00000000..4098296e --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/tabr_utils.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import math + +import torch +from torch import Tensor, nn +from torch.nn.parameter import Parameter + + +# adapted from https://github.com/yandex-research/tabular-dl-tabr +def _initialize_embeddings(weight: Tensor, d: int | None) -> None: + if d is None: + d = weight.shape[-1] + d_sqrt_inv = 1 / math.sqrt(d) + nn.init.uniform_(weight, a=-d_sqrt_inv, b=d_sqrt_inv) + + +def make_trainable_vector(d: int) -> Parameter: + x = torch.empty(d) + _initialize_embeddings(x, None) + return Parameter(x) + + +class CLSEmbedding(nn.Module): + def __init__(self, d_embedding: int) -> None: + super().__init__() + self.weight = make_trainable_vector(d_embedding) + + def forward(self, x: Tensor) -> Tensor: + assert x.ndim == 3 + assert x.shape[-1] == len(self.weight) + return torch.cat([self.weight.expand(len(x), 1, -1), x], dim=1) + + +class ResNet(nn.Module): + def __init__( + self, + *, + d_in: None | int = None, + d_out: None | int = None, + n_blocks: int, + d_block: int, + dropout: float, + d_hidden_multiplier: float | int, + n_linear_layers_per_block: int = 2, + activation: str = "ReLU", + normalization: str, + first_normalization: bool, + ) -> None: + assert n_linear_layers_per_block in (1, 2) + if n_linear_layers_per_block == 1: + assert d_hidden_multiplier == 1 + super().__init__() + + Activation = getattr(nn, activation) + Normalization = ( + Identity if normalization == "none" else getattr(nn, normalization) + ) + d_hidden = int(d_block * d_hidden_multiplier) + + self.proj = None if d_in is None else nn.Linear(d_in, d_block) + self.blocks = nn.ModuleList( + [ + nn.Sequential( + Normalization(d_block) if first_normalization else Identity(), + ( + nn.Linear(d_block, d_hidden) + if n_linear_layers_per_block == 2 + else nn.Linear(d_block, d_block) + ), + Activation(), + nn.Dropout(dropout), + ( + nn.Linear(d_hidden, d_block) + if n_linear_layers_per_block == 2 + else Identity() + ), + ) + for _ in range(n_blocks) + ] + ) + self.preoutput = nn.Sequential(Normalization(d_block), Activation()) + self.output = None if d_out is None else nn.Linear(d_block, d_out) + + def forward(self, x: Tensor) -> Tensor: + if self.proj is not None: + x = self.proj(x) + for block in self.blocks: + x = x + block(x) + x = self.preoutput(x) + if self.output is not None: + x = x + self.output(x) + return x + + +class LinearEmbeddings(nn.Module): + def __init__(self, n_features: int, d_embedding: int, bias: bool = True): + super().__init__() + self.weight = Parameter(Tensor(n_features, d_embedding)) + self.bias = Parameter(Tensor(n_features, d_embedding)) if bias else None + self.reset_parameters() + + def reset_parameters(self) -> None: + for parameter in [self.weight, self.bias]: + if parameter is not None: + _initialize_embeddings(parameter, parameter.shape[-1]) + + def forward(self, x: Tensor) -> Tensor: + assert x.ndim == 2 + x = self.weight[None] * x[..., None] + if self.bias is not None: + x = x + self.bias[None] + return x + + +class PeriodicEmbeddings(nn.Module): + def __init__( + self, n_features: int, n_frequencies: int, frequency_scale: float + ) -> None: + super().__init__() + self.frequencies = Parameter( + torch.normal(0.0, frequency_scale, (n_features, n_frequencies)) + ) + + def forward(self, x: Tensor) -> Tensor: + assert x.ndim == 2 + x = 2 * torch.pi * self.frequencies[None] * x[..., None] + return torch.cat([torch.cos(x), torch.sin(x)], -1) + + +class NLinear(nn.Module): + def __init__( + self, n_features: int, d_in: int, d_out: int, bias: bool = True + ) -> None: + super().__init__() + self.weight = Parameter(Tensor(n_features, d_in, d_out)) + self.bias = Parameter(Tensor(n_features, d_out)) if bias else None + with torch.no_grad(): + for i in range(n_features): + layer = nn.Linear(d_in, d_out) + self.weight[i] = layer.weight.T + if self.bias is not None: + self.bias[i] = layer.bias + + def forward(self, x): + assert x.ndim == 3 + x = x[..., None] * self.weight[None] + x = x.sum(-2) + if self.bias is not None: + x = x + self.bias[None] + return x + + +class LREmbeddings(nn.Sequential): + """The LR embeddings from the paper 'On Embeddings for Numerical Features in Tabular Deep Learning'.""" # noqa: E501 + + def __init__(self, n_features: int, d_embedding: int) -> None: + super().__init__(LinearEmbeddings(n_features, d_embedding), nn.ReLU()) + + +class PLREmbeddings(nn.Sequential): + """The PLR embeddings from the paper 'On Embeddings for Numerical Features in Tabular Deep Learning'. + + Additionally, the 'lite' option is added. Setting it to `False` gives you the original PLR + embedding from the above paper. We noticed that `lite=True` makes the embeddings + noticeably more lightweight without critical performance loss, and we used that for our model. + """ # noqa: E501 + + def __init__( + self, + n_features: int, + n_frequencies: int, + frequency_scale: float, + d_embedding: int, + lite: bool, + ) -> None: + super().__init__( + PeriodicEmbeddings(n_features, n_frequencies, frequency_scale), + ( + nn.Linear(2 * n_frequencies, d_embedding) + if lite + else NLinear(n_features, 2 * n_frequencies, d_embedding) + ), + nn.ReLU(), + ) + + +# class MLP(nn.Module): +# class Block(nn.Module): +# def __init__( +# self, +# *, +# d_in: int, +# d_out: int, +# bias: bool, +# activation: str, +# dropout: float, +# ) -> None: +# super().__init__() +# self.linear = nn.Linear(d_in, d_out, bias) +# self.activation = make_module(activation) +# self.dropout = nn.Dropout(dropout) + +# def forward(self, x: Tensor) -> Tensor: +# return self.dropout(self.activation(self.linear(x))) + +# Head = nn.Linear + +# def __init__( +# self, +# *, +# d_in: int, +# d_out: Optional[int], +# n_blocks: int, +# d_layer: int, +# activation: str, +# dropout: float, +# ) -> None: +# assert n_blocks > 0 +# super().__init__() + +# self.blocks = nn.Sequential( +# *[ +# MLP.Block( +# d_in=d_layer if block_i else d_in, +# d_out=d_layer, +# bias=True, +# activation=activation, +# dropout=dropout, +# ) +# for block_i in range(n_blocks) +# ] +# ) +# self.head = None if d_out is None else MLP.Head(d_layer, d_out) + +# @property +# def d_out(self) -> int: +# return ( +# self.blocks[-1].linear.out_features # type: ignore[code] +# if self.head is None +# else self.head.out_features +# ) + +# def forward(self, x: Tensor) -> Tensor: +# x = self.blocks(x) +# if self.head is not None: +# x = self.head(x) +# return x + + +class MLP(nn.Module): + def __init__( + self, + *, + d_in: None | int = None, + d_out: None | int = None, + n_blocks: int, + d_block: int, + dropout: float, + activation: str = "SELU", + ) -> None: + super().__init__() + + d_first = d_block if d_in is None else d_in + self.blocks = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(d_first if i == 0 else d_block, d_block), + getattr(nn, activation)(), + nn.Dropout(dropout), + ) + for i in range(n_blocks) + ] + ) + self.output = None if d_out is None else nn.Linear(d_block, d_out) + + def forward(self, x: Tensor) -> Tensor: + for block in self.blocks: + x = block(x) + if self.output is not None: + x = self.output(x) + return x + + +_CUSTOM_MODULES = { + x.__name__: x + for x in [ + LinearEmbeddings, + LREmbeddings, + PLREmbeddings, + MLP, + ] +} + + +def make_module(spec, *args, **kwargs) -> nn.Module: + """>>> make_module('ReLU') + >>> make_module(nn.ReLU) + >>> make_module('Linear', 1, out_features=2) + >>> make_module((lambda *args: nn.Linear(*args)), 1, out_features=2) + >>> make_module({'type': 'Linear', 'in_features' 1}, out_features=2). + """ + if isinstance(spec, str): + Module = getattr(nn, spec, None) + if Module is None: + Module = _CUSTOM_MODULES[spec] + else: + assert spec not in _CUSTOM_MODULES + return make_module(Module, *args, **kwargs) + if isinstance(spec, dict): + assert not (set(spec) & set(kwargs)) + spec = spec.copy() + return make_module(spec.pop("type"), *args, **spec, **kwargs) + if callable(spec): + return spec(*args, **kwargs) + raise ValueError() + + +def make_module1(type: str, *args, **kwargs) -> nn.Module: + Module = getattr(nn, type, None) + if Module is None: + Module = _CUSTOM_MODULES[type] + return Module(*args, **kwargs) diff --git a/tabrepo/benchmark/models/ag/beta/deps/talent_data.py b/tabrepo/benchmark/models/ag/beta/deps/talent_data.py new file mode 100644 index 00000000..a45ae227 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/talent_data.py @@ -0,0 +1,675 @@ +from __future__ import annotations + +import dataclasses as dc +import json +import os +import typing as ty +from copy import deepcopy +from pathlib import Path + +import category_encoders +import numpy as np +import sklearn.preprocessing +import torch +import torch.nn.functional as F +from sklearn.impute import SimpleImputer +from torch.utils.data import ( + DataLoader, + Dataset as TorchDataset, +) + +BINCLASS = "binclass" +MULTICLASS = "multiclass" +REGRESSION = "regression" + +ArrayDict = dict[str, np.ndarray] + + +def raise_unknown(unknown_what: str, unknown_value: ty.Any): + raise ValueError(f"Unknown {unknown_what}: {unknown_value}") + + +def load_json(path): + return json.loads(Path(path).read_text()) + + +@dc.dataclass +class Dataset: + N: ArrayDict | None + C: ArrayDict | None + y: ArrayDict + info: dict[str, ty.Any] + + @property + def is_binclass(self) -> bool: + return self.info["task_type"] == BINCLASS + + @property + def is_multiclass(self) -> bool: + return self.info["task_type"] == MULTICLASS + + @property + def is_regression(self) -> bool: + return self.info["task_type"] == REGRESSION + + @property + def n_num_features(self) -> int: + return self.info["n_num_features"] + + @property + def n_cat_features(self) -> int: + return self.info["n_cat_features"] + + @property + def n_features(self) -> int: + return self.n_num_features + self.n_cat_features + + def size(self, part: str) -> int: + """Return the size of the dataset partition. + + Args: + - part: str + + Returns: int + """ + X = self.N if self.N is not None else self.C + assert X is not None + return len(X[part]) + + +THIS_PATH = os.path.dirname(__file__) +DATA_PATH = os.path.abspath(os.path.join(THIS_PATH, "..", "..", "..")) + + +def dataname_to_numpy(dataset_name, dataset_path): + """Load the dataset from the numpy files. + + :param dataset_name: str + :param dataset_path: str + :return: Tuple[ArrayDict, ArrayDict, ArrayDict, Dict[str, Any]] + """ + dir_ = Path(os.path.join(DATA_PATH, dataset_path, dataset_name)) + + def load(item) -> ArrayDict: + return { + x: ty.cast( + "np.ndarray", np.load(dir_ / f"{item}_{x}.npy", allow_pickle=True) + ) + for x in ["train", "val", "test"] + } + + return ( + load("N") if dir_.joinpath("N_train.npy").exists() else None, + load("C") if dir_.joinpath("C_train.npy").exists() else None, + load("y"), + load_json(dir_ / "info.json"), + ) + + +def get_dataset(dataset_name, dataset_path): + """Load the dataset from the numpy files. + + :param dataset_name: str + :param dataset_path: str + :return: Tuple[ArrayDict, ArrayDict, ArrayDict, Dict[str, Any]] + """ + N, C, y, info = dataname_to_numpy(dataset_name, dataset_path) + N_trainval = ( + None + if N is None + else {key: N[key] for key in ["train", "val"]} + if "train" in N and "val" in N + else None + ) + N_test = ( + None + if N is None + else {key: N[key] for key in ["test"]} + if "test" in N + else None + ) + + C_trainval = ( + None + if C is None + else {key: C[key] for key in ["train", "val"]} + if "train" in C and "val" in C + else None + ) + C_test = ( + None + if C is None + else {key: C[key] for key in ["test"]} + if "test" in C + else None + ) + + y_trainval = {key: y[key] for key in ["train", "val"]} + y_test = {key: y[key] for key in ["test"]} + + # tune hyper-parameters + train_val_data = (N_trainval, C_trainval, y_trainval) + test_data = (N_test, C_test, y_test) + return train_val_data, test_data, info + + +def data_nan_process( + N_data, + C_data, + num_nan_policy, + cat_nan_policy, + num_new_value=None, + imputer=None, + cat_new_value=None, +): + """Process the NaN values in the dataset. + + :param N_data: ArrayDict + :param C_data: ArrayDict + :param num_nan_policy: str + :param cat_nan_policy: str + :param num_new_value: Optional[np.ndarray] + :param imputer: Optional[SimpleImputer] + :param cat_new_value: Optional[str] + :return: Tuple[ArrayDict, ArrayDict, Optional[np.ndarray], Optional[SimpleImputer], Optional[str]] + """ + if N_data is None: + N = None + else: + N = deepcopy(N_data) + if "train" in N_data: + if N["train"].ndim == 1: + N = {k: v.reshape(-1, 1) for k, v in N.items()} + elif N["test"].ndim == 1: + N = {k: v.reshape(-1, 1) for k, v in N.items()} + N = {k: v.astype(float) for k, v in N.items()} + num_nan_masks = {k: np.isnan(v) for k, v in N.items()} + if any(x.any() for x in num_nan_masks.values()): + if num_new_value is None: + if num_nan_policy == "mean": + num_new_value = np.nanmean(N_data["train"], axis=0) + elif num_nan_policy == "median": + num_new_value = np.nanmedian(N_data["train"], axis=0) + else: + raise_unknown("numerical NaN policy", num_nan_policy) + for k, v in N.items(): + num_nan_indices = np.where(num_nan_masks[k]) + v[num_nan_indices] = np.take(num_new_value, num_nan_indices[1]) + if C_data is None: + C = None + else: + assert cat_nan_policy == "new" + C = deepcopy(C_data) + if "train" in C_data: + if C["train"].ndim == 1: + C = {k: v.reshape(-1, 1) for k, v in C.items()} + elif C["test"].ndim == 1: + C = {k: v.reshape(-1, 1) for k, v in C.items()} + C = {k: v.astype(str) for k, v in C.items()} + # assume the cat nan condition + cat_nan_masks = { + k: np.isnan(v) + if np.issubdtype(v.dtype, np.number) + else np.isin(v, ["nan", "NaN", "", None]) + for k, v in C.items() + } + if any(x.any() for x in cat_nan_masks.values()): + if cat_nan_policy == "new": + if cat_new_value is None: + cat_new_value = "___null___" + imputer = None + elif cat_nan_policy == "most_frequent": + if imputer is None: + cat_new_value = None + imputer = SimpleImputer(strategy="most_frequent") + imputer.fit(C["train"]) + else: + raise_unknown("categorical NaN policy", cat_nan_policy) + if imputer: + C = {k: imputer.transform(v) for k, v in C.items()} + else: + for k, v in C.items(): + cat_nan_indices = np.where(cat_nan_masks[k]) + v[cat_nan_indices] = cat_new_value + + return (N, C, num_new_value, imputer, cat_new_value) + + +def num_enc_process( + N_data, num_policy, n_bins=2, y_train=None, is_regression=False, encoder=None +): + """Process the numerical features in the dataset. + + :param N_data: ArrayDict + :param num_policy: str + :param n_bins: int + :param y_train: Optional[np.ndarray] + :param is_regression: bool + :param encoder: Optional[PiecewiseLinearEncoding] + :return: Tuple[ArrayDict, Optional[PiecewiseLinearEncoding]] + """ + from tabrepo.benchmark.models.ag.beta.deps.talent_num_embeddings import ( + BinsEncoding, + JohnsonEncoding, + PiecewiseLinearEncoding, + UnaryEncoding, + compute_bins, + ) + + if N_data is not None: + if num_policy == "none": + return N_data, None + + if num_policy == "Q_PLE": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=None, + y=None, + regression=None, + ) + encoder = PiecewiseLinearEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + + elif num_policy == "T_PLE": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + tree_kwargs = {"min_samples_leaf": 64, "min_impurity_decrease": 1e-4} + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=tree_kwargs, + y=torch.from_numpy(y_train), + regression=is_regression, + ) + encoder = PiecewiseLinearEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + elif num_policy == "Q_Unary": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=None, + y=None, + regression=None, + ) + encoder = UnaryEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + elif num_policy == "T_Unary": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + tree_kwargs = {"min_samples_leaf": 64, "min_impurity_decrease": 1e-4} + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=tree_kwargs, + y=torch.from_numpy(y_train), + regression=is_regression, + ) + encoder = UnaryEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + elif num_policy == "Q_bins": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=None, + y=None, + regression=None, + ) + encoder = BinsEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + elif num_policy == "T_bins": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + tree_kwargs = {"min_samples_leaf": 64, "min_impurity_decrease": 1e-4} + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=tree_kwargs, + y=torch.from_numpy(y_train), + regression=is_regression, + ) + encoder = BinsEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + elif num_policy == "Q_Johnson": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=None, + y=None, + regression=None, + ) + encoder = JohnsonEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + elif num_policy == "T_Johnson": + for item in N_data: + N_data[item] = torch.from_numpy(N_data[item]) + if encoder is None: + tree_kwargs = {"min_samples_leaf": 64, "min_impurity_decrease": 1e-4} + bins = compute_bins( + N_data["train"], + n_bins=n_bins, + tree_kwargs=tree_kwargs, + y=torch.from_numpy(y_train), + regression=is_regression, + ) + encoder = JohnsonEncoding(bins) + for item in N_data: + N_data[item] = encoder(N_data[item]).cpu().numpy() + + return N_data, encoder + return N_data, None + + +def data_enc_process( + N_data, + C_data, + cat_policy, + y_train=None, + ord_encoder=None, + mode_values=None, + cat_encoder=None, +): + """Process the categorical features in the dataset. + + :param N_data: ArrayDict + :param C_data: ArrayDict + :param cat_policy: str + :param y_train: Optional[np.ndarray] + :param ord_encoder: Optional[OrdinalEncoder] + :param mode_values: Optional[List[int]] + :param cat_encoder: Optional[OneHotEncoder] + :return: Tuple[ArrayDict, ArrayDict, Optional[OrdinalEncoder], Optional[List[int]], Optional[OneHotEncoder]] + """ + if C_data is not None: + unknown_value = np.iinfo("int64").max - 3 + if ord_encoder is None: + ord_encoder = sklearn.preprocessing.OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=unknown_value, + dtype="int64", + ).fit(C_data["train"]) + C_data = {k: ord_encoder.transform(v) for k, v in C_data.items()} + + # for valset and testset, the unknown value is replaced by the mode value of the column + if mode_values is not None: + assert "test" in C_data + for column_idx in range(C_data["test"].shape[1]): + C_data["test"][:, column_idx][ + C_data["test"][:, column_idx] == unknown_value + ] = mode_values[column_idx] + elif "val" in C_data: + mode_values = [ + np.argmax(np.bincount(column[column != unknown_value])) + if np.any(column == unknown_value) + else column[0] + for column in C_data["train"].T + ] + for column_idx in range(C_data["val"].shape[1]): + C_data["val"][:, column_idx][ + C_data["val"][:, column_idx] == unknown_value + ] = mode_values[column_idx] + + if cat_policy == "indices": + result = (N_data, C_data) + return result[0], result[1], ord_encoder, mode_values, cat_encoder + # use other encoding if we will treat categorical features as numerical + if cat_policy == "ordinal": + cat_encoder = ord_encoder + elif cat_policy == "ohe": + if cat_encoder is None: + cat_encoder = sklearn.preprocessing.OneHotEncoder( + handle_unknown="ignore", sparse_output=False, dtype="float64" + ) + cat_encoder.fit(C_data["train"]) + C_data = {k: cat_encoder.transform(v) for k, v in C_data.items()} + elif cat_policy == "binary": + if cat_encoder is None: + cat_encoder = category_encoders.BinaryEncoder() + cat_encoder.fit(C_data["train"].astype(str)) + C_data = { + k: cat_encoder.transform(v.astype(str)).values + for k, v in C_data.items() + } + elif cat_policy == "hash": + if cat_encoder is None: + cat_encoder = category_encoders.HashingEncoder() + cat_encoder.fit(C_data["train"].astype(str)) + C_data = { + k: cat_encoder.transform(v.astype(str)).values + for k, v in C_data.items() + } + elif cat_policy == "loo": + if cat_encoder is None: + cat_encoder = category_encoders.LeaveOneOutEncoder() + cat_encoder.fit(C_data["train"].astype(str), y_train) + C_data = { + k: cat_encoder.transform(v.astype(str)).values + for k, v in C_data.items() + } + elif cat_policy == "target": + if cat_encoder is None: + cat_encoder = category_encoders.TargetEncoder() + cat_encoder.fit(C_data["train"].astype(str), y_train) + C_data = { + k: cat_encoder.transform(v.astype(str)).values + for k, v in C_data.items() + } + elif cat_policy == "catboost": + if cat_encoder is None: + cat_encoder = category_encoders.CatBoostEncoder() + cat_encoder.fit(C_data["train"].astype(str), y_train) + C_data = { + k: cat_encoder.transform(v.astype(str)).values + for k, v in C_data.items() + } + elif cat_policy == "tabr_ohe": + if cat_encoder is None: + cat_encoder = sklearn.preprocessing.OneHotEncoder( + handle_unknown="ignore", sparse_output=False, dtype="float64" + ) + cat_encoder.fit(C_data["train"]) + C_data = {k: cat_encoder.transform(v) for k, v in C_data.items()} + result = (N_data, C_data) + return result[0], result[1], ord_encoder, mode_values, cat_encoder + else: + raise_unknown("categorical encoding policy", cat_policy) + if N_data is None: + result = (C_data, None) + else: + result = ({x: np.hstack((N_data[x], C_data[x])) for x in N_data}, None) + return result[0], result[1], ord_encoder, mode_values, cat_encoder + return N_data, C_data, None, None, None + + +def data_norm_process(N_data, normalization, seed, normalizer=None): + """Process the normalization of the dataset. + + :param N_data: ArrayDict + :param normalization: str + :param seed: int + :param normalizer: Optional[TransformerMixin] + :return: Tuple[ArrayDict, Optional[TransformerMixin]] + """ + if N_data is None or normalization == "none": + return N_data, None + + if normalizer is None: + N_data_train = N_data["train"].copy() + + if normalization == "standard": + normalizer = sklearn.preprocessing.StandardScaler() + elif normalization == "minmax": + normalizer = sklearn.preprocessing.MinMaxScaler() + elif normalization == "quantile": + normalizer = sklearn.preprocessing.QuantileTransformer( + output_distribution="normal", + n_quantiles=max(min(N_data["train"].shape[0] // 30, 1000), 10), + random_state=seed, + ) + elif normalization == "maxabs": + normalizer = sklearn.preprocessing.MaxAbsScaler() + elif normalization == "power": + normalizer = sklearn.preprocessing.PowerTransformer(method="yeo-johnson") + elif normalization == "robust": + normalizer = sklearn.preprocessing.RobustScaler() + else: + raise_unknown("normalization", normalization) + normalizer.fit(N_data_train) + + result = {k: normalizer.transform(v) for k, v in N_data.items()} + return result, normalizer + + +def data_label_process(y_data, is_regression, info=None, encoder=None): + """Process the labels in the dataset. + + :param y_data: ArrayDict + :param is_regression: bool + :param info: Optional[Dict[str, Any]] + :param encoder: Optional[LabelEncoder] + :return: Tuple[ArrayDict, Dict[str, Any], Optional[LabelEncoder]] + """ + y = deepcopy(y_data) + if is_regression: + y = {k: v.astype(float) for k, v in y.items()} + if info is None: + mean, std = y_data["train"].mean(), y_data["train"].std() + else: + mean, std = info["mean"], info["std"] + y = {k: (v - mean) / std for k, v in y.items()} + info = {"policy": "mean_std", "mean": mean, "std": std} + return y, info, None + # classification + if encoder is None: + encoder = sklearn.preprocessing.LabelEncoder().fit(y["train"]) + y = {k: encoder.transform(v) for k, v in y.items()} + return y, {"policy": "none"}, encoder + + +def data_loader_process(is_regression, X, Y, y_info, device, batch_size, is_train): + """Process the data loader. + + :param is_regression: bool + :param X: Tuple[ArrayDict, ArrayDict] + :param Y: ArrayDict + :param y_info: Dict[str, Any] + :param device: torch.device + :param batch_size: int + :param is_train: bool + :return: Tuple[ArrayDict, ArrayDict, ArrayDict, DataLoader, DataLoader, Callable] + """ + X = tuple(None if x is None else to_tensors(x) for x in X) + Y = to_tensors(Y) + + X = tuple(None if x is None else {k: v.to(device) for k, v in x.items()} for x in X) + Y = {k: v.to(device) for k, v in Y.items()} + + if X[0] is not None: + X = ({k: v.double() for k, v in X[0].items()}, X[1]) + + if is_regression: + Y = {k: v.double() for k, v in Y.items()} + else: + Y = {k: v.long() for k, v in Y.items()} + + loss_fn = F.mse_loss if is_regression else F.cross_entropy + + if is_train: + trainset = TData(is_regression, X, Y, y_info, "train") + valset = TData(is_regression, X, Y, y_info, "val") + train_loader = DataLoader( + dataset=trainset, batch_size=batch_size, shuffle=True, num_workers=0 + ) + val_loader = DataLoader( + dataset=valset, batch_size=batch_size, shuffle=False, num_workers=0 + ) + return X[0], X[1], Y, train_loader, val_loader, loss_fn + testset = TData(is_regression, X, Y, y_info, "test") + test_loader = DataLoader( + dataset=testset, batch_size=batch_size, shuffle=False, num_workers=0 + ) + return X[0], X[1], Y, test_loader, loss_fn + + +def to_tensors(data: ArrayDict) -> dict[str, torch.Tensor]: + """Convert the numpy arrays to torch tensors. + + :param data: ArrayDict + :return: Dict[str, torch.Tensor] + """ + return {k: torch.as_tensor(v) for k, v in data.items()} + + +def get_categories(X_cat: dict[str, torch.Tensor] | None) -> list[int] | None: + """Get the categories for each categorical feature. + + :param X_cat: Optional[Dict[str, torch.Tensor]] + :return: Optional[List[int]] + """ + return ( + None + if X_cat is None + else [ + len(set(X_cat["train"][:, i].tolist())) + for i in range(X_cat["train"].shape[1]) + ] + ) + + +class TData(TorchDataset): + def __init__(self, is_regression, X, Y, y_info, part): + assert part in ["train", "val", "test"] + X_num, X_cat = X + self.X_num = X_num[part] if X_num is not None else None + self.X_cat = X_cat[part] if X_cat is not None else None + self.Y, self.y_info = Y[part], y_info + + # self.num_class = 1 if is_regression else torch.unique(Y['train']).shape[0] + + def get_dim_in(self): + return 0 if self.X_num is None else self.X_num.shape[1] + + def get_categories(self): + return ( + None + if self.X_cat is None + else [ + len(set(self.X_cat[:, i].cpu().tolist())) + for i in range(self.X_cat.shape[1]) + ] + ) + + def __len__(self): + return len(self.Y) + + def __getitem__(self, i): + if self.X_num is not None and self.X_cat is not None: + data = (self.X_num[i], self.X_cat[i]) + elif self.X_cat is not None and self.X_num is None: + data, label = self.X_cat[i], self.Y[i] + else: + data, label = self.X_num[i], self.Y[i] + label = self.Y[i] + return data, label diff --git a/tabrepo/benchmark/models/ag/beta/deps/talent_methods_base.py b/tabrepo/benchmark/models/ag/beta/deps/talent_methods_base.py new file mode 100644 index 00000000..5915ff79 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/talent_methods_base.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +import abc +import os.path as osp +import time + +import numpy as np +import sklearn.metrics as skm +import torch +from tqdm import tqdm + +from tabrepo.benchmark.models.ag.beta.deps.talent_data import ( + Dataset, + data_enc_process, + data_label_process, + data_loader_process, + data_nan_process, + data_norm_process, + get_categories, + num_enc_process, +) +from tabrepo.benchmark.models.ag.beta.deps.talent_utils import ( + Averager, + Timer, + get_device, + set_seeds, +) + +# from schedulefree import ScheduleFreeWrapper,AdamWScheduleFree + + +def check_softmax(logits): + """Check if the logits are already probabilities, and if not, convert them to probabilities. + + :param logits: np.ndarray of shape (N, C) with logits + :return: np.ndarray of shape (N, C) with probabilities + """ + # Check if any values are outside the [0, 1] range and Ensure they sum to 1 + if np.any((logits < 0) | (logits > 1)) or ( + not np.allclose(logits.sum(axis=-1), 1, atol=1e-5) + ): + exps = np.exp( + logits - np.max(logits, axis=1, keepdims=True) + ) # stabilize by subtracting max + return exps / np.sum(exps, axis=1, keepdims=True) + return logits + + +class Method(metaclass=abc.ABCMeta): + def __init__(self, args, is_regression): + """:param args: argparse object + :param is_regression: bool, whether the task is regression or not + """ + self.args = args + print(args.config) + self.is_regression = is_regression + self.D = None + + self.train_step = 0 + self.val_count = 0 + self.continue_training = True + self.timer = Timer() + + self.trlog = {} + self.trlog["args"] = vars(args) + self.trlog["train_loss"] = [] + self.trlog["best_epoch"] = 0 + if self.is_regression: + self.trlog["best_res"] = 1e10 + else: + self.trlog["best_res"] = 0 + + self.args.device = get_device() + + def reset_stats_withconfig(self, config): + """Reset the training statistics with a new configuration. + + :param config: dict, new configuration + """ + set_seeds(self.args.seed) + self.train_step = 0 + self.val_count = 0 + self.continue_training = True + self.timer = Timer() + self.config = self.args.config = config + + # train statistics + self.trlog = {} + self.trlog["args"] = vars(self.args) + self.trlog["train_loss"] = [] + self.trlog["best_epoch"] = 0 + if self.is_regression: + self.trlog["best_res"] = 1e10 + else: + self.trlog["best_res"] = 0 + + def data_format(self, is_train=True, N=None, C=None, y=None): + """Format the data for training or testing. + + :param is_train: bool, whether the data is for training or testing + :param N: dict, numerical data + :param C: dict, categorical data + :param y: dict, labels + """ + if is_train: + self.N, self.C, self.num_new_value, self.imputer, self.cat_new_value = ( + data_nan_process( + self.N, self.C, self.args.num_nan_policy, self.args.cat_nan_policy + ) + ) + self.y, self.y_info, self.label_encoder = data_label_process( + self.y, self.is_regression + ) + self.N, self.num_encoder = num_enc_process( + self.N, + num_policy=self.args.num_policy, + n_bins=self.args.config["training"]["n_bins"], + y_train=self.y["train"], + is_regression=self.is_regression, + ) + self.N, self.C, self.ord_encoder, self.mode_values, self.cat_encoder = ( + data_enc_process(self.N, self.C, self.args.cat_policy, self.y["train"]) + ) + self.N, self.normalizer = data_norm_process( + self.N, self.args.normalization, self.args.seed + ) + + if self.is_regression: + self.d_out = 1 + else: + self.d_out = len(np.unique(self.y["train"])) + self.d_in = 0 if self.N is None else self.N["train"].shape[1] + self.categories = get_categories(self.C) + ( + self.N, + self.C, + self.y, + self.train_loader, + self.val_loader, + self.criterion, + ) = data_loader_process( + self.is_regression, + (self.N, self.C), + self.y, + self.y_info, + self.args.device, + self.args.batch_size, + is_train=True, + ) + + else: + N_test, C_test, _, _, _ = data_nan_process( + N, + C, + self.args.num_nan_policy, + self.args.cat_nan_policy, + self.num_new_value, + self.imputer, + self.cat_new_value, + ) + y_test, _, _ = data_label_process( + y, self.is_regression, self.y_info, self.label_encoder + ) + N_test, _ = num_enc_process( + N_test, + num_policy=self.args.num_policy, + n_bins=self.args.config["training"]["n_bins"], + y_train=None, + encoder=self.num_encoder, + ) + N_test, C_test, _, _, _ = data_enc_process( + N_test, + C_test, + self.args.cat_policy, + None, + self.ord_encoder, + self.mode_values, + self.cat_encoder, + ) + N_test, _ = data_norm_process( + N_test, self.args.normalization, self.args.seed, self.normalizer + ) + _, _, _, self.test_loader, _ = data_loader_process( + self.is_regression, + (N_test, C_test), + y_test, + self.y_info, + self.args.device, + self.args.batch_size, + is_train=False, + ) + if N_test is not None and C_test is not None: + self.N_test, self.C_test = N_test["test"], C_test["test"] + elif N_test is None and C_test is not None: + self.N_test, self.C_test = None, C_test["test"] + else: + self.N_test, self.C_test = N_test["test"], None + self.y_test = y_test["test"] + + def fit(self, data, info, train=True, config=None): + """Fit the method to the data. + + :param data: tuple, (N, C, y) + :param info: dict, information about the data + :param train: bool, whether to train the method + :param config: dict, configuration for the method + :return: float, time cost + """ + # if the method already fit the dataset, skip these steps (such as the hyper-tune process) + N, C, y = data + self.D = Dataset(N, C, y, info) + self.N, self.C, self.y = self.D.N, self.D.C, self.D.y + self.is_binclass, self.is_multiclass, self.is_regression = ( + self.D.is_binclass, + self.D.is_multiclass, + self.D.is_regression, + ) + self.n_num_features, self.n_cat_features = ( + self.D.n_num_features, + self.D.n_cat_features, + ) + if config is not None: + self.reset_stats_withconfig(config) + self.data_format(is_train=True) + self.construct_model() + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=self.args.config["training"]["lr"], + weight_decay=self.args.config["training"]["weight_decay"], + ) + # self.optimizer = AdamWScheduleFree( + # self.model.parameters(), + # lr=self.args.config['training']['lr'], + # # lr = 0.05, + # # beta=0.9 + # ) + # if not train, skip the training process. such as load the checkpoint and directly predict the results + if not train: + return None + + time_cost = 0 + for epoch in range(self.args.max_epoch): + tic = time.time() + self.train_epoch(epoch) + self.validate(epoch) + elapsed = time.time() - tic + time_cost += elapsed + print(f"Epoch: {epoch}, Time cost: {elapsed}") + if not self.continue_training: + break + torch.save( + {"params": self.model.state_dict()}, + osp.join(self.args.save_path, f"epoch-last-{self.args.seed!s}.pth"), + ) + return time_cost + + def predict(self, data, info, model_name): + """Predict the results of the data. + + :param data: tuple, (N, C, y) + :param info: dict, information about the data + :param model_name: str, name of the model + :return: tuple, (loss, metric, metric_name, predictions) + """ + N, C, y = data + self.model.load_state_dict( + torch.load( + osp.join(self.args.save_path, model_name + f"-{self.args.seed!s}.pth") + )["params"] + ) + print( + "best epoch {}, best val res={:.4f}".format( + self.trlog["best_epoch"], self.trlog["best_res"] + ) + ) + ## Evaluation Stage + self.model.eval() + # self.optimizer.eval() + self.data_format(False, N, C, y) + + test_logit, test_label = [], [] + with torch.no_grad(): + for _i, (X, y) in tqdm(enumerate(self.test_loader)): + if self.N is not None and self.C is not None: + X_num, X_cat = X[0], X[1] + elif self.C is not None and self.N is None: + X_num, X_cat = None, X + else: + X_num, X_cat = X, None + + pred = self.model(X_num, X_cat) + + test_logit.append(pred) + test_label.append(y) + + test_logit = torch.cat(test_logit, 0) + test_label = torch.cat(test_label, 0) + + vl = self.criterion(test_logit, test_label).item() + + vres, metric_name = self.metric(test_logit, test_label, self.y_info) + + print(f"Test: loss={vl:.4f}") + for name, res in zip(metric_name, vres): + print(f"[{name}]={res:.4f}") + + return vl, vres, metric_name, test_logit + + def train_epoch(self, epoch): + """Train the model for one epoch. + + :param epoch: int, the current epoch + """ + self.model.train() + # self.optimizer.train() + tl = Averager() + for i, (X, y) in enumerate(self.train_loader, 1): + self.train_step = self.train_step + 1 + if self.N is not None and self.C is not None: + X_num, X_cat = X[0], X[1] + elif self.C is not None and self.N is None: + X_num, X_cat = None, X + else: + X_num, X_cat = X, None + + loss = self.criterion(self.model(X_num, X_cat), y) + + tl.add(loss.item()) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + if (i - 1) % 50 == 0 or i == len(self.train_loader): + print( + "epoch {}, train {}/{}, loss={:.4f} lr={:.4g}".format( + epoch, + i, + len(self.train_loader), + loss.item(), + self.optimizer.param_groups[0]["lr"], + ) + ) + del loss + tl = tl.item() + self.trlog["train_loss"].append(tl) + + def validate(self, epoch): + """Validate the model. + + :param epoch: int, the current epoch + """ + print( + "best epoch {}, best val res={:.4f}".format( + self.trlog["best_epoch"], self.trlog["best_res"] + ) + ) + + ## Evaluation Stage + self.model.eval() + # self.optimizer.eval() + test_logit, test_label = [], [] + with torch.no_grad(): + for _i, (X, y) in tqdm(enumerate(self.val_loader)): + if self.N is not None and self.C is not None: + X_num, X_cat = X[0], X[1] + elif self.C is not None and self.N is None: + X_num, X_cat = None, X + else: + X_num, X_cat = X, None + + pred = self.model(X_num, X_cat) + + test_logit.append(pred) + test_label.append(y) + + test_logit = torch.cat(test_logit, 0) + test_label = torch.cat(test_label, 0) + + vl = self.criterion(test_logit, test_label).item() + + if self.is_regression: + task_type = "regression" + measure = np.less_equal + else: + task_type = "classification" + measure = np.greater_equal + + vres, metric_name = self.metric(test_logit, test_label, self.y_info) + + print(f"epoch {epoch}, val, loss={vl:.4f} {task_type} result={vres[0]:.4f}") + if measure(vres[0], self.trlog["best_res"]) or epoch == 0: + self.trlog["best_res"] = vres[0] + self.trlog["best_epoch"] = epoch + torch.save( + {"params": self.model.state_dict()}, + osp.join(self.args.save_path, f"best-val-{self.args.seed!s}.pth"), + ) + self.val_count = 0 + else: + self.val_count += 1 + if self.val_count > 20: + self.continue_training = False + torch.save(self.trlog, osp.join(self.args.save_path, "trlog")) + + def metric(self, predictions, labels, y_info): + """Compute the evaluation metric. + + :param predictions: np.ndarray, predictions + :param labels: np.ndarray, labels + :param y_info: dict, information about the labels + :return: tuple, (metric, metric_name) + """ + if not isinstance(labels, np.ndarray): + labels = labels.cpu().numpy() + if not isinstance(predictions, np.ndarray): + predictions = predictions.cpu().numpy() + if self.is_regression: + mae = skm.mean_absolute_error(labels, predictions) + rmse = skm.mean_squared_error(labels, predictions) ** 0.5 + r2 = skm.r2_score(labels, predictions) + if y_info["policy"] == "mean_std": + mae *= y_info["std"] + rmse *= y_info["std"] + return (mae, r2, rmse), ("MAE", "R2", "RMSE") + if self.is_binclass: + # if not softmax, convert to probabilities + predictions = check_softmax(predictions) + accuracy = skm.accuracy_score(labels, predictions.argmax(axis=-1)) + avg_recall = skm.balanced_accuracy_score( + labels, predictions.argmax(axis=-1) + ) + avg_precision = skm.precision_score( + labels, predictions.argmax(axis=-1), average="macro" + ) + f1_score = skm.f1_score( + labels, predictions.argmax(axis=-1), average="binary" + ) + log_loss = skm.log_loss(labels, predictions) + auc = skm.roc_auc_score(labels, predictions[:, 1]) + return (accuracy, avg_recall, avg_precision, f1_score, log_loss, auc), ( + "Accuracy", + "Avg_Recall", + "Avg_Precision", + "F1", + "LogLoss", + "AUC", + ) + if self.is_multiclass: + # if not softmax, convert to probabilities + predictions = check_softmax(predictions) + accuracy = skm.accuracy_score(labels, predictions.argmax(axis=-1)) + avg_recall = skm.balanced_accuracy_score( + labels, predictions.argmax(axis=-1) + ) + avg_precision = skm.precision_score( + labels, predictions.argmax(axis=-1), average="macro" + ) + f1_score = skm.f1_score( + labels, predictions.argmax(axis=-1), average="macro" + ) + log_loss = skm.log_loss(labels, predictions) + auc = skm.roc_auc_score( + labels, predictions, average="macro", multi_class="ovr" + ) + return (accuracy, avg_recall, avg_precision, f1_score, log_loss, auc), ( + "Accuracy", + "Avg_Recall", + "Avg_Precision", + "F1", + "LogLoss", + "AUC", + ) + raise ValueError("Unknown tabular task type") diff --git a/tabrepo/benchmark/models/ag/beta/deps/talent_num_embeddings.py b/tabrepo/benchmark/models/ag/beta/deps/talent_num_embeddings.py new file mode 100644 index 00000000..860b1ffd --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/talent_num_embeddings.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import math +import warnings +from typing import Any + +import torch +from torch import Tensor, nn +from tqdm import tqdm + +try: + import sklearn.tree as sklearn_tree +except ImportError: + sklearn_tree = None + + +def _check_bins(bins: list[Tensor]) -> None: + if not bins: + raise ValueError("The list of bins must not be empty") + for i, feature_bins in enumerate(bins): + if not isinstance(feature_bins, Tensor): + raise ValueError( + "bins must be a list of PyTorch tensors. " + f"However, for {i=}: {type(feature_bins)=}" + ) + if feature_bins.ndim != 1: + raise ValueError( + "Each item of the bin list must have exactly one dimension." + f" However, for {i=}: {feature_bins.ndim=}" + ) + if len(feature_bins) < 2: + raise ValueError( + "All features must have at least two bin edges." + f" However, for {i=}: {len(feature_bins)=}" + ) + if not feature_bins.isfinite().all(): + raise ValueError( + "Bin edges must not contain nan/inf/-inf." + f" However, this is not true for the {i}-th feature" + ) + if (feature_bins[:-1] >= feature_bins[1:]).any(): + raise ValueError( + "Bin edges must be sorted." + f" However, the for the {i}-th feature, the bin edges are not sorted" + ) + if len(feature_bins) == 2: + warnings.warn( + f"The {i}-th feature has just two bin edges, which means only one bin." + " Strictly speaking, using a single bin for the" + " piecewise-linear encoding should not break anything," + " but it is the same as using sklearn.preprocessing.MinMaxScaler", + stacklevel=2, + ) + + +def compute_bins( + X: torch.Tensor, + n_bins: int = 48, + *, + tree_kwargs: dict[str, Any] | None = None, + y: Tensor | None = None, + regression: bool | None = None, + verbose: bool = False, +) -> list[Tensor]: + # Source: https://github.com/yandex-research/rtdl-num-embeddings/blob/main/package/rtdl_num_embeddings.py + """Compute bin edges for `PiecewiseLinearEmbeddings`. + + **Usage** + + Computing the quantile-based bins (Section 3.2.1 in the paper): + + >>> X_train = torch.randn(10000, 2) + >>> bins = compute_bins(X_train) + + Computing the tree-based bins (Section 3.2.2 in the paper): + + >>> X_train = torch.randn(10000, 2) + >>> y_train = torch.randn(len(X_train)) + >>> bins = compute_bins( + ... X_train, + ... y=y_train, + ... regression=True, + ... tree_kwargs={'min_samples_leaf': 64, 'min_impurity_decrease': 1e-4}, + ... ) + + Args: + X: the training features. + n_bins: the number of bins. + tree_kwargs: keyword arguments for `sklearn.tree.DecisionTreeRegressor` + (if ``regression`` is `True`) or `sklearn.tree.DecisionTreeClassifier` + (if ``regression`` is `False`). + NOTE: requires ``scikit-learn>=1.0,>2`` to be installed. + y: the training labels (must be provided if ``tree`` is not None). + regression: whether the labels are regression labels + (must be provided if ``tree`` is not None). + verbose: if True and ``tree_kwargs`` is not None, than ``tqdm`` + (must be installed) will report the progress while fitting trees. + + Returns: + A list of bin edges for all features. For one feature: + + - the maximum possible number of bin edges is ``n_bins + 1``. + - the minumum possible number of bin edges is ``1``. + """ + if not isinstance(X, Tensor): + raise ValueError(f"X must be a PyTorch tensor, however: {type(X)=}") + if X.ndim != 2: + raise ValueError(f"X must have exactly two dimensions, however: {X.ndim=}") + if X.shape[0] < 2: + raise ValueError(f"X must have at least two rows, however: {X.shape[0]=}") + if X.shape[1] < 1: + raise ValueError(f"X must have at least one column, however: {X.shape[1]=}") + if not X.isfinite().all(): + raise ValueError("X must not contain nan/inf/-inf.") + if (X[0] == X).all(dim=0).any(): + raise ValueError( + "All columns of X must have at least two distinct values." + " However, X contains columns with just one distinct value." + ) + if n_bins <= 1 or n_bins >= len(X): + raise ValueError( + "n_bins must be more than 1, but less than len(X), however:" + f" {n_bins=}, {len(X)=}" + ) + + if tree_kwargs is None: + if y is not None or regression is not None or verbose: + raise ValueError( + "If tree_kwargs is None, then y must be None, regression must be None" + " and verbose must be False" + ) + + # NOTE[DIFF] + # The original implementation in the official paper repository has an + # unintentional divergence from what is written in the paper. + # This package implements the algorithm described in the paper, + # and it is recommended for future work + # (this may affect the optimal number of bins + # reported in the official repository). + # + # Additional notes: + # - this is the line where the divergence happens: + # (the thing is that limiting the number of quantiles by the number of + # distinct values is NOT the same as removing identical quantiles + # after computing them) + # https://github.com/yandex-research/tabular-dl-num-embeddings/blob/c1d9eb63c0685b51d7e1bc081cdce6ffdb8886a8/bin/train4.py#L612C30-L612C30 + # - for the tree-based bins, there is NO such divergence; + bins = [ + q.unique() + for q in torch.quantile( + X, torch.linspace(0.0, 1.0, n_bins + 1).to(X), dim=0 + ).T + ] + _check_bins(bins) + return bins + if sklearn_tree is None: + raise RuntimeError( + "The scikit-learn package is missing." + " See README.md for installation instructions" + ) + if y is None or regression is None: + raise ValueError( + "If tree_kwargs is not None, then y and regression must not be None" + ) + if y.ndim != 1: + raise ValueError(f"y must have exactly one dimension, however: {y.ndim=}") + if len(y) != len(X): + raise ValueError( + f"len(y) must be equal to len(X), however: {len(y)=}, {len(X)=}" + ) + if y is None or regression is None: + raise ValueError( + "If tree_kwargs is not None, then y and regression must not be None" + ) + if "max_leaf_nodes" in tree_kwargs: + raise ValueError( + 'tree_kwargs must not contain the key "max_leaf_nodes"' + " (it will be set to n_bins automatically)." + ) + + if verbose: + if tqdm is None: + raise ImportError("If verbose is True, tqdm must be installed") + tqdm_ = tqdm + else: + tqdm_ = lambda x: x + + if X.device.type != "cpu" or y.device.type != "cpu": + warnings.warn( + "Computing tree-based bins involves the conversion of the input PyTorch" + " tensors to NumPy arrays. The provided PyTorch tensors are not" + " located on CPU, so the conversion has some overhead.", + UserWarning, + stacklevel=2, + ) + X_numpy = X.cpu().numpy() + y_numpy = y.cpu().numpy() + bins = [] + for column in tqdm_(X_numpy.T): + feature_bin_edges = [float(column.min()), float(column.max())] + tree = ( + ( + sklearn_tree.DecisionTreeRegressor + if regression + else sklearn_tree.DecisionTreeClassifier + )(max_leaf_nodes=n_bins, **tree_kwargs) + .fit(column.reshape(-1, 1), y_numpy) + .tree_ + ) + for node_id in range(tree.node_count): + # The following condition is True only for split nodes. Source: + # https://scikit-learn.org/1.0/auto_examples/tree/plot_unveil_tree_structure.html#tree-structure + if tree.children_left[node_id] != tree.children_right[node_id]: + feature_bin_edges.append(float(tree.threshold[node_id])) + bins.append(torch.as_tensor(feature_bin_edges).unique()) + _check_bins(bins) + return [x.to(device=X.device, dtype=X.dtype) for x in bins] + + +class _PiecewiseLinearEncodingImpl(nn.Module): + # NOTE + # 1. DO NOT USE THIS CLASS DIRECTLY (ITS OUTPUT CONTAINS INFINITE VALUES). + # 2. This implementation is not memory efficient for cases when there are many + # features with low number of bins and only few features + # with high number of bins. If this becomes a problem, + # just split features into groups and encode the groups separately. + + # The output of this module has the shape (*batch_dims, n_features, max_n_bins), + # where max_n_bins = max(map(len, bins)) - 1. + # If the i-th feature has the number of bins less than max_n_bins, + # then its piecewise-linear representation is padded with inf as follows: + # [x_1, x_2, ..., x_k, inf, ..., inf] + # where: + # x_1 <= 1.0 + # 0.0 <= x_i <= 1.0 (for i in range(2, k)) + # 0.0 <= x_k + # k == len(bins[i]) - 1 (the number of bins for the i-th feature) + + # If all features have the same number of bins, then there are no infinite values. + + edges: Tensor + width: Tensor + mask: Tensor + + # Source: https://github.com/yandex-research/rtdl-num-embeddings/blob/main/package/rtdl_num_embeddings.py + def __init__(self, bins: list[Tensor]) -> None: + _check_bins(bins) + + super().__init__() + # To stack bins to a tensor, all features must have the same number of bins. + # To achieve that, for each feature with a less-than-max number of bins, + # its bins are padded with additional phantom bins with infinite edges. + max_n_edges = max(len(x) for x in bins) + padding = torch.full( + (max_n_edges,), + math.inf, + dtype=bins[0].dtype, + device=bins[0].device, + ) + edges = torch.row_stack([torch.cat([x, padding])[:max_n_edges] for x in bins]) + + # The rightmost edge is needed only to compute the width of the rightmost bin. + self.register_buffer("edges", edges[:, :-1]) + self.register_buffer("width", edges.diff()) + # mask is false for the padding values. + self.register_buffer( + "mask", + torch.row_stack( + [ + torch.cat( + [ + torch.ones(len(x) - 1, dtype=torch.bool, device=x.device), + torch.zeros( + max_n_edges - 1, dtype=torch.bool, device=x.device + ), + ] + )[: max_n_edges - 1] + for x in bins + ] + ), + ) + self._bin_counts = tuple(len(x) - 1 for x in bins) + self._same_bin_count = all(x == self._bin_counts[0] for x in self._bin_counts) + + def forward(self, x: Tensor) -> Tensor: + if x.ndim < 2: + raise ValueError( + f"The input must have at least two dimensions, however: {x.ndim=}" + ) + + # See Equation 1 in the paper. + x = (x[..., None] - self.edges) / self.width + + # If the number of bins is greater than 1, then, the following rules must + # be applied to a piecewise-linear encoding of a single feature: + # - the leftmost value can be negative, but not greater than 1.0. + # - the rightmost value can be greater than 1.0, but not negative. + # - the intermediate values must stay within [0.0, 1.0]. + n_bins = x.shape[-1] + if n_bins > 1: + if self._same_bin_count: + x = torch.cat( + [ + x[..., :1].clamp_max(1.0), + *([] if n_bins == 2 else [x[..., 1:-1].clamp(0.0, 1.0)]), + x[..., -1:].clamp_min(0.0), + ], + dim=-1, + ) + else: + # In this case, the rightmost values for all features are located + # in different columns. + x = torch.stack( + [ + x[..., i, :] + if count == 1 + else torch.cat( + [ + x[..., i, :1].clamp_max(1.0), + *( + [] + if n_bins == 2 + else [x[..., i, 1 : count - 1].clamp(0.0, 1.0)] + ), + x[..., i, count - 1 : count].clamp_min(0.0), + x[..., i, count:], + ], + dim=-1, + ) + for i, count in enumerate(self._bin_counts) + ], + dim=-2, + ) + return x + + +class PiecewiseLinearEncoding(nn.Module): + """Piecewise-linear encoding. + + **Shape** + + - Input: ``(*, n_features)`` + - Output: ``(*, n_features, total_n_bins)``, + where ``total_n_bins`` is the total number of bins for all features: + ``total_n_bins = sum(len(b) - 1 for b in bins)``. + """ + + # Source: https://github.com/yandex-research/rtdl-num-embeddings/blob/main/package/rtdl_num_embeddings.py + def __init__(self, bins: list[Tensor]) -> None: + """Args: + bins: the bins computed by `compute_bins`. + """ + _check_bins(bins) + + super().__init__() + self.impl = _PiecewiseLinearEncodingImpl(bins) + + def forward(self, x: Tensor) -> Tensor: + x = self.impl(x) + return x.flatten(-2) if self.impl._same_bin_count else x[:, self.impl.mask] + + +class _UnaryEncodingImpl(nn.Module): + edges: Tensor + mask: Tensor + + def __init__(self, bins: list[Tensor]) -> None: + _check_bins(bins) + + super().__init__() + # To stack bins to a tensor, all features must have the same number of bins. + # To achieve that, for each feature with a less-than-max number of bins, + # its bins are padded with additional phantom bins with infinite edges. + max_n_edges = max(len(x) for x in bins) + padding = torch.full( + (max_n_edges,), + math.inf, + dtype=bins[0].dtype, + device=bins[0].device, + ) + edges = torch.row_stack([torch.cat([x, padding])[:max_n_edges] for x in bins]) + + # The rightmost edge is needed only to compute the width of the rightmost bin. + self.register_buffer("edges", edges[:, :-1]) + # mask is false for the padding values. + self.register_buffer( + "mask", + torch.row_stack( + [ + torch.cat( + [ + torch.ones(len(x) - 1, dtype=torch.bool, device=x.device), + torch.zeros( + max_n_edges - 1, dtype=torch.bool, device=x.device + ), + ] + )[: max_n_edges - 1] + for x in bins + ] + ), + ) + self._bin_counts = tuple(len(x) - 1 for x in bins) + self._same_bin_count = all(x == self._bin_counts[0] for x in self._bin_counts) + + def forward(self, x: Tensor) -> Tensor: + if x.ndim < 2: + raise ValueError( + f"The input must have at least two dimensions, however: {x.ndim=}" + ) + + # Compute which bin each value falls into + x = (x[..., None] - self.edges).sign().cumsum(dim=-1) + + # Ensure values are within [0, 1] range for unary encoding + return x.clamp(0, 1) + + +class UnaryEncoding(nn.Module): + """Unary encoding. + + **Shape** + + - Input: ``(*, n_features)`` + - Output: ``(*, n_features, total_n_bins)``, + where ``total_n_bins`` is the total number of bins for all features: + ``total_n_bins = sum(len(b) - 1 for b in bins)``. + """ + + def __init__(self, bins: list[Tensor]) -> None: + """Args: + bins: the bins computed by `compute_bins`. + """ + _check_bins(bins) + + super().__init__() + self.impl = _UnaryEncodingImpl(bins) + + def forward(self, x: Tensor) -> Tensor: + x = self.impl(x) + return x.flatten(-2) if self.impl._same_bin_count else x[:, self.impl.mask] + + +class _JohnsonEncodingImpl(nn.Module): + edges: Tensor + mask: Tensor + + def __init__(self, bins: list[Tensor]) -> None: + _check_bins(bins) + + super().__init__() + # To stack bins to a tensor, all features must have the same number of bins. + # To achieve that, for each feature with a less-than-max number of bins, + # its bins are padded with additional phantom bins with infinite edges. + max_n_edges = max(len(x) for x in bins) + padding = torch.full( + (max_n_edges,), + math.inf, + dtype=bins[0].dtype, + device=bins[0].device, + ) + edges = torch.row_stack([torch.cat([x, padding])[:max_n_edges] for x in bins]) + + # The rightmost edge is needed only to compute the width of the rightmost bin. + self.register_buffer("edges", edges[:, :-1]) + self.register_buffer("width", edges.diff()) + # mask is false for the padding values. + self.register_buffer( + "mask", + torch.row_stack( + [ + torch.cat( + [ + torch.ones(len(x) - 1, dtype=torch.bool, device=x.device), + torch.zeros( + max_n_edges - 1, dtype=torch.bool, device=x.device + ), + ] + )[: max_n_edges - 1] + for x in bins + ] + ), + ) + self._bin_counts = tuple(len(x) - 1 for x in bins) + self._same_bin_count = all(x == self._bin_counts[0] for x in self._bin_counts) + + def forward(self, x: Tensor) -> Tensor: + if x.ndim < 2: + raise ValueError( + f"The input must have at least two dimensions, however: {x.ndim=}" + ) + + # Compute which bin each value falls into + bin_indices = torch.stack( + [ + torch.bucketize(x[..., i], self.edges[i], right=True) - 1 + for i in range(x.shape[-1]) + ], + dim=-1, + ) + + # Generate Johnson code for each bin index + max_bin = self.edges.shape[1] + code_length = (max_bin + 1) // 2 + johnson_code = torch.zeros( + *x.shape, code_length, device=x.device, dtype=torch.float32 + ) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + johnson_code[i, j, :] = self.temp_code( + bin_indices[i, j].item(), max_bin + ) + + return johnson_code + + def temp_code(self, num, num_bits): + num_bits = num_bits + 1 if num_bits % 2 != 0 else num_bits + bits = num_bits // 2 + a = torch.zeros([bits], dtype=torch.long) + for i in range(bits): + if bits - i - 1 < num <= num_bits - i - 1: + a[i] = 1 + return a + + +class JohnsonEncoding(nn.Module): + """Johnson encoding. + + **Shape** + + - Input: ``(*, n_features)`` + - Output: ``(*, n_features, total_n_bits)``, + where ``total_n_bits`` is the total number of bits for all features: + ``total_n_bits = sum((len(b) - 1) // 2 for b in bins)``. + """ + + def __init__(self, bins: list[Tensor]) -> None: + """Args: + bins: the bins computed by `compute_bins`. + """ + _check_bins(bins) + + super().__init__() + self.impl = _JohnsonEncodingImpl(bins) + + def forward(self, x: Tensor) -> Tensor: + x = self.impl(x) + return x.flatten(-2) # if self.impl._same_bin_count else x[:, self.impl.mask] + + +class _BinsEncodingImpl(nn.Module): + edges: Tensor + mask: Tensor + + def __init__(self, bins: list[Tensor]) -> None: + _check_bins(bins) + + super().__init__() + # To stack bins to a tensor, all features must have the same number of bins. + # To achieve that, for each feature with a less-than-max number of bins, + # its bins are padded with additional phantom bins with infinite edges. + max_n_edges = max(len(x) for x in bins) + padding = torch.full( + (max_n_edges,), + math.inf, + dtype=bins[0].dtype, + device=bins[0].device, + ) + edges = torch.row_stack([torch.cat([x, padding])[:max_n_edges] for x in bins]) + + # The rightmost edge is needed only to compute the width of the rightmost bin. + self.register_buffer("edges", edges[:, :-1]) + # mask is false for the padding values. + self.register_buffer( + "mask", + torch.row_stack( + [ + torch.cat( + [ + torch.ones(len(x) - 1, dtype=torch.bool, device=x.device), + torch.zeros( + max_n_edges - 1, dtype=torch.bool, device=x.device + ), + ] + )[: max_n_edges - 1] + for x in bins + ] + ), + ) + self._bin_counts = tuple(len(x) - 1 for x in bins) + self._same_bin_count = all(x == self._bin_counts[0] for x in self._bin_counts) + + def forward(self, x: Tensor) -> Tensor: + if x.ndim < 2: + raise ValueError( + f"The input must have at least two dimensions, however: {x.ndim=}" + ) + + # Compute which bin each value falls into + return torch.stack( + [ + torch.bucketize(x[..., i], self.edges[i], right=True) - 1 + for i in range(x.shape[-1]) + ], + dim=-1, + ) + + +class BinsEncoding(nn.Module): + """Bins encoding. + + **Shape** + + - Input: ``(*, n_features)`` + - Output: ``(*, n_features, total_n_bins)``, + where ``total_n_bins`` is the total number of bins for all features. + """ + + def __init__(self, bins: list[Tensor]) -> None: + """Args: + bins: the bins computed by `compute_bins`. + """ + _check_bins(bins) + + super().__init__() + self.impl = _BinsEncodingImpl(bins) + + def forward(self, x: Tensor) -> Tensor: + return self.impl(x) diff --git a/tabrepo/benchmark/models/ag/beta/deps/talent_utils.py b/tabrepo/benchmark/models/ag/beta/deps/talent_utils.py new file mode 100644 index 00000000..7cf0ca53 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/deps/talent_utils.py @@ -0,0 +1,812 @@ +from __future__ import annotations + +import errno +import json +import os +import os.path as osp +import pprint +import random +import shutil +import time + +import numpy as np +import torch + +THIS_PATH = os.path.dirname(__file__) + + +def mkdir(path): + """Create a directory if it does not exist. + + :path: str, path to the directory + """ + try: + os.makedirs(path) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + + +def set_gpu(x): + """Set environment variable CUDA_VISIBLE_DEVICES. + + :x: str, GPU id + """ + os.environ["CUDA_VISIBLE_DEVICES"] = x + print("using gpu:", x) + + +def ensure_path(path, remove=True): + """Ensure a path exists. + + path: str, path to the directory + remove: bool, whether to remove the directory if it exists + """ + if os.path.exists(path): + if remove and input(f"{path} exists, remove? ([y]/n)") != "n": + shutil.rmtree(path) + os.mkdir(path) + else: + os.mkdir(path) + + +# --- criteria helper --- +class Averager: + """A simple averager.""" + + def __init__(self): + self.n = 0 + self.v = 0 + + def add(self, x): + """:x: float, value to be added.""" + self.v = (self.v * self.n + x) / (self.n + 1) + self.n += 1 + + def item(self): + return self.v + + +class Timer: + def __init__(self): + self.o = time.time() + + def measure(self, p=1): + """Measure the time since the last call to measure. + + :p: int, period of printing the time + """ + x = (time.time() - self.o) / p + x = int(x) + if x >= 3600: + return f"{x / 3600:.1f}h" + if x >= 60: + return f"{round(x / 60)}m" + return f"{x}s" + + +_utils_pp = pprint.PrettyPrinter() + + +def pprint(x): + _utils_pp.pprint(x) + + +# ---- import from lib.util ----------- +def set_seeds(base_seed: int, one_cuda_seed: bool = False) -> None: + """Set random seeds for reproducibility. + + :base_seed: int, base seed + :one_cuda_seed: bool, whether to set one seed for all GPUs + """ + assert 0 <= base_seed < 2**32 - 10000 + random.seed(base_seed) + np.random.seed(base_seed + 1) + torch.manual_seed(base_seed + 2) + cuda_seed = base_seed + 3 + if one_cuda_seed: + torch.cuda.manual_seed_all(cuda_seed) + elif torch.cuda.is_available(): + # the following check should never succeed since torch.manual_seed also calls + # torch.cuda.manual_seed_all() inside; but let's keep it just in case + if not torch.cuda.is_initialized(): + torch.cuda.init() + # Source: https://github.com/pytorch/pytorch/blob/2f68878a055d7f1064dded1afac05bb2cb11548f/torch/cuda/random.py#L109 + for i in range(torch.cuda.device_count()): + default_generator = torch.cuda.default_generators[i] + default_generator.manual_seed(cuda_seed + i) + + +def get_device() -> torch.device: + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +import sklearn.metrics as skm + + +def rmse(y, prediction, y_info): + """:y: np.ndarray, ground truth + :prediction: np.ndarray, prediction + :y_info: dict, information about the target variable + :return: float, root mean squared error. + """ + rmse = skm.mean_squared_error(y, prediction) ** 0.5 # type: ignore[code] + if y_info["policy"] == "mean_std": + rmse *= y_info["std"] + return rmse + + +def load_config(args, config=None, config_name=None): + """Load the config file. + + :args: argparse.Namespace, arguments + :config: dict, config file + :config_name: str, name of the config file + :return: argparse.Namespace, arguments + """ + if config is None: + config_path = os.path.join( + os.path.abspath(os.path.join(THIS_PATH, "..")), + "configs", + args.dataset, + f"{args.model_type if args.config_name is None else args.config_name}.json", + ) + with open(config_path) as fp: + config = json.load(fp) + + # set additional parameters + args.config = config + + # save the config files + with open( + os.path.join( + args.save_path, + "{}.json".format("config" if config_name is None else config_name), + ), + "w", + ) as fp: + args_dict = vars(args) + if "device" in args_dict: + del args_dict["device"] + json.dump(args_dict, fp, sort_keys=True, indent=4) + + return args + + +# parameter search +def sample_parameters(trial, space, base_config): + """Sample hyper-parameters. + + :trial: optuna.trial.Trial, trial + :space: dict, search space + :base_config: dict, base configuration + :return: dict, sampled hyper-parameters + """ + + def get_distribution(distribution_name): + return getattr(trial, f"suggest_{distribution_name}") + + result = {} + for label, subspace in space.items(): + if isinstance(subspace, dict): + result[label] = sample_parameters(trial, subspace, base_config) + else: + assert isinstance(subspace, list) + distribution, *args = subspace + + if distribution.startswith("?"): + default_value = args[0] + result[label] = ( + get_distribution(distribution.lstrip("?"))(label, *args[1:]) + if trial.suggest_categorical(f"optional_{label}", [False, True]) + else default_value + ) + + elif distribution == "$mlp_d_layers": + min_n_layers, max_n_layers, d_min, d_max = args + n_layers = trial.suggest_int("n_layers", min_n_layers, max_n_layers) + suggest_dim = lambda name: trial.suggest_int(name, d_min, d_max) # noqa + d_first = [suggest_dim("d_first")] if n_layers else [] + d_middle = ( + [suggest_dim("d_middle")] * (n_layers - 2) if n_layers > 2 else [] + ) + d_last = [suggest_dim("d_last")] if n_layers > 1 else [] + result[label] = d_first + d_middle + d_last + + elif distribution == "$d_token": + assert len(args) == 2 + try: + n_heads = base_config["model"]["n_heads"] + except KeyError: + n_heads = base_config["model"]["n_latent_heads"] + + for x in args: + assert x % n_heads == 0 + result[label] = trial.suggest_int("d_token", *args, n_heads) # type: ignore[code] + + elif distribution in ["$d_ffn_factor", "$d_hidden_factor"]: + if base_config["model"]["activation"].endswith("glu"): + args = (args[0] * 2 / 3, args[1] * 2 / 3) + result[label] = trial.suggest_uniform("d_ffn_factor", *args) + + else: + result[label] = get_distribution(distribution)(label, *args) + return result + + +def merge_sampled_parameters(config, sampled_parameters): + """Merge the sampled hyper-parameters. + + :config: dict, configuration + :sampled_parameters: dict, sampled hyper-parameters + """ + for k, v in sampled_parameters.items(): + if isinstance(v, dict): + merge_sampled_parameters(config.setdefault(k, {}), v) + else: + # If there are parameters in the default config, the value of the parameter will be overwritten. + config[k] = v + + +def get_classical_args(): + """Get the arguments for classical models. + + :return: argparse.Namespace, arguments + """ + import argparse + import warnings + + warnings.filterwarnings("ignore") + with open("configs/classical_configs.json") as file: + default_args = json.load(file) + parser = argparse.ArgumentParser() + # basic parameters + parser.add_argument("--dataset", type=str, default=default_args["dataset"]) + parser.add_argument( + "--model_type", + type=str, + default=default_args["model_type"], + choices=[ + "LogReg", + "NCM", + "RandomForest", + "xgboost", + "catboost", + "lightgbm", + "svm", + "knn", + "NaiveBayes", + "dummy", + "LinearRegression", + ], + ) + + # optimization parameters + parser.add_argument( + "--normalization", + type=str, + default=default_args["normalization"], + choices=["none", "standard", "minmax", "quantile", "maxabs", "power", "robust"], + ) + parser.add_argument( + "--num_nan_policy", + type=str, + default=default_args["num_nan_policy"], + choices=["mean", "median"], + ) + parser.add_argument( + "--cat_nan_policy", + type=str, + default=default_args["cat_nan_policy"], + choices=["new", "most_frequent"], + ) + parser.add_argument( + "--cat_policy", + type=str, + default=default_args["cat_policy"], + choices=[ + "indices", + "ordinal", + "ohe", + "binary", + "hash", + "loo", + "target", + "catboost", + ], + ) + parser.add_argument( + "--num_policy", + type=str, + default=default_args["num_policy"], + choices=[ + "none", + "Q_PLE", + "T_PLE", + "Q_Unary", + "T_Unary", + "Q_bins", + "T_bins", + "Q_Johnson", + "T_Johnson", + ], + ) + parser.add_argument("--n_bins", type=int, default=default_args["n_bins"]) + parser.add_argument( + "--cat_min_frequency", type=float, default=default_args["cat_min_frequency"] + ) + + # other choices + parser.add_argument("--n_trials", type=int, default=default_args["n_trials"]) + parser.add_argument("--seed_num", type=int, default=default_args["seed_num"]) + parser.add_argument("--gpu", default=default_args["gpu"]) + parser.add_argument("--tune", action="store_true", default=default_args["tune"]) + parser.add_argument("--retune", action="store_true", default=default_args["retune"]) + parser.add_argument( + "--dataset_path", type=str, default=default_args["dataset_path"] + ) + parser.add_argument("--model_path", type=str, default=default_args["model_path"]) + parser.add_argument( + "--evaluate_option", type=str, default=default_args["evaluate_option"] + ) + args = parser.parse_args() + + set_gpu(args.gpu) + save_path1 = "-".join([args.dataset, args.model_type]) + + save_path2 = f"Norm-{args.normalization}" + save_path2 += f"-Nan-{args.num_nan_policy}-{args.cat_nan_policy}" + save_path2 += f"-Cat-{args.cat_policy}" + + if args.cat_min_frequency > 0.0: + save_path2 += f"-CatFreq-{args.cat_min_frequency}" + if args.tune: + save_path1 += "-Tune" + + save_path = osp.join(save_path1, save_path2) + args.save_path = osp.join(args.model_path, save_path) + mkdir(args.save_path) + + # load config parameters + args.seed = 0 + + config_default_path = os.path.join("configs", "default", args.model_type + ".json") + config_opt_path = os.path.join("configs", "opt_space", args.model_type + ".json") + with open(config_default_path) as file: + default_para = json.load(file) + + with open(config_opt_path) as file: + opt_space = json.load(file) + + args.config = default_para[args.model_type] + set_seeds(args.seed) + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + pprint(vars(args)) + + args.config["fit"]["n_bins"] = args.n_bins + return args, default_para, opt_space + + +def get_deep_args(): + """Get the arguments for deep learning models. + + :return: argparse.Namespace, arguments + """ + import argparse + import warnings + + warnings.filterwarnings("ignore") + + argparse.ArgumentParser() + # basic parameters + default_args = { + "dataset": "yeast", + "model_type": "Beta", + "max_epoch": 200, + "batch_size": 1024, + "normalization": "standard", + "num_nan_policy": "mean", + "cat_nan_policy": "new", + "cat_policy": "indices", + "num_policy": "none", + "n_bins": 2, + "cat_min_frequency": 0.0, + "n_trials": 100, + "seed_num": 5, + "workers": 0, + "gpu": "0", + "tune": False, + "retune": False, + "evaluate_option": "best-val", + "dataset_path": "data", + "model_path": "results_model", + } + from types import SimpleNamespace + + args = SimpleNamespace(**default_args) + + set_gpu(args.gpu) + save_path1 = "-".join([args.dataset, args.model_type]) + save_path2 = f"Epoch{args.max_epoch}BZ{args.batch_size}" + save_path2 += f"-Norm-{args.normalization}" + save_path2 += f"-Nan-{args.num_nan_policy}-{args.cat_nan_policy}" + save_path2 += f"-Cat-{args.cat_policy}" + + if args.cat_min_frequency > 0.0: + save_path2 += f"-CatFreq-{args.cat_min_frequency}" + if args.tune: + save_path1 += "-Tune" + + save_path = osp.join(save_path1, save_path2) + args.save_path = osp.join(args.model_path, save_path) + # mkdir(args.save_path) + + # load config parameters + default_para = { + "Beta": { + "model": { + "arch_type": "tabm-mini", + "k": 16, + "num_embeddings": { + "type": "PLREmbeddings", + "n_frequencies": 72, + "frequency_scale": 0.04, + "d_embedding": 32, + "lite": True, + }, + "backbone": { + "type": "MLP", + "n_blocks": 2, + "d_block": 100, + "dropout": 0.15, + }, + }, + "training": {"lr": 0.003, "weight_decay": 0.02}, + "general": {}, + } + } + opt_space = { + "Beta": { + "model": { + "arch_type": "tabm", + "k": 3, + "num_embeddings": { + "type": "PLREmbeddings", + "n_frequencies": 77, + "frequency_scale": 0.04431360576139521, + "d_embedding": 34, + "lite": True, + }, + "backbone": { + "type": "MLP", + "n_blocks": 2, + "d_block": 256, + "dropout": 0.1, + }, + "temperature": 1, + "sample_rate": 0.5, + }, + "training": {"lr": 0.01, "weight_decay": 1e-5}, + "general": {}, + } + } + args.config = default_para[args.model_type] + + args.seed = 0 + set_seeds(args.seed) + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + pprint(vars(args)) + + args.config["training"]["n_bins"] = args.n_bins + return args, default_para, opt_space + + +def show_results_classical(args, info, metric_name, results_list, time_list): + """Show the results for classical models. + + :args: argparse.Namespace, arguments + :info: dict, information about the dataset + :metric_name: list, names of the metrics + :results_list: list, list of results + :time_list: list, list of time + """ + metric_arrays = {name: [] for name in metric_name} + + for result in results_list: + for idx, name in enumerate(metric_name): + metric_arrays[name].append(result[idx]) + + metric_arrays["Time"] = time_list + metric_name = (*metric_name, "Time") + + mean_metrics = {name: np.mean(metric_arrays[name]) for name in metric_name} + std_metrics = {name: np.std(metric_arrays[name]) for name in metric_name} + + # Printing results + print(f"{args.model_type}: {args.seed_num} Trials") + for name in metric_name: + if info["task_type"] == "regression" and name != "Time": + formatted_results = ", ".join([f"{e:.8e}" for e in metric_arrays[name]]) + print(f"{name} Results: {formatted_results}") + print(f"{name} MEAN = {mean_metrics[name]:.8e} ± {std_metrics[name]:.8e}") + else: + formatted_results = ", ".join([f"{e:.8f}" for e in metric_arrays[name]]) + print(f"{name} Results: {formatted_results}") + print(f"{name} MEAN = {mean_metrics[name]:.8f} ± {std_metrics[name]:.8f}") + + print("-" * 20, "GPU info", "-" * 20) + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + print(f"{num_gpus} GPU Available.") + for i in range(num_gpus): + gpu_info = torch.cuda.get_device_properties(i) + print(f"GPU {i}: {gpu_info.name}") + print(f" Total Memory: {gpu_info.total_memory / 1024**2} MB") + print(f" Multi Processor Count: {gpu_info.multi_processor_count}") + print(f" Compute Capability: {gpu_info.major}.{gpu_info.minor}") + else: + print("CUDA is unavailable.") + print("-" * 50) + + +def show_results(args, info, metric_name, loss_list, results_list, time_list): + """Show the results for deep learning models. + + :args: argparse.Namespace, arguments + :info: dict, information about the dataset + :metric_name: list, names of the metrics + :loss_list: list, list of loss + :results_list: list, list of results + :time_list: list, list of time + """ + metric_arrays = {name: [] for name in metric_name} + + for result in results_list: + for idx, name in enumerate(metric_name): + metric_arrays[name].append(result[idx]) + + metric_arrays["Time"] = time_list + metric_name = (*metric_name, "Time") + + mean_metrics = {name: np.mean(metric_arrays[name]) for name in metric_name} + std_metrics = {name: np.std(metric_arrays[name]) for name in metric_name} + mean_loss = np.mean(np.array(loss_list)) + + # Printing results + print(f"{args.model_type}: {args.seed_num} Trials") + for name in metric_name: + if info["task_type"] == "regression" and name != "Time": + formatted_results = ", ".join([f"{e:.8e}" for e in metric_arrays[name]]) + print(f"{name} Results: {formatted_results}") + print(f"{name} MEAN = {mean_metrics[name]:.8e} ± {std_metrics[name]:.8e}") + else: + formatted_results = ", ".join([f"{e:.8f}" for e in metric_arrays[name]]) + print(f"{name} Results: {formatted_results}") + print(f"{name} MEAN = {mean_metrics[name]:.8f} ± {std_metrics[name]:.8f}") + + print(f"Mean Loss: {mean_loss:.8e}") + + print("-" * 20, "GPU info", "-" * 20) + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + print(f"{num_gpus} GPU Available.") + for i in range(num_gpus): + gpu_info = torch.cuda.get_device_properties(i) + print(f"GPU {i}: {gpu_info.name}") + print(f" Total Memory: {gpu_info.total_memory / 1024**2} MB") + print(f" Multi Processor Count: {gpu_info.multi_processor_count}") + print(f" Compute Capability: {gpu_info.major}.{gpu_info.minor}") + else: + print("CUDA is unavailable.") + print("-" * 50) + + +def tune_hyper_parameters(args, opt_space, train_val_data, info): + """Tune hyper-parameters. + + :args: argparse.Namespace, arguments + :opt_space: dict, search space + :train_val_data: tuple, training and validation data + :info: dict, information about the dataset + :return: argparse.Namespace, arguments + """ + import optuna + import optuna.samplers + import optuna.trial + + def objective(trial): + config = {} + try: + opt_space[args.model_type]["training"]["n_bins"] = ["int", 2, 256] + except: + opt_space[args.model_type]["fit"]["n_bins"] = ["int", 2, 256] + merge_sampled_parameters( + config, sample_parameters(trial, opt_space[args.model_type], config) + ) + if args.model_type == "xgboost" and torch.cuda.is_available(): + config["model"]["tree_method"] = "gpu_hist" + config["model"]["gpu_id"] = args.gpu + config["fit"]["verbose"] = False + elif args.model_type == "catboost" and torch.cuda.is_available(): + config["fit"]["logging_level"] = "Silent" + + elif args.model_type == "RandomForest": + config["model"]["max_depth"] = 12 + + if args.model_type in ["resnet"]: + config["model"]["activation"] = "relu" + config["model"]["normalization"] = "batchnorm" + + if args.model_type in ["ftt"]: + config["model"].setdefault("prenormalization", False) + config["model"].setdefault("initialization", "xavier") + config["model"].setdefault("activation", "reglu") + config["model"].setdefault("n_heads", 8) + config["model"].setdefault("d_token", 64) + config["model"].setdefault("token_bias", True) + config["model"].setdefault("kv_compression", None) + config["model"].setdefault("kv_compression_sharing", None) + + if args.model_type in ["excelformer"]: + config["model"].setdefault("prenormalization", False) + config["model"].setdefault("kv_compression", None) + config["model"].setdefault("kv_compression_sharing", None) + config["model"].setdefault("token_bias", True) + config["model"].setdefault("init_scale", 0.01) + config["model"].setdefault("n_heads", 8) + + if args.model_type in ["node"]: + config["model"].setdefault("choice_function", "sparsemax") + config["model"].setdefault("bin_function", "sparsemoid") + + if args.model_type in ["tabr"]: + config["model"]["num_embeddings"].setdefault("type", "PLREmbeddings") + config["model"]["num_embeddings"].setdefault("lite", True) + config["model"].setdefault("d_multiplier", 2.0) + config["model"].setdefault("mixer_normalization", "auto") + config["model"].setdefault("dropout1", 0.0) + config["model"].setdefault("normalization", "LayerNorm") + config["model"].setdefault("activation", "ReLU") + + if args.model_type in ["mlp_plr"]: + config["model"]["num_embeddings"].setdefault("type", "PLREmbeddings") + config["model"]["num_embeddings"].setdefault("lite", True) + + if args.model_type in ["ptarl"]: + config["model"]["n_clusters"] = 20 + config["model"]["regularize"] = "True" + config["general"]["diversity"] = "True" + config["general"]["ot_weight"] = 0.25 + config["general"]["diversity_weight"] = 0.25 + config["general"]["r_weight"] = 0.25 + + if args.model_type in ["modernNCA", "tabm"]: + config["model"]["num_embeddings"].setdefault("type", "PLREmbeddings") + config["model"]["num_embeddings"].setdefault("lite", True) + + if args.model_type in ["tabm"]: + config["model"]["backbone"].setdefault("type", "MLP") + config["model"].setdefault("arch_type", "tabm") + config["model"].setdefault("k", 16) + + if args.model_type in ["danets"]: + config["general"]["k"] = 5 + config["general"]["virtual_batch_size"] = 256 + + if args.model_type in ["dcn2"]: + config["model"]["stacked"] = False + + if args.model_type in ["grownet"]: + config["ensemble_model"]["lr"] = 1.0 + config["model"]["sparse"] = False + config["training"]["lr_scaler"] = 3 + + if args.model_type in ["autoint"]: + config["model"].setdefault("prenormalization", False) + config["model"].setdefault("initialization", "xavier") + config["model"].setdefault("activation", "relu") + config["model"].setdefault("n_heads", 8) + config["model"].setdefault("d_token", 64) + config["model"].setdefault("kv_compression", None) + config["model"].setdefault("kv_compression_sharing", None) + + if args.model_type in ["protogate"]: + config["training"].setdefault("lam", 1e-3) + config["training"].setdefault("pred_coef", 1) + config["training"].setdefault("sorting_tau", 16) + config["training"].setdefault("feature_selection", True) + config["model"].setdefault("a", 1) + config["model"].setdefault("sigma", 0.5) + + if args.model_type in ["grande"]: + config["model"].setdefault("from_logits", True) + config["model"].setdefault("use_class_weights", True) + config["model"].setdefault("bootstrap", False) + + if args.model_type in ["amformer"]: + config["model"].setdefault("heads", 8) + config["model"].setdefault("groups", [54, 54, 54, 54]) + config["model"].setdefault("sum_num_per_group", [32, 16, 8, 4]) + config["model"].setdefault("prod_num_per_group", [6, 6, 6, 6]) + config["model"].setdefault("cluster", True) + config["model"].setdefault("target_mode", "mix") + config["model"].setdefault("token_descent", False) + + if config.get("config_type") == "trv4": + if config["model"]["activation"].endswith("glu"): + # This adjustment is needed to keep the number of parameters roughly in the + # same range as for non-glu activations + config["model"]["d_ffn_factor"] *= 2 / 3 + + trial_configs.append(config) + # method.fit(train_val_data, info, train=True, config=config) + # run with this config + try: + method.fit(train_val_data, info, train=True, config=config) + return method.trlog["best_res"] + except Exception as e: + print(e) + return 1e9 if info["task_type"] == "regression" else 0.0 + + if ( + osp.exists(osp.join(args.save_path, f"{args.model_type}-tuned.json")) + and not args.retune + ): + with open( + osp.join(args.save_path, f"{args.model_type}-tuned.json"), "rb" + ) as fp: + args.config = json.load(fp) + else: + # get data property + if info["task_type"] == "regression": + direction = "minimize" + for key in opt_space[args.model_type]["model"]: + if ( + "dropout" in key + and "?" not in opt_space[args.model_type]["model"][key][0] + ): + opt_space[args.model_type]["model"][key][0] = ( + "?" + opt_space[args.model_type]["model"][key][0] + ) + opt_space[args.model_type]["model"][key].insert(1, 0.0) + else: + direction = "maximize" + + method = get_method(args.model_type)(args, info["task_type"] == "regression") + + trial_configs = [] + study = optuna.create_study( + direction=direction, + sampler=optuna.samplers.TPESampler(seed=0), + ) + study.optimize( + objective, + n_trials=args.n_trials, + show_progress_bar=True, + ) + # get best configs + best_trial_id = study.best_trial.number + # update config files + print("Best Hyper-Parameters") + print(trial_configs[best_trial_id]) + args.config = trial_configs[best_trial_id] + with open(osp.join(args.save_path, f"{args.model_type}-tuned.json"), "w") as fp: + json.dump(args.config, fp, sort_keys=True, indent=4) + return args + + +def get_method(model): + """Get the method class. + + :model: str, model name + :return: class, method class + """ + if model == "Beta": + from tabrepo.benchmark.models.ag.beta.talent_beta_method import BetaMethod + + return BetaMethod + raise ValueError(f"Unknown model: {model}. Please check the model name.") diff --git a/tabrepo/benchmark/models/ag/beta/talent_beta_method.py b/tabrepo/benchmark/models/ag/beta/talent_beta_method.py new file mode 100644 index 00000000..910303a8 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/talent_beta_method.py @@ -0,0 +1,434 @@ +from __future__ import annotations + +import os.path as osp +import time + +import numpy as np +import torch +from autogluon.core.metrics import compute_metric +from tqdm import tqdm + +from tabrepo.benchmark.models.ag.beta.deps.talent_data import ( + Dataset, + data_enc_process, + data_label_process, + data_loader_process, + data_nan_process, + data_norm_process, + get_categories, +) +from tabrepo.benchmark.models.ag.beta.deps.talent_methods_base import ( + Method, + check_softmax, +) + + +def reduce_loss(loss, reduction="mean"): + return ( + loss.mean() + if reduction == "mean" + else loss.sum() + if reduction == "sum" + else loss + ) + + +def loss_fn(_loss_fn, y_pred, y_true): + return _loss_fn(y_pred.flatten(0, 1), y_true.repeat_interleave(y_pred.shape[1])) + + +class BetaMethod(Method): + def __init__(self, args, is_regression, max_context_size: int = 1000): + super().__init__(args, is_regression) + assert args.num_policy == "none" + assert not is_regression + self.max_context_size = max_context_size + + def construct_model(self, model_config=None): + from tabrepo.benchmark.models.ag.beta.talent_beta_model import Beta + + if model_config is None: + model_config = self.args.config["model"] + self.model = Beta( + d_num=self.n_num_features, + cat_cardinalities=self.categories, + d_out=self.d_out, + max_context_size=self.max_context_size, + **model_config, + ).to(self.args.device) + self.trlog["best_res"] = float("-inf") + + def data_format(self, is_train=True, N=None, C=None, y=None): + if is_train: + self.N, self.C, self.num_new_value, self.imputer, self.cat_new_value = ( + data_nan_process( + self.N, self.C, self.args.num_nan_policy, self.args.cat_nan_policy + ) + ) + self.y, self.y_info, self.label_encoder = data_label_process( + self.y, self.is_regression + ) + self.N, self.C, self.ord_encoder, self.mode_values, self.cat_encoder = ( + data_enc_process(self.N, self.C, self.args.cat_policy) + ) + self.n_num_features = self.N["train"].shape[1] if self.N is not None else 0 + self.n_cat_features = self.C["train"].shape[1] if self.C is not None else 0 + self.N, self.normalizer = data_norm_process( + self.N, self.args.normalization, self.args.seed + ) + + if self.is_regression: + self.d_out = 1 + else: + self.d_out = len(np.unique(self.y["train"])) + self.C_features = self.C["train"].shape[1] if self.C is not None else 0 + self.categories = get_categories(self.C) + ( + self.N, + self.C, + self.y, + self.train_loader, + self.val_loader, + self.criterion, + ) = data_loader_process( + self.is_regression, + (self.N, self.C), + self.y, + self.y_info, + self.args.device, + self.args.batch_size, + is_train=True, + ) + + else: + N_test, C_test, _, _, _ = data_nan_process( + N, + C, + self.args.num_nan_policy, + self.args.cat_nan_policy, + self.num_new_value, + self.imputer, + self.cat_new_value, + ) + y_test, _, _ = data_label_process( + y, self.is_regression, self.y_info, self.label_encoder + ) + N_test, C_test, _, _, _ = data_enc_process( + N_test, + C_test, + self.args.cat_policy, + None, + self.ord_encoder, + self.mode_values, + self.cat_encoder, + ) + N_test, _ = data_norm_process( + N_test, self.args.normalization, self.args.seed, self.normalizer + ) + _, _, _, self.test_loader, _ = data_loader_process( + self.is_regression, + (N_test, C_test), + y_test, + self.y_info, + self.args.device, + self.args.batch_size, + is_train=False, + ) + if N_test is not None and C_test is not None: + self.N_test, self.C_test = N_test["test"], C_test["test"] + elif N_test is None and C_test is not None: + self.N_test, self.C_test = None, C_test["test"] + else: + self.N_test, self.C_test = N_test["test"], None + self.y_test = y_test["test"] + + def fit(self, data, info, train=True, config=None, model_name=None): + N, C, y = data + # if the method already fit the dataset, skip these steps (such as the hyper-tune process) + if self.D is None: + self.D = Dataset(N, C, y, info) + self.N, self.C, self.y = self.D.N, self.D.C, self.D.y + self.is_binclass, self.is_multiclass, self.is_regression = ( + self.D.is_binclass, + self.D.is_multiclass, + self.D.is_regression, + ) + self.n_num_features, self.n_cat_features = ( + self.D.n_num_features, + self.D.n_cat_features, + ) + + self.data_format(is_train=True) + if config is not None: + self.reset_stats_withconfig(config) + self.construct_model() + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=self.args.config["training"]["lr"], + weight_decay=self.args.config["training"]["weight_decay"], + ) + self.train_size = ( + self.N["train"].shape[0] if self.N is not None else self.C["train"].shape[0] + ) + self.train_indices = torch.arange(self.train_size, device=self.args.device) + # if not train, skip the training process. such as load the checkpoint and directly predict the results + if not train: + return None + + time_cost = 0 + try: + self.pre_validate() + except: + print("Pre-validation failed.") + self.N_train = self.N["train"].cpu().numpy() if self.N is not None else None + self.C_train = self.C["train"].cpu().numpy() if self.C is not None else None + self.y_train = self.y["train"].cpu().numpy() + self.N_val = self.N["val"].cpu().numpy() if self.N is not None else None + self.C_val = self.C["val"].cpu().numpy() if self.C is not None else None + self.y_val = self.y["val"].cpu().numpy() + + loss_list = [] + + N_train_size = len(self.y_train) + + from torch.cuda.amp import GradScaler, autocast + + scaler = GradScaler() + self._start_time = time.time() + + for epoch in range(self.args.max_epoch): + tic = time.time() + import math + + steps = math.ceil(N_train_size / self.args.batch_size) + if self.args.batch_size > N_train_size / 2: + steps = 2 + for _step in tqdm(range(steps)): + self.model.train() + self.optimizer.zero_grad() + with autocast(enabled=True, dtype=torch.float32): + train_logit, train_label = self.model.train_step( + self.N_train, + self.C_train, + self.y_train, + min(self.args.batch_size, N_train_size), + ) + loss = loss_fn(self.criterion, train_logit, train_label) + loss_list.append(loss.item()) + scaler.scale(loss).backward() + scaler.step(self.optimizer) + scaler.update() + if self.early_stop_due_to_timelimit(iteration=epoch): + self.continue_training = False + break + + elapsed = time.time() - tic + self.validate(epoch) + time_cost += elapsed + print(f"Epoch: {epoch}, Time cost: {elapsed}") + if not self.continue_training: + break + + # Save model to class object for use at predict time + model_path = osp.join( + self.args.save_path, model_name + f"-{self.args.seed!s}.pth" + ) + saved_state_dict = torch.load(model_path)["params"] + + filtered_saved_state_dict = { + k: v for k, v in saved_state_dict.items() if "TabPFN" not in k + } + + self.model.load_state_dict(filtered_saved_state_dict, strict=False) + self.indexs = np.load( + osp.join( + self.args.save_path, model_name + f"-indexs-{self.args.seed!s}.npy" + ), + allow_pickle=True, + ) + + return time_cost + + def early_stop_due_to_timelimit(self, iteration: int) -> bool: + if iteration > 0 and self.args.time_to_fit_in_seconds is not None: + pred_time_after_next_epoch = ( + (iteration + 1) / iteration * (time.time() - self._start_time) + ) + if pred_time_after_next_epoch >= self.args.time_to_fit_in_seconds: + return True + + return False + + def validate(self, epoch): + self.model.eval() + with torch.no_grad(): + test_logit, indexs = self.model( + self.N_val, + self.C_val, + self.N_train, + self.C_train, + self.y_train, + is_val=True, + ) + test_logit = test_logit.cpu() + test_label = self.y_val + test_logit = test_logit.mean(1).to(torch.float32).cpu().numpy() + + if self.is_regression: + assert 0 + else: + test_logit = check_softmax(test_logit) + + validation_score = compute_metric( + y=test_label, + metric=self.args.early_stopping_metric, + y_pred=test_logit if self.is_regression else test_logit.argmax(axis=-1), + y_pred_proba=test_logit[:, 1] if self.is_binclass else test_logit, + silent=True, + ) + + if validation_score > self.trlog["best_res"]: + self.trlog["best_res"] = validation_score + self.trlog["best_epoch"] = epoch + model_state_dict = self.model.state_dict() + + filtered_state_dict = { + k: v for k, v in model_state_dict.items() if "TabPFN" not in k + } + + torch.save( + {"params": filtered_state_dict}, + osp.join(self.args.save_path, f"best-val-{self.args.seed!s}.pth"), + ) + np.save( + osp.join( + self.args.save_path, f"best-val-indexs-{self.args.seed!s}.npy" + ), + indexs, + ) + self.val_count = 0 + else: + self.val_count += 1 + if self.val_count > 50: + self.continue_training = False + + print( + "best_val_res {}, best_epoch {}".format( + self.trlog["best_res"], self.trlog["best_epoch"] + ) + ) + torch.save(self.trlog, osp.join(self.args.save_path, "trlog")) + + def pre_validate(self): + epoch = -1 + from tabrepo.benchmark.models.ag.beta.talent_tabpfn_model import ( + TabPFNClassifier, + ) + + self.PFN_model = TabPFNClassifier( + device=self.args.device, + seed=self.args.seed, + N_ensemble_configurations=self.args.config["model"]["k"], + ) + if self.N is not None and self.C is not None: + X_train = np.concatenate( + (self.N["train"].cpu().numpy(), self.C["train"].cpu().numpy()), axis=1 + ) + X_val = np.concatenate( + (self.N["val"].cpu().numpy(), self.C["val"].cpu().numpy()), axis=1 + ) + elif self.N is None and self.C is not None: + X_train = self.C["train"].cpu().numpy() + X_val = self.C["val"].cpu().numpy() + else: + X_train = self.N["train"].cpu().numpy() + X_val = self.N["val"].cpu().numpy() + y_train = self.y["train"].cpu().numpy() + y_val = self.y["val"].cpu().numpy() + if y_train.shape[0] > 3000: + # sampled_X and sampled_Y contain sample_size samples maintaining class proportions for the training set + from sklearn.model_selection import train_test_split + + X_train, _, y_train, _ = train_test_split( + X_train, y_train, train_size=3000, stratify=y_train + ) + self.PFN_model.fit(X_train, y_train, overwrite_warning=True) + y_val_predict = self.PFN_model.predict_proba(X_val) + y_val_predict = torch.tensor(y_val_predict) + test_label = y_val + # logits are already in softmax form + test_logit = y_val_predict.to(torch.float32).cpu().numpy() + + if self.is_regression: + assert 0 + + validation_score = compute_metric( + y=test_label, + metric=self.args.early_stopping_metric, + y_pred=test_logit if self.is_regression else test_logit.argmax(axis=-1), + y_pred_proba=test_logit[:, 1] if self.is_binclass else test_logit, + silent=True, + ) + + self.trlog["best_res"] = validation_score + self.trlog["best_epoch"] = epoch + self.val_count = 0 + + def PFN_predict(self, data, info, model_name): + N, C, y = data + self.data_format(False, N, C, y) + if self.N_test is not None and self.C_test is not None: + Test_X = np.concatenate((self.N_test, self.C_test), axis=1) + elif self.N_test is None and self.C_test is not None: + Test_X = self.C_test + else: + Test_X = self.N_test + test_logit = self.PFN_model.predict_proba(Test_X) + test_label = self.y_test + vl = self.criterion(torch.tensor(test_logit), torch.tensor(test_label)).item() + vres, metric_name = self.metric(test_logit, test_label, self.y_info) + print(f"Test: loss={vl:.4f}") + for name, res in zip(metric_name, vres): + print(f"[{name}]={res:.4f}") + return vl, vres, metric_name, test_logit + + def predict(self, data, info, model_name): + if self.trlog["best_epoch"] == -1: + return self.PFN_predict(data, info, model_name) + + N, C, y = data + self.data_format(False, N, C, y) + self.model.eval() + import time + + time.time() + batch_size = 4096 + with torch.no_grad(): + num_test_samples = len(self.y_test) # .size(0) + + all_test_logit = [] + + for start_idx in range(0, num_test_samples, batch_size): + end_idx = min(start_idx + batch_size, num_test_samples) + + batch_N_test = ( + self.N_test[start_idx:end_idx] if self.N_test is not None else None + ) + batch_C_test = ( + self.C_test[start_idx:end_idx] if self.C_test is not None else None + ) + batch_logit = self.model( + batch_N_test, + batch_C_test, + self.N_train, + self.C_train, + self.y_train, + is_test=True, + indexs=self.indexs, + ).cpu() + + all_test_logit.append(batch_logit) + + test_logit = torch.cat(all_test_logit, dim=0) + return test_logit.mean(1).to(torch.float32) + # print("Evaluation time cost:", time.time() - start_time) diff --git a/tabrepo/benchmark/models/ag/beta/talent_beta_model.py b/tabrepo/benchmark/models/ag/beta/talent_beta_model.py new file mode 100644 index 00000000..c929792d --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/talent_beta_model.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import pathlib +import random +from typing import Literal + +import numpy as np +import torch +from torch import nn + +from tabrepo.benchmark.models.ag.beta.deps.tabm_utils import ( + ElementwiseAffineEnsemble, + OneHotEncoding0d, + _init_scaling_by_sections, + make_efficient_ensemble, +) +from tabrepo.benchmark.models.ag.beta.deps.tabpfn.utils import load_model_workflow +from tabrepo.benchmark.models.ag.beta.deps.tabr_utils import ( + MLP, + ResNet, + make_module, + make_module1, +) + + +def _get_first_input_scaling(backbone): + if isinstance(backbone, MLP): + return backbone.blocks[0][0] # type: ignore[code] + if isinstance(backbone, ResNet): + return backbone.blocks[0][1] if backbone.proj is None else backbone.proj # type: ignore[code] + raise RuntimeError(f"Unsupported backbone: {backbone}") + + +class Beta(nn.Module): + def __init__( + self, + *, + d_num: int, + d_out: int, + backbone: dict, + cat_cardinalities: list[int], + num_embeddings: dict | None, + arch_type: Literal[ + # Active + "vanilla", # Simple MLP + "tabm", # BatchEnsemble + separate heads + better initialization + "tabm-mini", # Minimal: * weight + # BatchEnsemble + "tabm-naive", + ], + k: None | int = None, + device="cuda:0", + base_path=pathlib.Path(__file__).parent.resolve(), + model_string="", + max_context_size=1000, + ) -> None: + super().__init__() + self.d_out = d_out + self.d_num = d_num + if cat_cardinalities is None: + cat_cardinalities = [] + scaling_init_sections = [] + self.max_context_size = max_context_size + if d_num == 0: + # assert bins is None + self.num_module = None + d_num = 0 + + elif num_embeddings is None: + # assert bins is None + self.num_module = None + scaling_init_sections.extend(1 for _ in range(self.d_num)) + + else: + self.num_module = make_module(num_embeddings, n_features=d_num) + d_num = d_num * num_embeddings["d_embedding"] + scaling_init_sections.extend( + num_embeddings["d_embedding"] for _ in range(self.d_num) + ) + + self.cat_module = ( + OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None + ) + scaling_init_sections.extend(cat_cardinalities) + d_cat = sum(cat_cardinalities) + + # >>> Backbone + d_flat = d_num + d_cat + self.affine_ensemble = None + # self.scaling_layer = ScaleLayer(k, d_flat) + self.backbone = make_module1(d_in=d_flat, **backbone) + if arch_type != "vanilla": + assert k is not None + scaling_init = "random-signs" if num_embeddings is None else "normal" + + if arch_type == "tabm-mini": + # The minimal possible efficient ensemble. + self.affine_ensemble = ElementwiseAffineEnsemble( + k, + d_flat, + weight=True, + bias=False, + weight_init=( + "random-signs" if num_embeddings is None else "normal" + ), + ) + _init_scaling_by_sections( + self.affine_ensemble.weight, # type: ignore[code] + scaling_init, + scaling_init_sections, + ) + + elif arch_type == "tabm-naive": + # The original BatchEnsemble. + make_efficient_ensemble( + self.backbone, + k=k, + ensemble_scaling_in=True, + ensemble_scaling_out=True, + ensemble_bias=True, + scaling_init="random-signs", + ) + elif arch_type == "tabm": + # Like BatchEnsemble, but all scalings, except for the first one, + # are initialized with ones. + make_efficient_ensemble( + self.backbone, + k=k, + ensemble_scaling_in=True, + ensemble_scaling_out=True, + ensemble_bias=True, + scaling_init="ones", + ) + _init_scaling_by_sections( + _get_first_input_scaling(self.backbone).r, # type: ignore[code] + scaling_init, + scaling_init_sections, + ) + + else: + raise ValueError(f"Unknown arch_type: {arch_type}") + self.arch_type = arch_type + self.k = k + self.device = device + self.style = None + self.TabPFN, self.c, self.results_file = load_model_workflow( + 0, + 42, + add_name=model_string, + base_path=base_path, + device=device, + eval_addition="", + only_inference=True, + ) + self.TabPFN = self.TabPFN[2] + self.TabPFN.to(torch.float16) + for param in self.TabPFN.parameters(): + param.requires_grad = False + self.max_num_features = self.c["num_features"] + self.max_num_classes = self.c["max_num_classes"] + self.differentiable_hps_as_style = self.c["differentiable_hps_as_style"] + self.index = None + + def train_step(self, x_num, x_cat, y, batch_size): + num_samples = y.shape[0] + if batch_size > num_samples / 2: + batch_size = int(num_samples / 2) + train_indices = random.sample(range(num_samples), batch_size) + candidate_indices = list(range(num_samples)) # if i not in train_indices] + + x_num_train = x_num[train_indices] if x_num is not None else None + x_cat_train = x_cat[train_indices] if x_cat is not None else None + y_train = torch.tensor(y[train_indices], device=self.device).long() + + x_num_candidate = x_num[candidate_indices] if x_num is not None else None + x_cat_candidate = x_cat[candidate_indices] if x_cat is not None else None + y_candidate = y[candidate_indices] + logits = self.forward( + x_num_train, + x_cat_train, + x_num_candidate, + x_cat_candidate, + y_candidate, + is_train=True, + ) + + return logits, y_train + + def forward( + self, + x_num, + x_cat, + candidate_x_num, + candidate_x_cat, + candidate_y, + is_train=False, + is_val=False, + is_test=False, + indexs=None, + ): + candidate_size = candidate_y.shape[0] + input = [] + y_input = [] + index_val = [] + if is_train: + indices = np.random.randint(0, candidate_size, (min(1024, candidate_size),)) + candidate_x_cat_sample = ( + torch.tensor(candidate_x_cat[indices], device=self.device) + if candidate_x_cat is not None + else None + ) + candidate_x_num_sample = ( + torch.tensor(candidate_x_num[indices], device=self.device).float() + if candidate_x_num is not None + else None + ) + candidate_y_sample = ( + torch.tensor(candidate_y[indices], device=self.device).long() + if candidate_y is not None + else None + ) + x_num = ( + torch.tensor(x_num, device=self.device).float() + if x_num is not None + else None + ) + x_cat = ( + torch.tensor(x_cat, device=self.device) if x_cat is not None else None + ) + x = [] + candidate_x = [] + if x_num is not None: + x_num_sample = x_num + x.append( + x_num_sample + if self.num_module is None + else self.num_module(x_num_sample) + ) + candidate_x.append( + candidate_x_num_sample + if self.num_module is None + else self.num_module(candidate_x_num_sample) + ) + if x_cat is None: + assert self.cat_module is None + else: + assert self.cat_module is not None + x_cat_sample = x_cat + x.append(self.cat_module(x_cat_sample)) + candidate_x.append(self.cat_module(candidate_x_cat_sample)) + x = torch.column_stack([x_.flatten(1, -1) for x_ in x]) + candidate_x = torch.column_stack([x_.flatten(1, -1) for x_ in candidate_x]) + if self.k is not None: + x = x[:, None].expand(-1, self.k, -1) # (B, D) -> (B, K, D) + candidate_x = candidate_x[:, None].expand( + -1, self.k, -1 + ) # (B, D) -> (B, K, D) + if self.affine_ensemble is not None: + x = self.affine_ensemble(x) + candidate_x = self.affine_ensemble(candidate_x) + else: + assert self.affine_ensemble is None + + x1 = self.backbone(x) + candidate_x1 = self.backbone(candidate_x) + for k in range(self.k): + x = x1[:, k, :] + candidate_x = candidate_x1[:, k, :] + input.append(torch.cat([candidate_x, x], dim=0)) + zeroy_expanded = torch.zeros((x.shape[0]), device=self.device) + y_full = torch.cat([candidate_y_sample, zeroy_expanded], dim=0) + y_input.append(y_full) + input = torch.stack(input) + y_input = torch.stack(y_input) + input = torch.permute(input, (1, 0, 2)).to(torch.float16) + y_input = torch.permute(y_input, (1, 0)).to(torch.float16) + + from torch.cuda.amp import autocast + + with autocast(): + logits = self.TabPFN( + (None, input, y_input), + single_eval_pos=y_input.shape[0] - x.shape[0], + )[:, :, : self.d_out].to(torch.float) + # print(logits.dtype) + else: + val_logits = [] + x_num = ( + torch.tensor(x_num, device=self.device).float() + if x_num is not None + else None + ) + x_cat = ( + torch.tensor(x_cat, device=self.device) if x_cat is not None else None + ) + for k in range(self.k): + if indexs is None: + indices = np.random.randint( + 0, candidate_size, (min(self.max_context_size, candidate_size),) + ) + index_val.append(indices) + else: + indices = indexs[k] + candidate_x_cat_sample = ( + torch.tensor(candidate_x_cat[indices], device=self.device) + if candidate_x_cat is not None + else None + ) + candidate_x_num_sample = ( + torch.tensor(candidate_x_num[indices], device=self.device).float() + if candidate_x_num is not None + else None + ) + candidate_y_sample = ( + torch.tensor(candidate_y[indices], device=self.device).long() + if candidate_y is not None + else None + ) + + x = [] + candidate_x = [] + if x_num is not None: + x_num_sample = x_num + x.append( + x_num_sample + if self.num_module is None + else self.num_module(x_num_sample) + ) + candidate_x.append( + candidate_x_num_sample + if self.num_module is None + else self.num_module(candidate_x_num_sample) + ) + if x_cat is None: + assert self.cat_module is None + else: + assert self.cat_module is not None + x_cat_sample = x_cat + x.append(self.cat_module(x_cat_sample)) + candidate_x.append(self.cat_module(candidate_x_cat_sample)) + x = torch.column_stack([x_.flatten(1, -1) for x_ in x]) + candidate_x = torch.column_stack( + [x_.flatten(1, -1) for x_ in candidate_x] + ) + if self.k is not None: + x = x[:, None].expand(-1, self.k, -1) # (B, D) -> (B, K, D) + candidate_x = candidate_x[:, None].expand( + -1, self.k, -1 + ) # (B, D) -> (B, K, D) + if self.affine_ensemble is not None: + x = self.affine_ensemble(x) + candidate_x = self.affine_ensemble(candidate_x) + else: + assert self.affine_ensemble is None + + x1 = self.backbone(x)[:, k, :] + candidate_x1 = self.backbone(candidate_x)[:, k, :] + input = [torch.cat([candidate_x1, x1], dim=0)] + zeroy_expanded = torch.zeros((x.shape[0]), device=self.device) + y_input = [torch.cat([candidate_y_sample, zeroy_expanded], dim=0)] + input = torch.stack(input) + y_input = torch.stack(y_input) + input = torch.permute(input, (1, 0, 2)).to(torch.float16) + y_input = torch.permute(y_input, (1, 0)).to(torch.float16) + + val_logits.append( + self.TabPFN( + (None, input, y_input), + single_eval_pos=y_input.shape[0] - x.shape[0], + )[:, :, : self.d_out] + ) + # Clean up for next iteration + del x, candidate_x, x1, candidate_x1, input, y_input, zeroy_expanded + torch.cuda.empty_cache() + + logits = torch.cat(val_logits, dim=1) + if is_val: + return logits, index_val + return logits diff --git a/tabrepo/benchmark/models/ag/beta/talent_tabpfn_model.py b/tabrepo/benchmark/models/ag/beta/talent_tabpfn_model.py new file mode 100644 index 00000000..1230e263 --- /dev/null +++ b/tabrepo/benchmark/models/ag/beta/talent_tabpfn_model.py @@ -0,0 +1,221 @@ +import numpy as np +import random + +import torch +from torch.utils.checkpoint import checkpoint + +import pickle +import io +import os +import pathlib +from pathlib import Path +from functools import partial + +from sklearn.preprocessing import PowerTransformer, QuantileTransformer, RobustScaler +from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.preprocessing import LabelEncoder +from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import check_X_y, check_array, check_is_fitted +from sklearn.utils import column_or_1d + +from model.lib.tabpfn.utils import CustomUnpickler +from model.lib.tabpfn.utils import load_model_workflow, get_params_from_config, transformer_predict +# Source: https://github.com/automl/TabPFN + + +class TabPFNClassifier(BaseEstimator, ClassifierMixin): + + models_in_memory = {} + + def __init__(self, device='cpu', base_path=pathlib.Path(__file__).parent.resolve(), model_string='', + N_ensemble_configurations=3, no_preprocess_mode=False, multiclass_decoder='permutation', + feature_shift_decoder=True, only_inference=True, seed=0, no_grad=True, batch_size_inference=32, + subsample_features=False): + """ + Initializes the classifier and loads the model. + Depending on the arguments, the model is either loaded from memory, from a file, or downloaded from the + repository if no model is found. + + Can also be used to compute gradients with respect to the inputs X_train and X_test. Therefore no_grad has to be + set to False and no_preprocessing_mode must be True. Furthermore, X_train and X_test need to be given as + torch.Tensors and their requires_grad parameter must be set to True. + + + :param device: If the model should run on cuda or cpu. + :param base_path: Base path of the directory, from which the folders like models_diff can be accessed. + :param model_string: Name of the model. Used first to check if the model is already in memory, and if not, + tries to load a model with that name from the models_diff directory. It looks for files named as + follows: "prior_diff_real_checkpoint" + model_string + "_n_0_epoch_e.cpkt", where e can be a number + between 100 and 0, and is checked in a descending order. + :param N_ensemble_configurations: The number of ensemble configurations used for the prediction. Thereby the + accuracy, but also the running time, increases with this number. + :param no_preprocess_mode: Specifies whether preprocessing is to be performed. + :param multiclass_decoder: If set to permutation, randomly shifts the classes for each ensemble configuration. + :param feature_shift_decoder: If set to true shifts the features for each ensemble configuration according to a + random permutation. + :param only_inference: Indicates if the model should be loaded to only restore inference capabilities or also + training capabilities. Note that the training capabilities are currently not being fully restored. + :param seed: Seed that is used for the prediction. Allows for a deterministic behavior of the predictions. + :param batch_size_inference: This parameter is a trade-off between performance and memory consumption. + The computation done with different values for batch_size_inference is the same, + but it is split into smaller/larger batches. + :param no_grad: If set to false, allows for the computation of gradients with respect to X_train and X_test. + For this to correctly function no_preprocessing_mode must be set to true. + :param subsample_features: If set to true and the number of features in the dataset exceeds self.max_features (100), + the features are subsampled to self.max_features. + """ + + # Model file specification (Model name, Epoch) + i = 0 + model_key = model_string+'|'+str(device) + if model_key in self.models_in_memory: + model, c, results_file = self.models_in_memory[model_key] + else: + model, c, results_file = load_model_workflow(i, 42, add_name=model_string, base_path=base_path, device=device, + eval_addition='', only_inference=only_inference) + self.models_in_memory[model_key] = (model, c, results_file) + if len(self.models_in_memory) == 2: + print('Multiple models in memory. This might lead to memory issues. Consider calling remove_models_from_memory()') + #style, temperature = self.load_result_minimal(style_file, i, e) + + self.device = device + self.model = model + self.c = c + self.style = None + self.temperature = None + self.N_ensemble_configurations = N_ensemble_configurations + self.base__path = base_path + self.base_path = base_path + self.i = i + self.model_string = model_string + + self.max_num_features = self.c['num_features'] + self.max_num_classes = self.c['max_num_classes'] + self.differentiable_hps_as_style = self.c['differentiable_hps_as_style'] + + self.no_preprocess_mode = no_preprocess_mode + self.feature_shift_decoder = feature_shift_decoder + self.multiclass_decoder = multiclass_decoder + self.only_inference = only_inference + self.seed = seed + self.no_grad = no_grad + self.subsample_features = subsample_features + + assert self.no_preprocess_mode if not self.no_grad else True, \ + "If no_grad is false, no_preprocess_mode must be true, because otherwise no gradient can be computed." + + self.batch_size_inference = batch_size_inference + + def remove_models_from_memory(self): + self.models_in_memory = {} + + def load_result_minimal(self, path, i, e): + with open(path, 'rb') as output: + _, _, _, style, temperature, optimization_route = CustomUnpickler(output).load() + + return style, temperature + + def _validate_targets(self, y): + y_ = column_or_1d(y, warn=True) + check_classification_targets(y) + cls, y = np.unique(y_, return_inverse=True) + if len(cls) < 2: + raise ValueError( + "The number of classes has to be greater than one; got %d class" + % len(cls) + ) + + self.classes_ = cls + + return np.asarray(y, dtype=np.float64, order="C") + + def fit(self, X, y, overwrite_warning=False): + """ + Validates the training set and stores it. + + If clf.no_grad (default is True): + X, y should be of type np.array + else: + X should be of type torch.Tensors (y can be np.array or torch.Tensor) + """ + if self.no_grad: + # Check that X and y have correct shape + X, y = check_X_y(X, y, force_all_finite=False) + # Store the classes seen during fit + y = self._validate_targets(y) + self.label_encoder = LabelEncoder() + y = self.label_encoder.fit_transform(y) + + self.X_ = X + self.y_ = y + + if (X.shape[1] > self.max_num_features): + if self.subsample_features: + print('WARNING: The number of features for this classifier is restricted to ', self.max_num_features, ' and will be subsampled.') + else: + raise ValueError("The number of features for this classifier is restricted to ", self.max_num_features) + if len(np.unique(y)) > self.max_num_classes: + raise ValueError("The number of classes for this classifier is restricted to ", self.max_num_classes) + if X.shape[0] > 1024 and not overwrite_warning: + raise ValueError("⚠️ WARNING: TabPFN is not made for datasets with a trainingsize > 1024. Prediction might take a while, be less reliable. We advise not to run datasets > 10k samples, which might lead to your machine crashing (due to quadratic memory scaling of TabPFN). Please confirm you want to run by passing overwrite_warning=True to the fit function.") + + + # Return the classifier + return self + + def predict_proba(self, X, normalize_with_test=False, return_logits=False): + """ + Predict the probabilities for the input X depending on the training set previously passed in the method fit. + + If no_grad is true in the classifier the function takes X as a numpy.ndarray. If no_grad is false X must be a + torch tensor and is not fully checked. + """ + # Check is fit had been called + check_is_fitted(self) + + # Input validation + if self.no_grad: + X = check_array(X, force_all_finite=False) + X_full = np.concatenate([self.X_, X], axis=0) + X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1) + else: + assert (torch.is_tensor(self.X_) & torch.is_tensor(X)), "If no_grad is false, this function expects X as " \ + "a tensor to calculate a gradient" + X_full = torch.cat((self.X_, X), dim=0).float().unsqueeze(1).to(self.device) + + if int(torch.isnan(X_full).sum()): + print('X contains nans and the gradient implementation is not designed to handel nans.') + + y_full = np.concatenate([self.y_, np.zeros(shape=X.shape[0])], axis=0) + y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1) + + eval_pos = self.X_.shape[0] + + prediction = transformer_predict(self.model[2], X_full, y_full, eval_pos, + device=self.device, + style=self.style, + inference_mode=True, + preprocess_transform='none' if self.no_preprocess_mode else 'mix', + normalize_with_test=normalize_with_test, + N_ensemble_configurations=self.N_ensemble_configurations, + softmax_temperature=self.temperature, + multiclass_decoder=self.multiclass_decoder, + feature_shift_decoder=self.feature_shift_decoder, + differentiable_hps_as_style=self.differentiable_hps_as_style, + seed=self.seed, + return_logits=return_logits, + no_grad=self.no_grad, + batch_size_inference=self.batch_size_inference, + **get_params_from_config(self.c)) + prediction_, y_ = prediction.squeeze(0), y_full.squeeze(1).long()[eval_pos:] + + return prediction_.detach().cpu().numpy() if self.no_grad else prediction_ + + def predict(self, X, return_winning_probability=False, normalize_with_test=False): + p = self.predict_proba(X, normalize_with_test=normalize_with_test) + y = np.argmax(p, axis=-1) + y = self.classes_.take(np.asarray(y, dtype=np.intp)) + if return_winning_probability: + return y, p.max(axis=-1) + return y + \ No newline at end of file diff --git a/tabrepo/benchmark/models/model_register.py b/tabrepo/benchmark/models/model_register.py index 9066e788..66634253 100644 --- a/tabrepo/benchmark/models/model_register.py +++ b/tabrepo/benchmark/models/model_register.py @@ -5,6 +5,7 @@ from autogluon.tabular.registry import ModelRegistry, ag_model_registry from tabrepo.benchmark.models.ag import ( + BetaModel, ExplainableBoostingMachineModel, ModernNCAModel, RealMLPModel, @@ -26,6 +27,7 @@ TabDPTModel, TabMModel, ModernNCAModel, + BetaModel, ] for _model_cls in _models_to_add: @@ -43,7 +45,10 @@ def infer_model_cls(model_cls: str, model_register: ModelRegistry = None): if real_model_cls.ag_name == model_cls: model_cls = real_model_cls break - elif model_cls in [str(real_model_cls.__name__) for real_model_cls in model_register.model_cls_list]: + elif model_cls in [ + str(real_model_cls.__name__) + for real_model_cls in model_register.model_cls_list + ]: for real_model_cls in model_register.model_cls_list: if model_cls == str(real_model_cls.__name__): model_cls = real_model_cls diff --git a/tabrepo/models/beta/__init__.py b/tabrepo/models/beta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tabrepo/models/beta/generate.py b/tabrepo/models/beta/generate.py new file mode 100644 index 00000000..ca26da7b --- /dev/null +++ b/tabrepo/models/beta/generate.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from tabrepo.benchmark.models.ag.beta.beta_model import BetaModel +from tabrepo.utils.config_utils import ConfigGenerator + +name = "BETA" +manual_configs = [ + {}, +] + +gen_beta = ConfigGenerator( + model_cls=BetaModel, manual_configs=manual_configs, search_space={} +) + +if __name__ == "__main__": + from tabrepo.benchmark.experiment import YamlExperimentSerializer + + print( + YamlExperimentSerializer.to_yaml_str( + experiments=gen_beta.generate_all_bag_experiments(num_random_configs=0), + ), + ) diff --git a/tabrepo/models/utils.py b/tabrepo/models/utils.py index 5f021f4b..0c8dce12 100644 --- a/tabrepo/models/utils.py +++ b/tabrepo/models/utils.py @@ -16,7 +16,10 @@ def convert_numpy_dtypes(data: dict) -> dict: converted_data[key] = convert_numpy_dtypes(value) elif isinstance(value, list): converted_data[key] = [ - convert_numpy_dtypes({i: v})[i] if isinstance(v, (dict, np.generic)) else v for i, v in enumerate(value) + convert_numpy_dtypes({i: v})[i] + if isinstance(v, (dict, np.generic)) + else v + for i, v in enumerate(value) ] else: converted_data[key] = value @@ -28,27 +31,62 @@ def get_configs_generator_from_name(model_name: str): import importlib name_to_import_map = { - "CatBoost": lambda: importlib.import_module("tabrepo.models.catboost.generate").gen_catboost, + "CatBoost": lambda: importlib.import_module( + "tabrepo.models.catboost.generate" + ).gen_catboost, "EBM": lambda: importlib.import_module("tabrepo.models.ebm.generate").gen_ebm, - "ExtraTrees": lambda: importlib.import_module("tabrepo.models.extra_trees.generate").gen_extratrees, - "FastaiMLP": lambda: importlib.import_module("tabrepo.models.fastai.generate").gen_fastai, - "FTTransformer": lambda: importlib.import_module("tabrepo.models.ftt.generate").gen_fttransformer, + "ExtraTrees": lambda: importlib.import_module( + "tabrepo.models.extra_trees.generate" + ).gen_extratrees, + "FastaiMLP": lambda: importlib.import_module( + "tabrepo.models.fastai.generate" + ).gen_fastai, + "FTTransformer": lambda: importlib.import_module( + "tabrepo.models.ftt.generate" + ).gen_fttransformer, "KNN": lambda: importlib.import_module("tabrepo.models.knn.generate").gen_knn, - "LightGBM": lambda: importlib.import_module("tabrepo.models.lightgbm.generate").gen_lightgbm, - "Linear": lambda: importlib.import_module("tabrepo.models.lr.generate").gen_linear, - "ModernNCA": lambda: importlib.import_module("tabrepo.models.modernnca.generate").gen_modernnca, - "TorchMLP": lambda: importlib.import_module("tabrepo.models.nn_torch.generate").gen_nn_torch, - "RandomForest": lambda: importlib.import_module("tabrepo.models.random_forest.generate").gen_randomforest, - "RealMLP": lambda: importlib.import_module("tabrepo.models.realmlp.generate").gen_realmlp, - "TabDPT": lambda: importlib.import_module("tabrepo.models.tabdpt.generate").gen_tabdpt, - "TabICL": lambda: importlib.import_module("tabrepo.models.tabicl.generate").gen_tabicl, - "TabM": lambda: importlib.import_module("tabrepo.models.tabm.generate").gen_tabm, + "LightGBM": lambda: importlib.import_module( + "tabrepo.models.lightgbm.generate" + ).gen_lightgbm, + "Linear": lambda: importlib.import_module( + "tabrepo.models.lr.generate" + ).gen_linear, + "ModernNCA": lambda: importlib.import_module( + "tabrepo.models.modernnca.generate" + ).gen_modernnca, + "TorchMLP": lambda: importlib.import_module( + "tabrepo.models.nn_torch.generate" + ).gen_nn_torch, + "RandomForest": lambda: importlib.import_module( + "tabrepo.models.random_forest.generate" + ).gen_randomforest, + "RealMLP": lambda: importlib.import_module( + "tabrepo.models.realmlp.generate" + ).gen_realmlp, + "TabDPT": lambda: importlib.import_module( + "tabrepo.models.tabdpt.generate" + ).gen_tabdpt, + "TabICL": lambda: importlib.import_module( + "tabrepo.models.tabicl.generate" + ).gen_tabicl, + "TabM": lambda: importlib.import_module( + "tabrepo.models.tabm.generate" + ).gen_tabm, # "TabPFN": lambda: importlib.import_module("tabrepo.models.tabpfn.generate").gen_tabpfn, # not supported in TabArena - "TabPFNv2": lambda: importlib.import_module("tabrepo.models.tabpfnv2.generate").gen_tabpfnv2, - "XGBoost": lambda: importlib.import_module("tabrepo.models.xgboost.generate").gen_xgboost, + "TabPFNv2": lambda: importlib.import_module( + "tabrepo.models.tabpfnv2.generate" + ).gen_tabpfnv2, + "XGBoost": lambda: importlib.import_module( + "tabrepo.models.xgboost.generate" + ).gen_xgboost, + "BETA": lambda: importlib.import_module( + "tabrepo.models.beta.generate" + ).gen_beta, } if model_name not in name_to_import_map: - raise ValueError(f"Model name '{model_name}' is not recognized. Options are: {list(name_to_import_map.keys())}") + raise ValueError( + f"Model name '{model_name}' is not recognized. Options are: {list(name_to_import_map.keys())}" + ) return name_to_import_map[model_name]() diff --git a/tst/benchmark/models/test_beta.py b/tst/benchmark/models/test_beta.py new file mode 100644 index 00000000..fccf2153 --- /dev/null +++ b/tst/benchmark/models/test_beta.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pytest + + +def test_beta_tabpfn(): + toy_model_params = {"batch_size": 8, "max_epoch": 10} + model_hyperparameters = toy_model_params + try: + from autogluon.tabular.testing import FitHelper + from tabrepo.benchmark.models.ag.beta.beta_model import BetaModel + + model_cls = BetaModel + FitHelper.verify_model( + model_cls=model_cls, model_hyperparameters=model_hyperparameters + ) + except ImportError as err: + pytest.skip( + f"Import Error, skipping test... " + f"Ensure you have the proper dependencies installed to run this test:\n" + f"{err}" + )