Skip to content

Migrate from built-in json to faster orjson #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
from codecs import open
from typing import Any
from typing import Dict
Expand All @@ -26,6 +27,10 @@
with open(os.path.join(here, "README.md"), "r", "utf-8") as f:
readme = f.read()

orjson_require = []
if platform.python_implementation() != 'PyPy':
orjson_require = ['orjson >= 3.11.0']

kerberos_require = ["requests_kerberos"]
gssapi_require = [""
"requests_gssapi",
Expand Down Expand Up @@ -90,7 +95,7 @@
"requests>=2.31.0",
"tzlocal",
"zstandard",
],
] + orjson_require,
extras_require={
"all": all_require,
"kerberos": kerberos_require,
Expand Down
19 changes: 11 additions & 8 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import threading
import time
import urllib
Expand All @@ -26,6 +25,10 @@
import gssapi
import httpretty
import keyring
try:
import orjson as json
except ImportError:
import json
import pytest
import requests
from httpretty import httprettified
Expand Down Expand Up @@ -61,8 +64,7 @@

@mock.patch("trino.client.TrinoRequest.http")
def test_trino_initial_request(mock_requests, sample_post_response_data):
mock_requests.Response.return_value.json.return_value = sample_post_response_data

mock_requests.Response.return_value.text = json.dumps(sample_post_response_data)
req = TrinoRequest(
host="coordinator",
port=8080,
Expand Down Expand Up @@ -692,7 +694,7 @@ def run(self) -> None:

@mock.patch("trino.client.TrinoRequest.http")
def test_trino_fetch_request(mock_requests, sample_get_response_data):
mock_requests.Response.return_value.json.return_value = sample_get_response_data
mock_requests.Response.return_value.text = json.dumps(sample_get_response_data)

req = TrinoRequest(
host="coordinator",
Expand All @@ -718,7 +720,7 @@ def test_trino_fetch_request(mock_requests, sample_get_response_data):

@mock.patch("trino.client.TrinoRequest.http")
def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_none):
mock_requests.Response.return_value.json.return_value = sample_get_response_data_none
mock_requests.Response.return_value.text = json.dumps(sample_get_response_data_none)

req = TrinoRequest(
host="coordinator",
Expand All @@ -744,7 +746,7 @@ def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_n

@mock.patch("trino.client.TrinoRequest.http")
def test_trino_fetch_error(mock_requests, sample_get_error_response_data):
mock_requests.Response.return_value.json.return_value = sample_get_error_response_data
mock_requests.Response.return_value.text = json.dumps(sample_get_error_response_data)

req = TrinoRequest(
host="coordinator",
Expand Down Expand Up @@ -1154,8 +1156,9 @@ def headers(self):
'X-Trino-Fake-2': 'two',
}

def json(self):
return sample_get_response_data
@property
def text(self):
return json.dumps(sample_get_response_data)

req = TrinoRequest(
host="coordinator",
Expand Down
8 changes: 6 additions & 2 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import base64
import copy
import functools
import json
import os
import random
import re
Expand All @@ -66,6 +65,11 @@
from zoneinfo import ZoneInfo

import lz4.block
try:
import orjson as json
except ImportError:
import json

import requests
import zstandard
from requests import Response
Expand Down Expand Up @@ -687,7 +691,7 @@ def process(self, http_response: Response) -> TrinoStatus:
self.raise_response_error(http_response)

http_response.encoding = "utf-8"
response = http_response.json()
response = json.loads(http_response.text)
if "error" in response and response["error"]:
raise self._process_error(response["error"], response.get("id"))

Expand Down