Skip to content

Commit 6076e10

Browse files
committed
refactor: remove the use of enums
1 parent a6ef573 commit 6076e10

File tree

1 file changed

+23
-43
lines changed

1 file changed

+23
-43
lines changed

neuralnetlib/ensemble.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from enum import Enum
32

43

54
class IsolationTree:
@@ -122,13 +121,8 @@ def fit_predict(self, X: np.ndarray) -> np.ndarray:
122121
return self.fit(X).predict(X)
123122

124123

125-
class TreeType(Enum):
126-
CLASSIFIER = "classifier"
127-
REGRESSOR = "regressor"
128-
129-
130124
class DecisionTree:
131-
def __init__(self, tree_type: TreeType = TreeType.CLASSIFIER, max_depth: int = None,
125+
def __init__(self, tree_type: str = "classifier", max_depth: int = None,
132126
min_samples_split: int = 2, min_samples_leaf: int = 1,
133127
max_features: int = None, random_state: int = None):
134128
self.tree_type = tree_type
@@ -161,7 +155,7 @@ def _best_split(self, X, y, features):
161155
best_feature = None
162156
best_threshold = None
163157

164-
current_metric = self._gini(y) if self.tree_type == TreeType.CLASSIFIER else self._mse(y)
158+
current_metric = self._gini(y) if self.tree_type == "classifier" else self._mse(y)
165159

166160
for feature in features:
167161
thresholds = np.unique(X[:, feature])
@@ -173,7 +167,7 @@ def _best_split(self, X, y, features):
173167
if np.sum(left_mask) < self.min_samples_leaf or np.sum(right_mask) < self.min_samples_leaf:
174168
continue
175169

176-
if self.tree_type == TreeType.CLASSIFIER:
170+
if self.tree_type == "classifier":
177171
left_metric = self._gini(y[left_mask])
178172
right_metric = self._gini(y[right_mask])
179173
else:
@@ -200,7 +194,7 @@ def _build_tree(self, X, y, depth=0):
200194
if (self.max_depth is not None and depth >= self.max_depth) or \
201195
n_samples < self.min_samples_split or \
202196
n_samples < 2 * self.min_samples_leaf:
203-
if self.tree_type == TreeType.CLASSIFIER:
197+
if self.tree_type == "classifier":
204198
unique, counts = np.unique(y, return_counts=True)
205199
node.prediction = unique[np.argmax(counts)]
206200
else:
@@ -213,7 +207,7 @@ def _build_tree(self, X, y, depth=0):
213207
feature, threshold = self._best_split(X, y, features)
214208

215209
if feature is None:
216-
if self.tree_type == TreeType.CLASSIFIER:
210+
if self.tree_type == "classifier":
217211
unique, counts = np.unique(y, return_counts=True)
218212
node.prediction = unique[np.argmax(counts)]
219213
else:
@@ -238,7 +232,7 @@ def _predict_single(self, x, node):
238232
return self._predict_single(x, node.right)
239233

240234
def fit(self, X, y):
241-
if self.tree_type == TreeType.CLASSIFIER:
235+
if self.tree_type == "classifier":
242236
self.n_classes = len(np.unique(y))
243237
y = y.astype(int)
244238

@@ -247,13 +241,13 @@ def fit(self, X, y):
247241

248242
def predict(self, X):
249243
predictions = np.array([self._predict_single(x, self.root) for x in X])
250-
if self.tree_type == TreeType.CLASSIFIER:
244+
if self.tree_type == "classifier":
251245
return predictions.astype(int)
252246
return predictions
253247

254248

255249
class RandomForest:
256-
def __init__(self, n_estimators: int = 100, tree_type: TreeType = TreeType.CLASSIFIER,
250+
def __init__(self, n_estimators: int = 100, tree_type: str = "classifier",
257251
max_depth: int = None, min_samples_split: int = 2, min_samples_leaf: int = 1,
258252
max_features: str | int = "sqrt", bootstrap: bool = True,
259253
random_state: int = None):
@@ -281,7 +275,7 @@ def fit(self, X, y):
281275
n_samples, n_features = X.shape
282276
max_features = self._get_max_features(n_features)
283277

284-
if self.tree_type == TreeType.CLASSIFIER:
278+
if self.tree_type == "classifier":
285279
y = y.astype(int)
286280
self.classes_ = np.unique(y)
287281

@@ -309,7 +303,7 @@ def fit(self, X, y):
309303

310304
def predict(self, X):
311305
predictions = np.array([tree.predict(X) for tree in self.trees])
312-
if self.tree_type == TreeType.CLASSIFIER:
306+
if self.tree_type == "classifier":
313307
mode_predictions = []
314308
for sample_pred in predictions.T:
315309
values, counts = np.unique(sample_pred.astype(int), return_counts=True)
@@ -318,8 +312,6 @@ def predict(self, X):
318312
return np.mean(predictions, axis=0)
319313

320314

321-
import numpy as np
322-
323315
class DecisionStump:
324316
def __init__(self):
325317
self.feature_idx = None
@@ -419,14 +411,7 @@ def score_samples(self, X):
419411
return np.sum([stump.alpha * stump.predict(X) for stump in self.stumps], axis=0)
420412

421413

422-
import numpy as np
423-
from enum import Enum
424-
425-
class GBMTask(Enum):
426-
REGRESSION = "regression"
427-
BINARY_CLASSIFICATION = "binary_classification"
428-
429-
class DecisionTree:
414+
class DecisionTreeGBM:
430415
def __init__(self, max_depth=3, min_samples_split=2):
431416
self.max_depth = max_depth
432417
self.min_samples_split = min_samples_split
@@ -516,9 +501,9 @@ def predict(self, X):
516501
return np.array([self._predict_sample(x, self.root) for x in X])
517502

518503
class GradientBoostingMachine:
519-
def __init__(self, task=GBMTask.REGRESSION, n_estimators=100, learning_rate=0.1,
504+
def __init__(self, task="regression", n_estimators=100, learning_rate=0.1,
520505
max_depth=3, min_samples_split=2, subsample=1.0, random_state=None):
521-
self.task = task if isinstance(task, GBMTask) else GBMTask(task)
506+
self.task = task
522507
self.n_estimators = n_estimators
523508
self.learning_rate = learning_rate
524509
self.max_depth = max_depth
@@ -533,7 +518,7 @@ def _sigmoid(self, x):
533518
return 1 / (1 + np.exp(-x))
534519

535520
def _compute_residuals(self, y_true, y_pred):
536-
if self.task == GBMTask.REGRESSION:
521+
if self.task == "regression":
537522
return y_true - y_pred
538523
else:
539524
p = self._sigmoid(y_pred)
@@ -548,7 +533,7 @@ def _sample_indices(self, n_samples):
548533
def fit(self, X, y):
549534
n_samples = X.shape[0]
550535

551-
if self.task == GBMTask.REGRESSION:
536+
if self.task == "regression":
552537
self.initial_prediction = np.mean(y)
553538
else:
554539
y = np.where(y <= 0, 0, 1)
@@ -560,7 +545,7 @@ def fit(self, X, y):
560545
residuals = self._compute_residuals(y, F)
561546
indices = self._sample_indices(n_samples)
562547

563-
tree = DecisionTree(
548+
tree = DecisionTreeGBM(
564549
max_depth=self.max_depth,
565550
min_samples_split=self.min_samples_split
566551
)
@@ -578,12 +563,12 @@ def predict(self, X):
578563
for tree in self.trees:
579564
predictions += self.learning_rate * tree.predict(X)
580565

581-
if self.task == GBMTask.BINARY_CLASSIFICATION:
566+
if self.task == "binary_classification":
582567
return (self._sigmoid(predictions) >= 0.5).astype(int)
583568
return predictions
584569

585570
def predict_proba(self, X):
586-
if self.task != GBMTask.BINARY_CLASSIFICATION:
571+
if self.task != "binary_classification":
587572
raise ValueError("predict_proba is only available for binary classification")
588573

589574
predictions = np.full(X.shape[0], self.initial_prediction)
@@ -595,11 +580,6 @@ def predict_proba(self, X):
595580
return np.vstack([1 - proba, proba]).T
596581

597582

598-
599-
class XGBoostObjective(Enum):
600-
REG_SQUAREDERROR = "reg:squarederror"
601-
BINARY_LOGISTIC = "binary:logistic"
602-
603583
class XGBoostNode:
604584
def __init__(self):
605585
self.feature_idx: int = None
@@ -710,7 +690,7 @@ def __init__(self, objective: str = "reg:squarederror", n_estimators: int = 100,
710690
min_child_weight: float = 1.0, subsample: float = 1.0,
711691
colsample_bytree: float = 1.0, lambda_: float = 1.0,
712692
gamma: float = 0.0, random_state: int = None):
713-
self.objective = XGBoostObjective(objective)
693+
self.objective = objective
714694
self.n_estimators = n_estimators
715695
self.learning_rate = learning_rate
716696
self.max_depth = max_depth
@@ -729,7 +709,7 @@ def _sigmoid(self, x: np.ndarray) -> np.ndarray:
729709
return 1 / (1 + np.exp(-x))
730710

731711
def _compute_gradients(self, y: np.ndarray, pred: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
732-
if self.objective == XGBoostObjective.REG_SQUAREDERROR:
712+
if self.objective == "reg:squarederror":
733713
grad = pred - y
734714
hess = np.ones_like(y)
735715
else:
@@ -757,7 +737,7 @@ def _subsample_data(self, X: np.ndarray, y: np.ndarray,
757737
return X, y, grad, hess
758738

759739
def fit(self, X: np.ndarray, y: np.ndarray) -> 'XGBoost':
760-
if self.objective == XGBoostObjective.BINARY_LOGISTIC:
740+
if self.objective == "binary:logistic":
761741
y = (y > 0).astype(np.float64)
762742
self.base_score = np.log(np.mean(y) / (1 - np.mean(y) + 1e-6))
763743
else:
@@ -790,12 +770,12 @@ def predict(self, X: np.ndarray) -> np.ndarray:
790770
for tree in self.trees:
791771
predictions += self.learning_rate * tree.predict(X)
792772

793-
if self.objective == XGBoostObjective.BINARY_LOGISTIC:
773+
if self.objective == "binary:logistic":
794774
return (self._sigmoid(predictions) >= 0.5).astype(int)
795775
return predictions
796776

797777
def predict_proba(self, X: np.ndarray) -> np.ndarray:
798-
if self.objective != XGBoostObjective.BINARY_LOGISTIC:
778+
if self.objective != "binary:logistic":
799779
raise ValueError("predict_proba is only available for binary classification")
800780

801781
predictions = np.full(X.shape[0], self.base_score)

0 commit comments

Comments
 (0)