Skip to content

Commit 3583e60

Browse files
added input validation and improved merge_sort error handling
1 parent e2a78d4 commit 3583e60

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

sorts/merge_sort.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
2-
This is a pure Python implementation of the merge sort algorithm.
2+
This is a pure Python implementation of the merge sort algorithm
3+
with added input validation for safer usage.
34
45
For doctests run following command:
56
python -m doctest -v merge_sort.py
@@ -9,6 +10,22 @@
910
python merge_sort.py
1011
"""
1112

13+
# -----------------------
14+
# Input validation helper
15+
# -----------------------
16+
def _validate_sort_input(arr):
17+
"""
18+
Ensures the input is a list (or tuple) of comparable elements.
19+
20+
Raises:
21+
TypeError: if arr is not a list or tuple.
22+
ValueError: if list contains uncomparable elements.
23+
"""
24+
if not isinstance(arr, (list, tuple)):
25+
raise TypeError("merge_sort() input must be a list or tuple.")
26+
if len(arr) > 1 and not all(isinstance(x, (int, float, str)) for x in arr):
27+
raise ValueError("merge_sort() elements must be comparable (int, float, or str).")
28+
1229

1330
def merge_sort(collection: list) -> list:
1431
"""
@@ -28,6 +45,7 @@ def merge_sort(collection: list) -> list:
2845
>>> merge_sort([-2, -5, -45])
2946
[-45, -5, -2]
3047
"""
48+
_validate_sort_input(collection)
3149

3250
def merge(left: list, right: list) -> list:
3351
"""
@@ -45,20 +63,25 @@ def merge(left: list, right: list) -> list:
4563
return result
4664

4765
if len(collection) <= 1:
48-
return collection
66+
return list(collection)
67+
4968
mid_index = len(collection) // 2
50-
return merge(merge_sort(collection[:mid_index]), merge_sort(collection[mid_index:]))
69+
return merge(
70+
merge_sort(collection[:mid_index]),
71+
merge_sort(collection[mid_index:])
72+
)
5173

5274

5375
if __name__ == "__main__":
5476
import doctest
55-
5677
doctest.testmod()
5778

5879
try:
5980
user_input = input("Enter numbers separated by a comma:\n").strip()
60-
unsorted = [int(item) for item in user_input.split(",")]
81+
unsorted = [int(item) for item in user_input.split(",") if item]
6182
sorted_list = merge_sort(unsorted)
62-
print(*sorted_list, sep=",")
83+
print("Sorted list:", *sorted_list, sep=" ")
6384
except ValueError:
6485
print("Invalid input. Please enter valid integers separated by commas.")
86+
except TypeError as e:
87+
print(e)

0 commit comments

Comments
 (0)