Skip to content
This repository was archived by the owner on Sep 23, 2024. It is now read-only.

Commit 954acf0

Browse files
Merge pull request #8 from korbit-ai/fix-json-key-type-crash
Fix JSON TypeError on dict key can now be a numpy type
2 parents fd440ad + 422cf04 commit 954acf0

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

codejail/custom_encoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import numpy as np
32

43

54
class GlobalEncoder(json.JSONEncoder):
@@ -14,10 +13,12 @@ def default(self, obj):
1413

1514
class NumpyEncoder(json.JSONEncoder):
1615
""" Custom encoder for numpy data types """
16+
1717
def default(self, obj):
18+
import numpy as np
1819
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
19-
np.int16, np.int32, np.int64, np.uint8,
20-
np.uint16, np.uint32, np.uint64)):
20+
np.int16, np.int32, np.int64, np.uint8,
21+
np.uint16, np.uint32, np.uint64)):
2122
return int(obj)
2223
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
2324
return float(obj)
@@ -27,6 +28,6 @@ def default(self, obj):
2728
return obj.tolist()
2829
elif isinstance(obj, (np.bool_)):
2930
return bool(obj)
30-
elif isinstance(obj, (np.void)):
31+
elif isinstance(obj, (np.void)):
3132
return None
3233
return json.JSONEncoder.default(self, obj)

codejail/safe_exec.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,6 @@ def json_safe(d):
194194
"""
195195
# pylint: disable=invalid-name
196196

197-
# six.binary_type is here because bytes are sometimes ok if they represent valid utf8
198-
# so we consider them valid for now and try to decode them with decode_object. If that
199-
# doesn't work they'll get dropped later in the process.
200-
ok_types = (type(None), int, float, six.binary_type, six.text_type, list, tuple, dict)
201-
202197
def decode_object(obj):
203198
"""
204199
Convert an object to a JSON serializable form by decoding all byte strings.
@@ -213,6 +208,9 @@ def decode_object(obj):
213208
"""
214209
if isinstance(obj, bytes):
215210
return obj.decode('utf-8')
211+
if type(obj).__module__ == 'numpy':
212+
from custom_encoder import NumpyEncoder
213+
return NumpyEncoder().default(obj)
216214
if isinstance(obj, (list, tuple)):
217215
new_list = []
218216
for i in obj:
@@ -231,8 +229,6 @@ def decode_object(obj):
231229
bad_keys = ("__builtins__",)
232230
jd = {}
233231
for k, v in six.iteritems(d):
234-
if not isinstance(v, ok_types):
235-
continue
236232
if k in bad_keys:
237233
continue
238234
try:

0 commit comments

Comments
 (0)