Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
"modernnca": [
"category_encoders",
],
"dpdt": [
# TODO: pypi package is not available yet
"git+https://github.com/KohlerHECTOR/DPDTreeEstimator.git",
],
}

benchmark_requires = []
Expand All @@ -51,6 +55,7 @@
"tabdpt",
"tabm",
"modernnca",
"dpdt",
]:
benchmark_requires += extras_require[extra_package]
benchmark_requires = list(set(benchmark_requires))
Expand Down
2 changes: 2 additions & 0 deletions tabrepo/benchmark/models/ag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tabrepo.benchmark.models.ag.tabm.tabm_model import TabMModel
from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_client_model import TabPFNV2ClientModel
from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_model import TabPFNV2Model
from tabrepo.benchmark.models.ag.dpdt.dpdt_model import BoostedDPDTModel

__all__ = [
"ExplainableBoostingMachineModel",
Expand All @@ -18,4 +19,5 @@
"TabMModel",
"TabPFNV2ClientModel",
"TabPFNV2Model",
"BoostedDPDTModel"
]
Empty file.
88 changes: 88 additions & 0 deletions tabrepo/benchmark/models/ag/dpdt/dpdt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
from autogluon.common.utils.resource_utils import ResourceManager
from autogluon.core.models import AbstractModel

if TYPE_CHECKING:
import pandas as pd


class BoostedDPDTModel(AbstractModel):
ag_key = "BOOSTEDDPDT"
ag_name = "boosted_dpdt"

def get_model_cls(self):
from dpdt import AdaBoostDPDT

if self.problem_type in ["binary", "multiclass"]:
model_cls = AdaBoostDPDT
else:
raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
return model_cls

def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, **kwargs):
model_cls = self.get_model_cls()
hyp = self._get_model_params()
if num_cpus < 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think num_cpus would never be below 1, did you want to do <=?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I will remove it. It is just by experience with the joblib library in which to use all available cpus one write n_jobs = -1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Here num_cpus might be a string called "auto" in edge cases (not within TabArena benchmarks)

num_cpus = 'best'
self.model = model_cls(
**hyp,
n_jobs=num_cpus,
)
X = self.preprocess(X)
self.model = self.model.fit(
X=X,
y=y,
)


def _set_default_params(self):
default_params = {
"random_state": 42,
}
for param, val in default_params.items():
self._set_default_param_value(param, val)

@classmethod
def supported_problem_types(cls) -> list[str] | None:
return ["binary", "multiclass"]

def _get_default_resources(self) -> tuple[int, int]:
import torch
# logical=False is faster in training
num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
num_gpus = 0
return num_cpus, num_gpus

def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
hyperparameters = self._get_model_params()
return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, hyperparameters=hyperparameters, **kwargs)

@classmethod
def _estimate_memory_usage_static(
cls,
*,
X: pd.DataFrame,
hyperparameters: dict = None,
**kwargs,
) -> int:
if hyperparameters is None:
hyperparameters = {}

dataset_size_mem_est = 10 * hyperparameters.get('cart_nodes_list')[0] * get_approximate_df_mem_usage(X).sum()
baseline_overhead_mem_est = 3e8 # 300 MB generic overhead

mem_estimate = dataset_size_mem_est + baseline_overhead_mem_est

return mem_estimate

@classmethod
def _class_tags(cls):
return {"can_estimate_memory_usage_static": True}

def _more_tags(self) -> dict:
"""DPDT does not yet support refit full."""
return {"can_refit_full": False}
2 changes: 2 additions & 0 deletions tabrepo/benchmark/models/model_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TabMModel,
TabPFNV2ClientModel,
TabPFNV2Model,
BoostedDPDTModel,
)

tabrepo_model_register: ModelRegistry = copy.deepcopy(ag_model_registry)
Expand All @@ -26,6 +27,7 @@
TabDPTModel,
TabMModel,
ModernNCAModel,
BoostedDPDTModel,
]

for _model_cls in _models_to_add:
Expand Down
Empty file added tabrepo/models/dpdt/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions tabrepo/models/dpdt/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from autogluon.common.space import Categorical, Real, Int
import numpy as np

from tabrepo.benchmark.models.ag.dpdt.dpdt_model import BoostedDPDTModel
from tabrepo.utils.config_utils import ConfigGenerator

name = 'BoostedDPDT'
manual_configs = [
{},
]

# get config from paper

# Generate 1000 samples from log-normal distribution
# Parameters: mu = log(0.01), sigma = log(10.0)
mu = float(np.log(0.01))
sigma = float(np.log(10.0))
samples = np.random.lognormal(mean=mu, sigma=sigma, size=1000)

# Generate 1000 samples from q_log_uniform_values distribution
# Parameters: min=1.5, max=50.5, q=1
min_val = 1.5
max_val = 50.5
q = 1
# Generate log-uniform samples and quantize
log_min = np.log(min_val)
log_max = np.log(max_val)
log_uniform_samples = np.random.uniform(log_min, log_max, size=1000)
min_samples_leaf_samples = np.round(np.exp(log_uniform_samples) / q) * q
min_samples_leaf_samples = np.clip(min_samples_leaf_samples, min_val, max_val).astype(int)

# Generate 1000 samples for min_weight_fraction_leaf
# Values: [0.0, 0.01], probabilities: [0.95, 0.05]
min_weight_fraction_leaf_samples = np.random.choice([0.0, 0.01], size=1000, p=[0.95, 0.05])

# Generate 1000 samples for max_features
# Values: ["sqrt", "log2", 10000], probabilities: [0.5, 0.25, 0.25]
max_features_samples = np.random.choice(["sqrt", "log2", 10000], size=1000, p=[0.5, 0.25, 0.25])

search_space = {
'learning_rate': Categorical(*samples), # log_normal distribution equivalent
'n_estimators': 1000, # Fixed value as per old config
'max_depth': Categorical(2, 2, 2, 2, 3, 3, 3, 3, 3, 3),
'min_samples_split': Categorical(*np.random.choice([2, 3], size=1000, p=[0.95, 0.05])),
'min_impurity_decrease': Categorical(*np.random.choice([0, 0.01, 0.02, 0.05], size=1000, p=[0.85, 0.05, 0.05, 0.05])),
'cart_nodes_list': Categorical((8, 4), (4, 8), (16, 2), (4, 4, 2)),
'min_samples_leaf': Categorical(*min_samples_leaf_samples), # q_log_uniform equivalent
'min_weight_fraction_leaf': Categorical(*min_weight_fraction_leaf_samples),
'max_features': Categorical(*max_features_samples),
'random_state': Categorical(0, 1, 2, 3, 4)
}

gen_boosteddpdt = ConfigGenerator(model_cls=BoostedDPDTModel, manual_configs=manual_configs, search_space=search_space)


def generate_configs_boosted_dpdt(num_random_configs=200):
config_generator = ConfigGenerator(name=name, manual_configs=manual_configs, search_space=search_space)
return config_generator.generate_all_configs(num_random_configs=num_random_configs)
1 change: 1 addition & 0 deletions tabrepo/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_configs_generator_from_name(model_name: str):
# "TabPFN": lambda: importlib.import_module("tabrepo.models.tabpfn.generate").gen_tabpfn, # not supported in TabArena
"TabPFNv2": lambda: importlib.import_module("tabrepo.models.tabpfnv2.generate").gen_tabpfnv2,
"XGBoost": lambda: importlib.import_module("tabrepo.models.xgboost.generate").gen_xgboost,
"BoostedDPDT": lambda: importlib.import_module("tabrepo.models.dpdt.generate").gen_boosteddpdt,
}

if model_name not in name_to_import_map:
Expand Down
17 changes: 17 additions & 0 deletions tst/benchmark/models/test_dpdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest


def test_dpdt():
model_hyperparameters = {"n_estimators": 2, "cart_nodes_list":(4,3)}

try:
from autogluon.tabular.testing import FitHelper
from tabrepo.benchmark.models.ag.tabicl.tabicl_model import BoostedDPDTModel
model_cls = BoostedDPDTModel
FitHelper.verify_model(model_cls=model_cls, model_hyperparameters=model_hyperparameters)
except ImportError as err:
pytest.skip(
f"Import Error, skipping test... "
f"Ensure you have the proper dependencies installed to run this test:\n"
f"{err}"
)