diff --git a/tests/conftest.py b/tests/conftest.py index e317bbc0a6..99e872d21d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -480,7 +480,8 @@ def empty_index(): stream_msg_template['id']: stream_msg_template, pm_template['id']: pm_template, group_pm_template['id']: group_pm_template, - }) + }), + 'unread_msg_ids': set() }) @@ -704,10 +705,10 @@ def stream_dict(streams_fixture): @pytest.fixture def classified_unread_counts(): """ - Unread counts return by + Unread counts and unread_msg_ids returned by helper.classify_unread_counts function. """ - return { + return ({ 'all_msg': 12, 'all_pms': 8, 'unread_topics': { @@ -726,4 +727,4 @@ def classified_unread_counts(): 1000: 3, 99: 1 } - } + }, {1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 101, 102}) diff --git a/tests/helper/test_helper.py b/tests/helper/test_helper.py index f810f08464..b958b10f6e 100644 --- a/tests/helper/test_helper.py +++ b/tests/helper/test_helper.py @@ -183,8 +183,9 @@ def test_classify_unread_counts(mocker, initial_data, stream_dict, model.initial_data = initial_data model.muted_topics = muted_topics model.muted_streams = muted_streams - assert classify_unread_counts(model) == dict(classified_unread_counts, - **vary_in_unreads) + assert classify_unread_counts(model) == (dict(classified_unread_counts[0], + **vary_in_unreads), + classified_unread_counts[1]) @pytest.mark.parametrize('color', [ diff --git a/tests/model/test_model.py b/tests/model/test_model.py index d3d25e4183..d164a9cdd2 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -34,7 +34,7 @@ def model(self, mocker, initial_data, user_profile): # NOTE: PATCH WHERE USED NOT WHERE DEFINED self.classify_unread_counts = mocker.patch( 'zulipterminal.model.classify_unread_counts', - return_value=[]) + return_value=([], set())) self.client.get_profile.return_value = user_profile model = Model(self.controller) return model @@ -436,7 +436,7 @@ def test_success_get_messages(self, mocker, messages_successful_response, return_value=({}, set(), [], [])) self.classify_unread_counts = mocker.patch( 'zulipterminal.model.classify_unread_counts', - return_value=[]) + return_value=([], set())) # Setup mocks before calling get_messages self.client.get_messages.return_value = messages_successful_response @@ -476,7 +476,7 @@ def test_get_message_false_first_anchor( return_value=({}, set(), [], [])) self.classify_unread_counts = mocker.patch( 'zulipterminal.model.classify_unread_counts', - return_value=[]) + return_value=([], set())) # Setup mocks before calling get_messages messages_successful_response['anchor'] = 0 @@ -509,7 +509,7 @@ def test_fail_get_messages(self, mocker, error_response, return_value=({}, set(), [], [])) self.classify_unread_counts = mocker.patch( 'zulipterminal.model.classify_unread_counts', - return_value=[]) + return_value=([], set())) # Setup mock before calling get_messages # FIXME This has no influence on the result @@ -588,7 +588,7 @@ def test__update_initial_data_raises_exception(self, mocker, initial_data): return_value=({}, set(), [], [])) self.classify_unread_counts = mocker.patch( 'zulipterminal.model.classify_unread_counts', - return_value=[]) + return_value=([], set())) # Setup mocks before calling get_messages self.client.register.return_value = initial_data @@ -622,7 +622,7 @@ def test_get_all_users(self, mocker, initial_data, user_list, user_dict, return_value=({}, set(), [], [])) self.classify_unread_counts = mocker.patch( 'zulipterminal.model.classify_unread_counts', - return_value=[]) + return_value=([], set())) model = Model(self.controller) assert model.user_dict == user_dict assert model.users == user_list @@ -1145,13 +1145,15 @@ def test_update_reaction_remove_reaction(self, mocker, model, response, def test_update_star_status_no_index(self, mocker, model): model.index = dict(messages={}) # Not indexed - event = dict(messages=[1], flag='starred', all=False) + event = dict(messages=[1], flag='starred', all=False, operation='add') mocker.patch('zulipterminal.model.Model.update_rendered_view') + set_count = mocker.patch('zulipterminal.model.set_count') model.update_message_flag_status(event) assert model.index == dict(messages={}) model.update_rendered_view.assert_not_called() + set_count.assert_not_called() def test_update_star_status_invalid_operation(self, mocker, model): model.index = dict(messages={1: {'flags': None}}) # Minimal @@ -1163,9 +1165,11 @@ def test_update_star_status_invalid_operation(self, mocker, model): 'all': False, } mocker.patch('zulipterminal.model.Model.update_rendered_view') + set_count = mocker.patch('zulipterminal.model.set_count') with pytest.raises(RuntimeError): model.update_message_flag_status(event) model.update_rendered_view.assert_not_called() + set_count.assert_not_called() @pytest.mark.parametrize('event_message_ids, indexed_ids', [ ([1], [1]), @@ -1199,6 +1203,7 @@ def test_update_star_status(self, mocker, model, event_op, 'all': False, } mocker.patch('zulipterminal.model.Model.update_rendered_view') + set_count = mocker.patch('zulipterminal.model.set_count') model.update_message_flag_status(event) @@ -1206,12 +1211,76 @@ def test_update_star_status(self, mocker, model, event_op, for changed_id in changed_ids: assert model.index['messages'][changed_id]['flags'] == flags_after (model.update_rendered_view. - has_calls([mocker.call(changed_id) for changed_id in changed_ids])) + assert_has_calls([mocker.call(changed_id) + for changed_id in changed_ids])) + + for unchanged_id in (set(indexed_ids) - set(event_message_ids)): + assert (model.index['messages'][unchanged_id]['flags'] == + flags_before) + + set_count.assert_not_called() + + @pytest.mark.parametrize('event_message_ids, indexed_ids', [ + ([1], [1]), + ([1, 2], [1]), + ([1, 2], [1, 2]), + ([1], [1, 2]), + ([], [1, 2]), + ([1, 2], []), + ]) + @pytest.mark.parametrize('event_op, flags_before, flags_after', [ + ('add', [], ['read']), + ('add', ['read'], ['read']), + ('add', ['starred'], ['starred', 'read']), + ('add', ['read', 'starred'], ['read', 'starred']), + ('remove', [], []), + ('remove', ['read'], ['read']), # msg cannot be marked 'unread' + ('remove', ['starred'], ['starred']), + ('remove', ['starred', 'read'], ['starred', 'read']), + ('remove', ['read', 'starred'], ['read', 'starred']), + ]) + def test_update_read_status(self, mocker, model, event_op, + event_message_ids, indexed_ids, + flags_before, flags_after): + model.index = dict(messages={msg_id: {'flags': flags_before} + for msg_id in indexed_ids}) + model.index['unread_msg_ids'] = set(event_message_ids) + + event = { + 'messages': event_message_ids, + 'type': 'update_message_flags', + 'flag': 'read', + 'operation': event_op, + 'all': False, + } + + mocker.patch('zulipterminal.model.Model.update_rendered_view') + set_count = mocker.patch('zulipterminal.model.set_count') + + model.update_message_flag_status(event) + + changed_ids = set(indexed_ids) & set(event_message_ids) + for changed_id in changed_ids: + assert model.index['messages'][changed_id]['flags'] == flags_after + + if event_op == 'add': + model.update_rendered_view.assert_has_calls( + [mocker.call(changed_id)]) + elif event_op == 'remove': + model.update_rendered_view.assert_not_called() for unchanged_id in (set(indexed_ids) - set(event_message_ids)): assert (model.index['messages'][unchanged_id]['flags'] == flags_before) + if event_op == 'add': + set_count.assert_called_once_with(list(changed_ids), + self.controller, -1) + assert len(model.index['unread_msg_ids']) == 0 + elif event_op == 'remove': + set_count.assert_not_called() + assert model.index['unread_msg_ids'] == set(event_message_ids) + @pytest.mark.parametrize('narrow, event, called', [ # Not in PM Narrow ([], {}, False), diff --git a/zulipterminal/helper.py b/zulipterminal/helper.py index b8593b8e0d..1abb7b479b 100644 --- a/zulipterminal/helper.py +++ b/zulipterminal/helper.py @@ -38,6 +38,8 @@ 'search': Set[int], # {message_id, ...} # Downloaded message data 'messages': Dict[int, Message], # message_id: Message + # unread message data; additional data in model.initial_data['unread_msgs'] + 'unread_msg_ids': Set[int] # {message_ids, ...} }) initial_index = Index( @@ -52,6 +54,7 @@ topics=defaultdict(list), search=set(), messages=defaultdict(dict), + unread_msg_ids=set(), ) @@ -351,9 +354,10 @@ def index_messages(messages: List[Message], return index -def classify_unread_counts(model: Any) -> UnreadCounts: +def classify_unread_counts(model: Any) -> Tuple[UnreadCounts, Set[int]]: # TODO: support group pms unread_msg_counts = model.initial_data['unread_msgs'] + unread_msg_ids = set() # type: Set[int] unread_counts = UnreadCounts( all_msg=0, @@ -365,13 +369,17 @@ def classify_unread_counts(model: Any) -> UnreadCounts: ) for pm in unread_msg_counts['pms']: - count = len(pm['unread_message_ids']) + message_ids = pm['unread_message_ids'] + unread_msg_ids.update(message_ids) + count = len(message_ids) unread_counts['unread_pms'][pm['sender_id']] = count unread_counts['all_msg'] += count unread_counts['all_pms'] += count for stream in unread_msg_counts['streams']: - count = len(stream['unread_message_ids']) + message_ids = stream['unread_message_ids'] + unread_msg_ids.update(message_ids) + count = len(message_ids) stream_id = stream['stream_id'] if [model.stream_dict[stream_id]['name'], stream['topic']] in model.muted_topics: @@ -386,14 +394,16 @@ def classify_unread_counts(model: Any) -> UnreadCounts: # store unread count of group pms in `unread_huddles` for group_pm in unread_msg_counts['huddles']: - count = len(group_pm['unread_message_ids']) + message_ids = group_pm['unread_message_ids'] + unread_msg_ids.update(message_ids) + count = len(message_ids) user_ids = group_pm['user_ids_string'].split(',') user_ids = frozenset(map(int, user_ids)) unread_counts['unread_huddles'][user_ids] = count unread_counts['all_msg'] += count unread_counts['all_pms'] += count - return unread_counts + return unread_counts, unread_msg_ids def match_user(user: Any, text: str) -> bool: diff --git a/zulipterminal/model.py b/zulipterminal/model.py index 48400ad3cf..2e01d3682d 100644 --- a/zulipterminal/model.py +++ b/zulipterminal/model.py @@ -116,7 +116,8 @@ def __init__(self, controller: Any) -> None: self.user_group_by_id = {} # type: Dict[int, Dict[str, Any]] self.user_group_names = self._group_info_from_realm_user_groups(groups) - self.unread_counts = classify_unread_counts(self) + unread_data = classify_unread_counts(self) + self.unread_counts, self.index['unread_msg_ids'] = unread_data self.fetch_all_topics(workers=5) @@ -285,7 +286,6 @@ def mark_message_ids_as_read(self, id_list: List[int]) -> None: 'flag': 'read', 'op': 'add', }) - set_count(id_list, self.controller, -1) # FIXME Update? def send_private_message(self, recipients: str, content: str) -> bool: @@ -811,9 +811,11 @@ def update_message_flag_status(self, event: Event) -> None: if event['all']: # FIXME Should handle eventually return - # TODO: Expand from 'starred' to also support 'read' flag changes? flag_to_change = event['flag'] - if flag_to_change != 'starred': + if flag_to_change not in {'starred', 'read'}: + return + + if flag_to_change == 'read' and event['operation'] == 'remove': return indexed_message_ids = set(self.index['messages']) @@ -833,6 +835,11 @@ def update_message_flag_status(self, event: Event) -> None: self.index['messages'][message_id] = msg self.update_rendered_view(message_id) + if event['operation'] == 'add' and flag_to_change == 'read': + set_count(list(message_ids_to_mark & indexed_message_ids), + self.controller, -1) + self.index['unread_msg_ids'].difference_update(message_ids_to_mark) + def update_rendered_view(self, msg_id: int) -> None: # Update new content in the rendered view for msg_w in self.msg_list.log: