Skip to content

Commit 4e41e61

Browse files
committed
PYON: fix nested encoding
Forces recursive calls of pyon.encode(), which is compatible with the plugin architecture. Previously, values similar to {'entry': CustomType()} would not encode because _Encoder didn't recognze CustomType
1 parent 0d53280 commit 4e41e61

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

sipyco/pyon.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@
7272

7373

7474
class _Encoder:
75-
def __init__(self, pretty):
75+
def __init__(self, pretty, indent_level=0):
7676
self.pretty = pretty
77-
self.indent_level = 0
77+
self.indent_level = indent_level
7878

7979
def indent(self):
8080
return " " * self.indent_level
@@ -100,22 +100,22 @@ def encode_bytes(self, x):
100100

101101
def encode_tuple(self, x):
102102
if len(x) == 1:
103-
return "(" + self.encode(x[0]) + ", )"
103+
return "(" + encode(x[0], self.pretty, self.indent_level) + ", )"
104104
else:
105105
r = "("
106-
r += ", ".join([self.encode(item) for item in x])
106+
r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x])
107107
r += ")"
108108
return r
109109

110110
def encode_list(self, x):
111111
r = "["
112-
r += ", ".join([self.encode(item) for item in x])
112+
r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x])
113113
r += "]"
114114
return r
115115

116116
def encode_set(self, x):
117117
r = "{"
118-
r += ", ".join([self.encode(item) for item in x])
118+
r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x])
119119
r += "}"
120120
return r
121121

@@ -127,8 +127,14 @@ def encode_dict(self, x):
127127

128128
r = "{"
129129
if not self.pretty or len(x) < 2:
130-
r += ", ".join([self.encode(k) + ": " + self.encode(v)
131-
for k, v in items()])
130+
r += ", ".join(
131+
[
132+
encode(k, self.pretty, self.indent_level)
133+
+ ": "
134+
+ encode(v, self.pretty, self.indent_level)
135+
for k, v in items()
136+
]
137+
)
132138
else:
133139
self.indent_level += 1
134140
r += "\n"
@@ -137,7 +143,12 @@ def encode_dict(self, x):
137143
if not first:
138144
r += ",\n"
139145
first = False
140-
r += self.indent() + self.encode(k) + ": " + self.encode(v)
146+
r += (
147+
self.indent()
148+
+ encode(k, self.pretty, self.indent_level)
149+
+ ": "
150+
+ encode(v, self.pretty, self.indent_level)
151+
)
141152
r += "\n" # no ','
142153
self.indent_level -= 1
143154
r += self.indent()
@@ -148,24 +159,30 @@ def encode_slice(self, x):
148159
return repr(x)
149160

150161
def encode_fraction(self, x):
151-
return "Fraction({}, {})".format(self.encode(x.numerator),
152-
self.encode(x.denominator))
162+
return "Fraction({}, {})".format(
163+
encode(x.numerator, self.pretty, self.indent_level),
164+
encode(x.denominator, self.pretty, self.indent_level),
165+
)
153166

154167
def encode_ordereddict(self, x):
155-
return "OrderedDict(" + self.encode(list(x.items())) + ")"
168+
return (
169+
"OrderedDict("
170+
+ encode(list(x.items()), self.pretty, self.indent_level)
171+
+ ")"
172+
)
156173

157174
def encode_nparray(self, x):
158175
x = numpy.ascontiguousarray(x)
159176
r = "nparray("
160-
r += self.encode(x.shape) + ", "
161-
r += self.encode(x.dtype.str) + ", b\""
177+
r += encode(x.shape, self.pretty, self.indent_level) + ", "
178+
r += encode(x.dtype.str, self.pretty, self.indent_level) + ", b\""
162179
r += base64.b64encode(x.data).decode()
163180
r += "\")"
164181
return r
165182

166183
def encode_npscalar(self, x):
167184
r = "npscalar("
168-
r += self.encode(x.dtype.str) + ", b\""
185+
r += encode(x.dtype.str, self.pretty, self.indent_level) + ", b\""
169186
r += base64.b64encode(x.data).decode()
170187
r += "\")"
171188
return r

sipyco/test/test_pyon_plugin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,10 @@ def test_pyon_plugin_encode_decode(monkeypatch):
5050
monkeypatch.setattr(plugin, "get_plugin_manager", pyon_extra_plugin)
5151
test_value = Point(2.5, 3.4)
5252
assert pyon.decode(pyon.encode(test_value)) == test_value
53+
54+
55+
def test_pyon_nested_encode(monkeypatch):
56+
"""Tests that nested items will be properly encoded."""
57+
monkeypatch.setattr(plugin, "get_plugin_manager", pyon_extra_plugin)
58+
test_value = {"first": Point(2.5, {"nothing": 0})}
59+
assert pyon.decode(pyon.encode(test_value)) == test_value

0 commit comments

Comments
 (0)