diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index c3394971e31..aecfa9c06df 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -293,8 +293,14 @@ def maxBackward_(self, line): return self.score_(res[::-1]) + def _stem_and_lemmatize(self, input): + lemmatize = self.lemmatizer.lemmatize(input) + if lemmatize.endswith('e'): + return lemmatize + return self.stemmer.stem(lemmatize) + def english_normalize_(self, tks): - return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] + return [self._stem_and_lemmatize(t) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] def _split_by_lang(self, line): txt_lang_pairs = [] @@ -328,7 +334,7 @@ def tokenize(self, line): res = [] for L,lang in arr: if not lang: - res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)]) + res.extend([self._stem_and_lemmatize(t) for t in word_tokenize(L)]) continue if len(L) < 2 or re.match( r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):