Skip to content

Commit ca07b3d

Browse files
jjmachanJithin James
andauthored
fix: lazyloading of model used in metrics to speed up import (#32)
* added init_model to baseline * added init_model to everything * fix lint issues * added init model to qsquare * ignore type issue * fix linting --------- Co-authored-by: Jithin James <[email protected]>
1 parent 6548f3c commit ca07b3d

File tree

6 files changed

+41
-17
lines changed

6 files changed

+41
-17
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ lint: ## Running lint checker: ruff
1818
@ruff check ragas examples tests
1919
type: ## Running type checker: pyright
2020
@echo "(pyright) Typechecking codebase..."
21-
@pyright -p ragas
21+
@pyright ragas
2222
clean: ## Clean all generated files
2323
@echo "Cleaning all generated files..."
2424
@cd $(GIT_ROOT)/docs && make clean

ragas/metrics/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,33 @@ class Metric(ABC):
1313
@property
1414
@abstractmethod
1515
def name(self: t.Self) -> str:
16+
"""
17+
the metric name
18+
"""
1619
...
1720

1821
@property
1922
@abstractmethod
2023
def is_batchable(self: t.Self) -> bool:
24+
"""
25+
Attribute to check if this metric is is_batchable
26+
"""
27+
...
28+
29+
@abstractmethod
30+
def init_model():
31+
"""
32+
This method will lazy initialize the model.
33+
"""
2134
...
2235

2336
@abstractmethod
2437
def score(
2538
self: t.Self, ground_truth: list[str], generated_text: list[str]
2639
) -> list[float]:
40+
"""
41+
Run the metric on the ground_truth and generated_text and return score.
42+
"""
2743
...
2844

2945

@@ -37,6 +53,10 @@ def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Resu
3753
ds = Dataset.from_dict(
3854
{"ground_truth": ground_truth, "generated_text": generated_text}
3955
)
56+
57+
# initialize all the models in the metrics
58+
[m.init_model() for m in self.metrics]
59+
4060
ds = ds.map(
4161
self._get_score,
4262
batched=self.batched,

ragas/metrics/factual.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class EntailmentScore(Metric):
5252
batch_size: int = 4
5353
device: t.Literal["cpu", "cuda"] | Device = "cpu"
5454

55-
def __post_init__(self):
55+
def init_model(self):
5656
self.device = device_check(self.device)
5757
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
5858
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
@@ -212,10 +212,11 @@ class Qsquare(Metric):
212212
include_nouns: bool = True
213213
save_results: bool = False
214214

215-
def __post_init__(self):
215+
def init_model(self):
216216
self.qa = QAGQ.from_pretrained(self.qa_model_name)
217217
self.qg = QAGQ.from_pretrained(self.qg_model_name)
218218
self.nli = EntailmentScore()
219+
self.nli.init_model()
219220
try:
220221
self.nlp = spacy.load(SPACY_MODEL)
221222
except OSError:
@@ -326,15 +327,15 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs):
326327
)
327328
gnd_qans[i] = [
328329
{"question": qstn, "answer": ans}
329-
for qstn, ans in zip(questions, candidates)
330+
for qstn, ans in zip(questions, candidates) # type: ignore
330331
]
331332

332333
for i, gen_text in enumerate(generated_text):
333334
questions = [item["question"] for item in gnd_qans[i]]
334335
gen_answers = self.generate_answers(questions, gen_text)
335336
_ = [
336337
item.update({"predicted_answer": ans})
337-
for item, ans in zip(gnd_qans[i], gen_answers)
338+
for item, ans in zip(gnd_qans[i], gen_answers) # type: ignore
338339
]
339340

340341
# del self.qa

ragas/metrics/similarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class BERTScore(Metric):
1818
model_path: str = "all-MiniLM-L6-v2"
1919
batch_size: int = 1000
2020

21-
def __post_init__(self):
21+
def init_model(self):
2222
self.model = SentenceTransformer(self.model_path)
2323

2424
@property

ragas/metrics/simple.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def name(self):
2626
def is_batchable(self):
2727
return True
2828

29+
def init_model(self):
30+
...
31+
2932
def score(self, ground_truth: t.List[str], generated_text: t.List[str]):
3033
ground_truth_ = [[word_tokenize(text)] for text in ground_truth]
3134
generated_text_ = [word_tokenize(text) for text in generated_text]
@@ -45,7 +48,7 @@ class ROUGE(Metric):
4548
type: t.Literal[ROUGE_TYPES]
4649
use_stemmer: bool = False
4750

48-
def __post_init__(self):
51+
def init_model(self):
4952
self.scorer = rouge_scorer.RougeScorer(
5053
[self.type], use_stemmer=self.use_stemmer
5154
)
@@ -80,6 +83,9 @@ def name(self) -> str:
8083
def is_batchable(self):
8184
return True
8285

86+
def init_model(self):
87+
...
88+
8389
def score(self, ground_truth: t.List[str], generated_text: t.List[str]):
8490
if self.measure == "distance":
8591
score = [distance(s1, s2) for s1, s2 in zip(ground_truth, generated_text)]

tests/benchmarks/benchmark.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,23 @@
77

88
from ragas.metrics import (
99
Evaluation,
10-
edit_distance,
10+
bert_score,
1111
edit_ratio,
12-
q_square,
1312
rouge1,
14-
rouge2,
15-
rougeL,
1613
)
1714

1815
DEVICE = "cuda" if is_available() else "cpu"
19-
BATCHES = [0, 1]
16+
BATCHES = [0, 1, 30, 60]
2017

2118
METRICS = {
2219
"Rouge1": rouge1,
23-
"Rouge2": rouge2,
24-
"RougeL": rougeL,
20+
# "Rouge2": rouge2,
21+
# "RougeL": rougeL,
2522
"EditRatio": edit_ratio,
26-
"EditDistance": edit_distance,
27-
# "SBERTScore": bert_score,
23+
# "EditDistance": edit_distance,
24+
"SBERTScore": bert_score,
2825
# "EntailmentScore": entailment_score,
29-
"Qsquare": q_square,
26+
# "Qsquare": q_square,
3027
}
3128
DS = load_dataset("explodinggradients/eli5-test", split="test_eli5")
3229
assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset"

0 commit comments

Comments
 (0)