@@ -52,7 +52,7 @@ class EntailmentScore(Metric):
5252 batch_size : int = 4
5353 device : t .Literal ["cpu" , "cuda" ] | Device = "cpu"
5454
55- def __post_init__ (self ):
55+ def init_model (self ):
5656 self .device = device_check (self .device )
5757 self .tokenizer = AutoTokenizer .from_pretrained (self .model_name )
5858 self .model = AutoModelForSequenceClassification .from_pretrained (self .model_name )
@@ -212,10 +212,11 @@ class Qsquare(Metric):
212212 include_nouns : bool = True
213213 save_results : bool = False
214214
215- def __post_init__ (self ):
215+ def init_model (self ):
216216 self .qa = QAGQ .from_pretrained (self .qa_model_name )
217217 self .qg = QAGQ .from_pretrained (self .qg_model_name )
218218 self .nli = EntailmentScore ()
219+ self .nli .init_model ()
219220 try :
220221 self .nlp = spacy .load (SPACY_MODEL )
221222 except OSError :
@@ -326,15 +327,15 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs):
326327 )
327328 gnd_qans [i ] = [
328329 {"question" : qstn , "answer" : ans }
329- for qstn , ans in zip (questions , candidates )
330+ for qstn , ans in zip (questions , candidates ) # type: ignore
330331 ]
331332
332333 for i , gen_text in enumerate (generated_text ):
333334 questions = [item ["question" ] for item in gnd_qans [i ]]
334335 gen_answers = self .generate_answers (questions , gen_text )
335336 _ = [
336337 item .update ({"predicted_answer" : ans })
337- for item , ans in zip (gnd_qans [i ], gen_answers )
338+ for item , ans in zip (gnd_qans [i ], gen_answers ) # type: ignore
338339 ]
339340
340341 # del self.qa
0 commit comments