@@ -40,6 +40,7 @@ def __init__(
40
40
edge_list : str = None ,
41
41
replace_classifiers : bool = True ,
42
42
n_jobs : int = 1 ,
43
+ bert : bool = False ,
43
44
):
44
45
"""
45
46
Initialize a local classifier per parent node.
@@ -61,6 +62,8 @@ def __init__(
61
62
n_jobs : int, default=1
62
63
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
63
64
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
65
+ bert : bool, default=False
66
+ If True, skip scikit-learn's checks and sample_weight passing for BERT.
64
67
"""
65
68
super ().__init__ (
66
69
local_classifier = local_classifier ,
@@ -69,6 +72,7 @@ def __init__(
69
72
replace_classifiers = replace_classifiers ,
70
73
n_jobs = n_jobs ,
71
74
classifier_abbreviation = "LCPPN" ,
75
+ bert = bert ,
72
76
)
73
77
74
78
def fit (self , X , y , sample_weight = None ):
@@ -128,7 +132,10 @@ def predict(self, X):
128
132
check_is_fitted (self )
129
133
130
134
# Input validation
131
- X = check_array (X , accept_sparse = "csr" )
135
+ if not self .bert :
136
+ X = check_array (X , accept_sparse = "csr" )
137
+ else :
138
+ X = np .array (X )
132
139
133
140
# Initialize array that holds predictions
134
141
y = np .empty ((X .shape [0 ], self .max_levels_ ), dtype = self .dtype_ )
@@ -203,7 +210,10 @@ def _fit_classifier(self, node):
203
210
unique_y = np .unique (y )
204
211
if len (unique_y ) == 1 and self .replace_classifiers :
205
212
classifier = ConstantClassifier ()
206
- classifier .fit (X , y , sample_weight )
213
+ if not self .bert :
214
+ classifier .fit (X , y , sample_weight )
215
+ else :
216
+ classifier .fit (X , y )
207
217
return classifier
208
218
209
219
def _fit_digraph (self , local_mode : bool = False , use_joblib : bool = False ):
0 commit comments