Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/feed/tests/views/test_feed_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from researchhub_document.related_models.researchhub_unified_document_model import (
ResearchhubUnifiedDocument,
)
from user.related_models.author_model import Author
from user.views.follow_view_mixins import create_follow

User = get_user_model()
Expand Down Expand Up @@ -818,3 +819,105 @@ def test_following_feed_default_sorting(self):
dates[i + 1],
"Feed items not sorted by action_date when using default sort",
)

def test_following_feed_with_followed_author(self):
"""Test that following an author returns their papers/posts in the feed."""
# Arrange
followed_author = Author.objects.create(
first_name="Followed",
last_name="Author",
)
create_follow(self.user, followed_author)

# Create a paper by the followed author in a hub we don't follow
author_paper_doc = ResearchhubUnifiedDocument.objects.create(
document_type="PAPER"
)
author_paper_doc.hubs.add(self.other_hub)
author_paper = Paper.objects.create(
title="Paper by Followed Author",
paper_publish_date=timezone.now(),
unified_document=author_paper_doc,
)

author_paper_entry = FeedEntry.objects.create(
user=self.other_user,
action="PUBLISH",
action_date=timezone.now(),
content_type=self.paper_content_type,
object_id=author_paper.id,
unified_document=author_paper_doc,
metrics={"votes": 25, "comments": 5},
)
author_paper_entry.hubs.add(self.other_hub)
author_paper_entry.authors.add(followed_author)

cache.clear()

url = reverse("feed-list")

# Act
response = self.client.get(url, {"feed_view": "following"})

# Assert
self.assertEqual(response.status_code, status.HTTP_200_OK)
result_ids = [r["content_object"]["id"] for r in response.data["results"]]

# Should see paper from followed author even though it's in an unfollowed hub
self.assertIn(author_paper.id, result_ids)
# Should still see paper from followed hub
self.assertIn(self.paper.id, result_ids)
# Should not see paper from unfollowed hub by unfollowed author
self.assertNotIn(self.other_paper.id, result_ids)

def test_following_feed_with_only_followed_author(self):
"""Test that following only authors (no hubs) returns their content."""
# Arrange
# Remove all hub follows
self.user.following.all().delete()

followed_author = Author.objects.create(
first_name="Only",
last_name="Author",
)
create_follow(self.user, followed_author)

# Create a paper by the followed author
author_paper_doc = ResearchhubUnifiedDocument.objects.create(
document_type="PAPER"
)
author_paper_doc.hubs.add(self.other_hub)
author_paper = Paper.objects.create(
title="Paper by Only Followed Author",
paper_publish_date=timezone.now(),
unified_document=author_paper_doc,
)

author_paper_entry = FeedEntry.objects.create(
user=self.other_user,
action="PUBLISH",
action_date=timezone.now(),
content_type=self.paper_content_type,
object_id=author_paper.id,
unified_document=author_paper_doc,
metrics={"votes": 25, "comments": 5},
)
author_paper_entry.hubs.add(self.other_hub)
author_paper_entry.authors.add(followed_author)

cache.clear()

url = reverse("feed-list")

# Act
response = self.client.get(url, {"feed_view": "following"})

# Assert
self.assertEqual(response.status_code, status.HTTP_200_OK)
result_ids = [r["content_object"]["id"] for r in response.data["results"]]

# Should see paper from followed author
self.assertIn(author_paper.id, result_ids)
# Should not see papers from unfollowed hubs/authors
self.assertNotIn(self.paper.id, result_ids)
self.assertNotIn(self.other_paper.id, result_ids)
16 changes: 11 additions & 5 deletions src/feed/views/feed_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from django.conf import settings
from django.core.cache import cache
from django.db.models import Case, IntegerField, Value, When
from django.db.models import Case, IntegerField, Q, Value, When
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.permissions import AllowAny
Expand Down Expand Up @@ -124,10 +124,16 @@ def get_queryset(self):
# Apply following filter only for "following" view
if feed_view == "following":
followed_hub_ids = self.get_followed_hub_ids()
if followed_hub_ids:
queryset = queryset.filter(
hubs__id__in=followed_hub_ids,
)
followed_author_ids = self.get_followed_author_ids()

if followed_hub_ids or followed_author_ids:
following_filter = Q()
if followed_hub_ids:
following_filter |= Q(hubs__id__in=followed_hub_ids)
if followed_author_ids:
following_filter |= Q(authors__id__in=followed_author_ids)

queryset = queryset.filter(following_filter)

# Only show paper and post for all following views
queryset = queryset.filter(
Expand Down
19 changes: 19 additions & 0 deletions src/feed/views/feed_view_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from paper.related_models.paper_model import Paper
from researchhub_comment.related_models.rh_comment_model import RhCommentModel
from researchhub_document.related_models.researchhub_post_model import ResearchhubPost
from user.related_models.author_model import Author


class FeedViewMixin:
Expand All @@ -20,6 +21,10 @@ class FeedViewMixin:

_content_types = {}

@property
def _author_content_type(self):
return self._get_content_type(Author)

@property
def _comment_content_type(self):
return self._get_content_type(RhCommentModel)
Expand Down Expand Up @@ -211,6 +216,20 @@ def get_followed_hub_ids(self):
).values_list("object_id", flat=True)
)

def get_followed_author_ids(self):
"""
Get IDs of authors followed by the current user.
Returns empty list if user is not authenticated.
"""
if not self.request.user.is_authenticated:
return []

return list(
self.request.user.following.filter(
content_type=self._author_content_type
).values_list("object_id", flat=True)
)

def _get_user_votes(self, created_by, doc_ids, reaction_content_type):
return Vote.objects.filter(
content_type=reaction_content_type,
Expand Down
Loading