diff --git a/ninja/operation.py b/ninja/operation.py index 623c0b7d0..7fd01e748 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -255,26 +255,31 @@ def _result_to_response( # Empty response. return temporal_response - resp_object = ResponseObject(result) - # ^ we need object because getter_dict seems work only with model_validate - validated_object = response_model.model_validate( - resp_object, context={"request": request, "response_status": status} - ) - - model_dump_kwargs: Dict[str, Any] = {} + model_dump_kwargs: Dict[str, Any] = { + "by_alias": self.by_alias, + "exclude_unset": self.exclude_unset, + "exclude_defaults": self.exclude_defaults, + "exclude_none": self.exclude_none, + } if pydantic_version >= [2, 7]: # pydantic added support for serialization context at 2.7 - model_dump_kwargs.update( - context={"request": request, "response_status": status} + model_dump_kwargs.update(context={"request": request, "response_status": status}) + + response_schema = response_model.__annotations__["response"] + if ( + isinstance(response_schema, type) + and issubclass(response_schema, pydantic.BaseModel) + and isinstance(result, response_schema) + ): + validated_object = result + result = validated_object.model_dump(**model_dump_kwargs) + else: + # ^ we need object because getter_dict seems work only with model_validate + validated_object = response_model.model_validate( + ResponseObject(result), context={"request": request, "response_status": status} ) + result = validated_object.model_dump(**model_dump_kwargs)["response"] - result = validated_object.model_dump( - by_alias=self.by_alias, - exclude_unset=self.exclude_unset, - exclude_defaults=self.exclude_defaults, - exclude_none=self.exclude_none, - **model_dump_kwargs, - )["response"] return self.api.create_response( request, result, temporal_response=temporal_response ) diff --git a/tests/test_response.py b/tests/test_response.py index 6b234a6cb..6d161c39f 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -40,6 +40,10 @@ def to_camel(string: str) -> str: return "".join(word.capitalize() for word in string.split("_")) +class MessageModel(BaseModel): + message: str + + class UserModel(BaseModel): id: int user_name: str @@ -67,6 +71,14 @@ def check_model_alias(request): return User(1, "John", "Password") +@router.get("/check_pydantic", response={200: UserModel, 201: MessageModel}, by_alias=True) +def check_pydantic(request, message_only: bool): + print(message_only) + if message_only: + return 201, {"message": "Created"} + return 200, UserModel(id=1, user_name="John") + + @router.get("/check_union", response=Union[int, UserModel]) def check_union(request, q: int): if q == 0: @@ -99,23 +111,26 @@ def check_del_cookie(request, response: HttpResponse): @pytest.mark.parametrize( - "path,expected_response", + "path,expected_status_code,expected_response", [ - ("/check_int", 1), - ("/check_model", {"id": 1, "user_name": "John"}), # the password is skipped + ("/check_int", 200, 1), + ("/check_model", 200, {"id": 1, "user_name": "John"}), # the password is skipped ( "/check_list_model", + 200, [{"id": 1, "user_name": "John"}], ), # the password is skipped - ("/check_model", {"id": 1, "user_name": "John"}), # the password is skipped - ("/check_model_alias", {"Id": 1, "UserName": "John"}), # result is Camal Case - ("/check_union?q=0", 1), - ("/check_union?q=1", {"id": 1, "user_name": "John"}), + ("/check_model", 200, {"id": 1, "user_name": "John"}), # the password is skipped + ("/check_model_alias", 200, {"Id": 1, "UserName": "John"}), # result is Camal Case + ("/check_pydantic?message_only=0", 200, {"Id": 1, "UserName": "John"}), # result is Camal Case + ("/check_pydantic?message_only=1", 201, {"message": "Created"}), + ("/check_union?q=0", 200, 1), + ("/check_union?q=1", 200, {"id": 1, "user_name": "John"}), ], ) -def test_responses(path, expected_response): +def test_responses(path, expected_status_code, expected_response): response = client.get(path) - assert response.status_code == 200, response.content + assert response.status_code == expected_status_code, response.content assert response.json() == expected_response assert response.data == response.data == expected_response # Ensures cache works