Skip to content

Commit f8b9873

Browse files
Mouhanedg56Mouhaned Chebaane
andauthored
Small fix on hyperparameter search (#150)
* edit modeling.py * add unit test Co-authored-by: Mouhaned Chebaane <[email protected]>
1 parent 8829272 commit f8b9873

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/setfit/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _from_pretrained(
376376

377377
model_head = multilabel_classifier
378378
else:
379-
model_head = LogisticRegression()
379+
model_head = clf
380380

381381
return SetFitModel(model_body=model_body, model_head=model_head, multi_target_strategy=multi_target_strategy)
382382

tests/test_modeling.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ def test_setfit_default_model_head():
5353
assert type(model.model_head) is LogisticRegression
5454

5555

56+
def test_setfit_model_head_params():
57+
params = {
58+
"head_params": {
59+
"max_iter": 200,
60+
"solver": "newton-cg",
61+
}
62+
}
63+
64+
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", **params)
65+
66+
assert type(model.model_head) is LogisticRegression
67+
assert params["head_params"] == {
68+
parameter: value
69+
for parameter, value in model.model_head.get_params(deep=False).items()
70+
if parameter in params["head_params"]
71+
}
72+
73+
5674
def test_setfit_multilabel_one_vs_rest_model_head():
5775
model = SetFitModel.from_pretrained(
5876
"sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest"

0 commit comments

Comments
 (0)