From 5e1d2d1369e9a4194baab46dc42cb64d95b44315 Mon Sep 17 00:00:00 2001 From: Oliver Sauder Date: Mon, 14 Jul 2025 20:06:25 +0700 Subject: [PATCH] nsured that interpreting `include` query parameter is done in internal Python naming. --- CHANGELOG.md | 5 +++++ rest_framework_json_api/renderers.py | 4 ---- rest_framework_json_api/serializers.py | 3 +-- rest_framework_json_api/utils.py | 12 ++++++++++-- tests/conftest.py | 7 ++++++- tests/test_utils.py | 21 +++++++++++++++++++++ 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ec7ff7d..eae89c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,11 @@ any parts of the framework not mentioned in the documentation should generally b * Added support for Django REST framework 3.16. * Added support for Django 5.2. +### Fixed + +* Ensured that interpreting `include` query parameter is done in internal Python naming. + This adds full support for using multipart field names for includes while configuring `JSON_API_FORMAT_FIELD_NAMES`. + ### Removed * Removed support for Python 3.8. diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 8c19934f..b670338f 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -6,7 +6,6 @@ from collections import defaultdict from collections.abc import Iterable -import inflection from django.db.models import Manager from django.template import loader from django.utils.encoding import force_str @@ -277,9 +276,6 @@ def extract_included( current_serializer, "included_serializers", dict() ) included_resources = copy.copy(included_resources) - included_resources = [ - inflection.underscore(value) for value in included_resources - ] for field_name, field in iter(fields.items()): # Skip URL field diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index d59dbd88..75764a5d 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -1,6 +1,5 @@ from collections.abc import Mapping -import inflection from django.core.exceptions import ObjectDoesNotExist from django.db.models.query import QuerySet from django.utils.module_loading import import_string as import_class_from_dotted_path @@ -129,7 +128,7 @@ def validate_path(serializer_class, field_path, path): serializers = getattr(serializer_class, "included_serializers", None) if serializers is None: raise ParseError("This endpoint does not support the include parameter") - this_field_name = inflection.underscore(field_path[0]) + this_field_name = field_path[0] this_included_serializer = serializers.get(this_field_name) if this_included_serializer is None: raise ParseError( diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 805f5f09..2dd79677 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -316,10 +316,18 @@ def get_resource_id(resource_instance, resource): def get_included_resources(request, serializer=None): - """Build a list of included resources.""" + """ + Build a list of included resources. + + This method ensures that returned includes are in Python internally used + format. + """ include_resources_param = request.query_params.get("include") if request else None if include_resources_param: - return include_resources_param.split(",") + return [ + undo_format_field_name(include) + for include in include_resources_param.split(",") + ] else: return get_default_included_resources_from_serializer(serializer) diff --git a/tests/conftest.py b/tests/conftest.py index 865244e0..77b3676b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory from tests.models import ( BasicModel, @@ -98,3 +98,8 @@ def nested_related_source( @pytest.fixture def client(): return APIClient() + + +@pytest.fixture +def rf(): + return APIRequestFactory() diff --git a/tests/test_utils.py b/tests/test_utils.py index a3beb12e..08e36b6a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ from rest_framework import status from rest_framework.fields import Field from rest_framework.generics import GenericAPIView +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView @@ -13,6 +14,7 @@ format_link_segment, format_resource_type, format_value, + get_included_resources, get_related_resource_type, get_resource_id, get_resource_name, @@ -456,3 +458,22 @@ def test_get_resource_id(resource_instance, resource, expected): ) def test_format_error_object(message, pointer, response, result): assert result == format_error_object(message, pointer, response) + + +@pytest.mark.parametrize( + "format_type,include_param,expected_includes", + [ + ("dasherize", "author-bio", ["author_bio"]), + ("dasherize", "author-bio,author-type", ["author_bio", "author_type"]), + ("dasherize", "author-bio.author-type", ["author_bio.author_type"]), + ("camelize", "authorBio", ["author_bio"]), + ], +) +def test_get_included_resources( + rf, include_param, expected_includes, format_type, settings +): + settings.JSON_API_FORMAT_FIELD_NAMES = format_type + + request = Request(rf.get("/test/", {"include": include_param})) + includes = get_included_resources(request) + assert includes == expected_includes