diff --git a/fuzzywuzzy/process.py b/fuzzywuzzy/process.py index 4d73248d..c4befd41 100644 --- a/fuzzywuzzy/process.py +++ b/fuzzywuzzy/process.py @@ -119,7 +119,7 @@ def no_process(x): yield (choice, score) -def extract(query, choices, processor=default_processor, scorer=default_scorer, limit=5): +def extract(query, choices, processor=default_processor, scorer=default_scorer, limit=5, char_sort=False): """Select the best match in a list or dictionary of choices. Find best matches in a list or dictionary of choices, return a @@ -165,11 +165,17 @@ def extract(query, choices, processor=default_processor, scorer=default_scorer, [('train', 22, 'bard'), ('man', 0, 'dog')] """ sl = extractWithoutOrder(query, choices, processor, scorer) - return heapq.nlargest(limit, sl, key=lambda i: i[1]) if limit is not None else \ - sorted(sl, key=lambda i: i[1], reverse=True) + if char_sort is False: + return heapq.nlargest(limit, sl, key=lambda i: i[1]) if limit is not None else \ + sorted(sl, key=lambda i: i[1], reverse=True) + else: + sl = sorted(sl, key=lambda i: i[1], reverse=True) + return sortByCommonLetter(sl, query)[0: min(limit, len(sl))] if limit is not None else \ + sortByCommonLetter(sl, query) -def extractBests(query, choices, processor=default_processor, scorer=default_scorer, score_cutoff=0, limit=5): +def extractBests(query, choices, processor=default_processor, scorer=default_scorer, score_cutoff=0, limit=5, + char_sort=False): """Get a list of the best matches to a collection of choices. Convenience function for getting the choices with best scores. @@ -188,13 +194,17 @@ def extractBests(query, choices, processor=default_processor, scorer=default_sco Returns: A a list of (match, score) tuples. """ - - best_list = extractWithoutOrder(query, choices, processor, scorer, score_cutoff) - return heapq.nlargest(limit, best_list, key=lambda i: i[1]) if limit is not None else \ - sorted(best_list, key=lambda i: i[1], reverse=True) + best_list = extractWithoutOrder(query, choices, processor, scorer) + if char_sort is False: + return heapq.nlargest(limit, best_list, key=lambda i: i[1]) if limit is not None else \ + sorted(best_list, key=lambda i: i[1], reverse=True) + else: + best_list = sorted(best_list, key=lambda i: i[1], reverse=True) + return sortByCommonLetter(best_list, query)[0: min(limit, len(best_list))] if limit is not None else \ + sortByCommonLetter(best_list, query) -def extractOne(query, choices, processor=default_processor, scorer=default_scorer, score_cutoff=0): +def extractOne(query, choices, processor=default_processor, scorer=default_scorer, score_cutoff=0, char_sort=False): """Find the single best match above a score in a list of choices. This is a convenience method which returns the single best choice. @@ -216,10 +226,70 @@ def extractOne(query, choices, processor=default_processor, scorer=default_score was found that was above score_cutoff. Otherwise, returns None. """ best_list = extractWithoutOrder(query, choices, processor, scorer, score_cutoff) - try: - return max(best_list, key=lambda i: i[1]) - except ValueError: - return None + if char_sort is False: + try: + return max(best_list, key=lambda i: i[1]) + except ValueError: + return None + else: + best_list = sorted(best_list, key=lambda i: i[1], reverse=True) + try: + return max(sortByCommonLetter(best_list, query), key=lambda i: i[1]) + except ValueError: + return None + + +def sortByCommonLetter(sl, query): + """This function further sorts the strings with the same scores by common letter count to the query.""" + current_score, last_index = -1, -1 + # Iterate over list and look for words with the same scores + for i in range(0, len(sl)): + # Identify the indexes of the strings with the same scores + if sl[i][1] != current_score or i == len(sl) - 1: + current_score = sl[i][1] + # First iteration, there are no previous words so we do not have to do anything + if last_index == -1: + last_index = i + continue + # Found a group of words with the same scores! Now sort them + if i - last_index > 1: + count_list = [] + for j in range(last_index, i): + count_list.append((sl[j][0], calculateCommonLetter(query, sl[j][0]))) + count_list = sorted(count_list, key=lambda k: k[1], reverse=True) + # Copy the sorted portion + for j in range(0, len(count_list)): + sl[last_index + j] = (count_list[j][0], current_score) + last_index = i + return sl + + +def calculateCommonLetter(s1, s2): + char_dict = {} + commonLetterCount = 0 + s1 = utils.full_process(s1) + s2 = utils.full_process(s2) + for char in s1: + if char in char_dict: + char_dict[char] += 1 + else: + char_dict[char] = 1 + + for char in s2: + if char in char_dict: + commonLetterCount += 1 + char_dict[char] -= 1 + if char_dict[char] == 0: + del char_dict[char] + # Add penalty for extra letters + else: + commonLetterCount -= 1 + + # Add penalty for missing letters + for char in char_dict: + commonLetterCount -= 1 + + return commonLetterCount def dedupe(contains_dupes, threshold=70, scorer=fuzz.token_set_ratio): diff --git a/test_fuzzywuzzy.py b/test_fuzzywuzzy.py index 58617b68..0d086b42 100644 --- a/test_fuzzywuzzy.py +++ b/test_fuzzywuzzy.py @@ -520,6 +520,32 @@ def test_simplematch(self): self.assertEqual(part_result, ('a, b', 100)) + def test_extractOne_sort_by_common_letter_count(self): + # Test case 1 + query_1 = 'Company 2' + choices_1 = ['Company', 'Company 1', 'Company 2', 'Awesome Company'] + + result_without_sort_1 = process.extractOne(query_1, choices_1, scorer=fuzz.partial_token_set_ratio, char_sort=False) + result_char_sort_1 = process.extractOne(query_1, choices_1, scorer=fuzz.partial_token_set_ratio, char_sort=True) + + self.assertEqual(result_without_sort_1, ('Company', 100)) + self.assertEqual(result_char_sort_1, ('Company 2', 100)) + + # Test case 2 + query_2 = 'apple pie' + choices_2 = ['pie', 'apple', 'pieapples', 'apple pie', 'pie apple'] + + result_char_sort_2 = process.extractOne(query_2, choices_2, scorer= fuzz.ratio, char_sort=True) + + self.assertEqual(result_char_sort_2, ('apple pie', 100)) + + # Test case 3 + query_3 = 'ABC NEWS' + choices_3 = ['BCD NEWS', 'NEWS ABC', 'DFG NEWS'] + result_char_sort_2 = process.extractOne(query_3, choices_3, scorer=fuzz.partial_token_set_ratio, char_sort=True) + + self.assertEqual(result_char_sort_2, ('NEWS ABC', 100)) + class TestCodeFormat(unittest.TestCase): def test_pep8_conformance(self): pep8style = pycodestyle.StyleGuide(quiet=False)