File tree Expand file tree Collapse file tree 2 files changed +19
-1
lines changed Expand file tree Collapse file tree 2 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -376,7 +376,7 @@ def _from_pretrained(
376376
377377 model_head = multilabel_classifier
378378 else :
379- model_head = LogisticRegression ()
379+ model_head = clf
380380
381381 return SetFitModel (model_body = model_body , model_head = model_head , multi_target_strategy = multi_target_strategy )
382382
Original file line number Diff line number Diff line change @@ -53,6 +53,24 @@ def test_setfit_default_model_head():
5353 assert type (model .model_head ) is LogisticRegression
5454
5555
56+ def test_setfit_model_head_params ():
57+ params = {
58+ "head_params" : {
59+ "max_iter" : 200 ,
60+ "solver" : "newton-cg" ,
61+ }
62+ }
63+
64+ model = SetFitModel .from_pretrained ("sentence-transformers/paraphrase-albert-small-v2" , ** params )
65+
66+ assert type (model .model_head ) is LogisticRegression
67+ assert params ["head_params" ] == {
68+ parameter : value
69+ for parameter , value in model .model_head .get_params (deep = False ).items ()
70+ if parameter in params ["head_params" ]
71+ }
72+
73+
5674def test_setfit_multilabel_one_vs_rest_model_head ():
5775 model = SetFitModel .from_pretrained (
5876 "sentence-transformers/paraphrase-albert-small-v2" , multi_target_strategy = "one-vs-rest"
You can’t perform that action at this time.
0 commit comments