Skip to content

Commit c5023a4

Browse files
authored
Add support for bert-sklearn #minor (#74)
1 parent cc6e992 commit c5023a4

10 files changed

+140
-9
lines changed

docs/examples/plot_bert.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=====================
4+
BERT sklearn
5+
=====================
6+
7+
In order to use `bert-sklearn <https://github.com/charles9n/bert-sklearn>`_ with HiClass, some of scikit-learns checks need to be disabled.
8+
The reason is that BERT expects text as input for the features, but scikit-learn expects numerical features.
9+
Hence, the checks will fail.
10+
To disable scikit-learn's checks, we can simply use the parameter `bert=True` in the constructor of the local hierarchical classifier.
11+
"""
12+
from bert_sklearn import BertClassifier
13+
from hiclass import LocalClassifierPerParentNode
14+
15+
# Define data
16+
X_train = X_test = [
17+
"Batman",
18+
"Rorschach",
19+
]
20+
Y_train = [
21+
["Action", "The Dark Night"],
22+
["Action", "Watchmen"],
23+
]
24+
25+
# Use BERT for every node
26+
bert = BertClassifier()
27+
classifier = LocalClassifierPerParentNode(
28+
local_classifier=bert,
29+
bert=True,
30+
)
31+
32+
# Train local classifier per node
33+
classifier.fit(X_train, Y_train)
34+
35+
# Predict
36+
predictions = classifier.predict(X_test)
37+
print(predictions)

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ matplotlib==3.5.2
88
pandas==1.4.2
99
ray==1.13.0
1010
numpy<1.24
11+
git+https://github.com/charles9n/bert-sklearn.git@master

hiclass/HierarchicalClassifier.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
edge_list: str = None,
6666
replace_classifiers: bool = True,
6767
n_jobs: int = 1,
68+
bert: bool = False,
6869
classifier_abbreviation: str = "",
6970
):
7071
"""
@@ -87,6 +88,8 @@ def __init__(
8788
n_jobs : int, default=1
8889
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
8990
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
91+
bert : bool, default=False
92+
If True, skip scikit-learn's checks and sample_weight passing for BERT.
9093
classifier_abbreviation : str, default=""
9194
The abbreviation of the local hierarchical classifier to be displayed during logging.
9295
"""
@@ -95,6 +98,7 @@ def __init__(
9598
self.edge_list = edge_list
9699
self.replace_classifiers = replace_classifiers
97100
self.n_jobs = n_jobs
101+
self.bert = bert
98102
self.classifier_abbreviation = classifier_abbreviation
99103

100104
def fit(self, X, y, sample_weight=None):
@@ -130,9 +134,13 @@ def _pre_fit(self, X, y, sample_weight):
130134
# Check that X and y have correct shape
131135
# and convert them to np.ndarray if need be
132136

133-
self.X_, self.y_ = self._validate_data(
134-
X, y, multi_output=True, accept_sparse="csr"
135-
)
137+
if not self.bert:
138+
self.X_, self.y_ = self._validate_data(
139+
X, y, multi_output=True, accept_sparse="csr"
140+
)
141+
else:
142+
self.X_ = np.array(X)
143+
self.y_ = np.array(y)
136144

137145
if sample_weight is not None:
138146
self.sample_weight_ = _check_sample_weight(sample_weight, X)

hiclass/LocalClassifierPerLevel.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
edge_list: str = None,
4848
replace_classifiers: bool = True,
4949
n_jobs: int = 1,
50+
bert: bool = False,
5051
):
5152
"""
5253
Initialize a local classifier per level.
@@ -68,6 +69,8 @@ def __init__(
6869
n_jobs : int, default=1
6970
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
7071
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
72+
bert : bool, default=False
73+
If True, skip scikit-learn's checks and sample_weight passing for BERT.
7174
"""
7275
super().__init__(
7376
local_classifier=local_classifier,
@@ -76,6 +79,7 @@ def __init__(
7679
replace_classifiers=replace_classifiers,
7780
n_jobs=n_jobs,
7881
classifier_abbreviation="LCPL",
82+
bert=bert,
7983
)
8084

8185
def fit(self, X, y, sample_weight=None):
@@ -135,7 +139,10 @@ def predict(self, X):
135139
check_is_fitted(self)
136140

137141
# Input validation
138-
X = check_array(X, accept_sparse="csr")
142+
if not self.bert:
143+
X = check_array(X, accept_sparse="csr")
144+
else:
145+
X = np.array(X)
139146

140147
# Initialize array that holds predictions
141148
y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_)
@@ -242,7 +249,10 @@ def _fit_classifier(self, level, separator):
242249
unique_y = np.unique(y)
243250
if len(unique_y) == 1 and self.replace_classifiers:
244251
classifier = ConstantClassifier()
245-
classifier.fit(X, y, sample_weight)
252+
if not self.bert:
253+
classifier.fit(X, y, sample_weight)
254+
else:
255+
classifier.fit(X, y)
246256
return classifier
247257

248258
@staticmethod

hiclass/LocalClassifierPerNode.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
edge_list: str = None,
4343
replace_classifiers: bool = True,
4444
n_jobs: int = 1,
45+
bert: bool = False,
4546
):
4647
"""
4748
Initialize a local classifier per node.
@@ -74,6 +75,8 @@ def __init__(
7475
n_jobs : int, default=1
7576
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
7677
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
78+
bert : bool, default=False
79+
If True, skip scikit-learn's checks and sample_weight passing for BERT.
7780
"""
7881
super().__init__(
7982
local_classifier=local_classifier,
@@ -82,6 +85,7 @@ def __init__(
8285
replace_classifiers=replace_classifiers,
8386
n_jobs=n_jobs,
8487
classifier_abbreviation="LCPN",
88+
bert=bert,
8589
)
8690
self.binary_policy = binary_policy
8791

@@ -145,7 +149,10 @@ def predict(self, X):
145149
check_is_fitted(self)
146150

147151
# Input validation
148-
X = check_array(X, accept_sparse="csr")
152+
if not self.bert:
153+
X = check_array(X, accept_sparse="csr")
154+
else:
155+
X = np.array(X)
149156

150157
# Initialize array that holds predictions
151158
y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_)
@@ -233,7 +240,10 @@ def _fit_classifier(self, node):
233240
unique_y = np.unique(y)
234241
if len(unique_y) == 1 and self.replace_classifiers:
235242
classifier = ConstantClassifier()
236-
classifier.fit(X, y, sample_weight)
243+
if not self.bert:
244+
classifier.fit(X, y, sample_weight)
245+
else:
246+
classifier.fit(X, y)
237247
return classifier
238248

239249
def _clean_up(self):

hiclass/LocalClassifierPerParentNode.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
edge_list: str = None,
4141
replace_classifiers: bool = True,
4242
n_jobs: int = 1,
43+
bert: bool = False,
4344
):
4445
"""
4546
Initialize a local classifier per parent node.
@@ -61,6 +62,8 @@ def __init__(
6162
n_jobs : int, default=1
6263
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
6364
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
65+
bert : bool, default=False
66+
If True, skip scikit-learn's checks and sample_weight passing for BERT.
6467
"""
6568
super().__init__(
6669
local_classifier=local_classifier,
@@ -69,6 +72,7 @@ def __init__(
6972
replace_classifiers=replace_classifiers,
7073
n_jobs=n_jobs,
7174
classifier_abbreviation="LCPPN",
75+
bert=bert,
7276
)
7377

7478
def fit(self, X, y, sample_weight=None):
@@ -128,7 +132,10 @@ def predict(self, X):
128132
check_is_fitted(self)
129133

130134
# Input validation
131-
X = check_array(X, accept_sparse="csr")
135+
if not self.bert:
136+
X = check_array(X, accept_sparse="csr")
137+
else:
138+
X = np.array(X)
132139

133140
# Initialize array that holds predictions
134141
y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_)
@@ -203,7 +210,10 @@ def _fit_classifier(self, node):
203210
unique_y = np.unique(y)
204211
if len(unique_y) == 1 and self.replace_classifiers:
205212
classifier = ConstantClassifier()
206-
classifier.fit(X, y, sample_weight)
213+
if not self.bert:
214+
classifier.fit(X, y, sample_weight)
215+
else:
216+
classifier.fit(X, y)
207217
return classifier
208218

209219
def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):

tests/test_HierarchicalClassifier.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,13 @@ def test_make_leveled_non_iterable_y(noniterable_y):
214214
def test_fit_classifier():
215215
with pytest.raises(NotImplementedError):
216216
HierarchicalClassifier._fit_classifier(None, None)
217+
218+
219+
def test_pre_fit_bert():
220+
classifier = HierarchicalClassifier()
221+
classifier.logger_ = logging.getLogger("HC")
222+
classifier.bert = True
223+
x = [[0, 1], [2, 3]]
224+
y = [["a", "b"], ["c", "d"]]
225+
sample_weight = None
226+
classifier._pre_fit(x, y, sample_weight)

tests/test_LocalClassifierPerLevel.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.utils.estimator_checks import parametrize_with_checks
1111
from sklearn.utils.validation import check_is_fitted
1212
from hiclass import LocalClassifierPerLevel
13+
from hiclass.ConstantClassifier import ConstantClassifier
1314

1415

1516
@parametrize_with_checks([LocalClassifierPerLevel()])
@@ -180,3 +181,17 @@ def test_empty_levels(empty_levels):
180181
lcppn.root_,
181182
]
182183
assert_array_equal(ground_truth, predictions)
184+
185+
186+
def test_fit_bert():
187+
bert = ConstantClassifier()
188+
lcpn = LocalClassifierPerLevel(
189+
local_classifier=bert,
190+
bert=True,
191+
)
192+
X = ["Text 1", "Text 2"]
193+
y = ["a", "a"]
194+
lcpn.fit(X, y)
195+
check_is_fitted(lcpn)
196+
predictions = lcpn.predict(X)
197+
assert_array_equal(y, predictions)

tests/test_LocalClassifierPerNode.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from hiclass import LocalClassifierPerNode
1414
from hiclass.BinaryPolicy import ExclusivePolicy
15+
from hiclass.ConstantClassifier import ConstantClassifier
1516

1617

1718
@parametrize_with_checks([LocalClassifierPerNode()])
@@ -242,3 +243,17 @@ def test_empty_levels(empty_levels):
242243
lcppn.root_,
243244
]
244245
assert_array_equal(ground_truth, predictions)
246+
247+
248+
def test_fit_bert():
249+
bert = ConstantClassifier()
250+
lcpn = LocalClassifierPerNode(
251+
local_classifier=bert,
252+
bert=True,
253+
)
254+
X = ["Text 1", "Text 2"]
255+
y = ["a", "a"]
256+
lcpn.fit(X, y)
257+
check_is_fitted(lcpn)
258+
predictions = lcpn.predict(X)
259+
assert_array_equal(y, predictions)

tests/test_LocalClassifierPerParentNode.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.utils.validation import check_is_fitted
1313

1414
from hiclass import LocalClassifierPerParentNode
15+
from hiclass.ConstantClassifier import ConstantClassifier
1516

1617

1718
@parametrize_with_checks([LocalClassifierPerParentNode()])
@@ -226,3 +227,17 @@ def test_empty_levels(empty_levels):
226227
lcppn.root_,
227228
]
228229
assert_array_equal(ground_truth, predictions)
230+
231+
232+
def test_bert():
233+
bert = ConstantClassifier()
234+
lcpn = LocalClassifierPerParentNode(
235+
local_classifier=bert,
236+
bert=True,
237+
)
238+
X = ["Text 1", "Text 2"]
239+
y = ["a", "a"]
240+
lcpn.fit(X, y)
241+
check_is_fitted(lcpn)
242+
predictions = lcpn.predict(X)
243+
assert_array_equal(y, predictions)

0 commit comments

Comments
 (0)