Skip to content

Commit 2ead207

Browse files
committed
pyln.proto.message.*: add type annotations.
Other changes along the way: 1. In a couple of places we passed None as a dummy for for `otherfields` where {} is just as good. 2. Turned bytes into hex for errors. 3. Remove nonsensical (unused) get_tlv_by_number() function from MessageNamespace 4. Renamed unrelated-but-overlapping `field_from_csv` and `type_from_csv` static methods, since mypy thought they should have the same type. 5. Unknown tlv fields are placed in dict as strings, not ints, for type simplicity. Signed-off-by: Rusty Russell <[email protected]>
1 parent bef68d3 commit 2ead207

File tree

3 files changed

+160
-156
lines changed

3 files changed

+160
-156
lines changed

contrib/pyln-proto/pyln/proto/message/array_types.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from .fundamental_types import FieldType, IntegerType, split_field
2+
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any
3+
from io import BufferedIOBase
4+
if TYPE_CHECKING:
5+
from .message import SubtypeType, TlvStreamType
26

37

48
class ArrayType(FieldType):
@@ -8,11 +12,11 @@ class ArrayType(FieldType):
812
wants an array of some type.
913
1014
"""
11-
def __init__(self, outer, name, elemtype):
15+
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType):
1216
super().__init__("{}.{}".format(outer.name, name))
1317
self.elemtype = elemtype
1418

15-
def val_from_str(self, s):
19+
def val_from_str(self, s: str) -> Tuple[List[Any], str]:
1620
# Simple arrays of bytes don't need commas
1721
if self.elemtype.name == 'byte':
1822
a, b = split_field(s)
@@ -30,20 +34,20 @@ def val_from_str(self, s):
3034
s = s[1:]
3135
return ret, s[1:]
3236

33-
def val_to_str(self, v, otherfields):
37+
def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str:
3438
if self.elemtype.name == 'byte':
3539
return bytes(v).hex()
3640

3741
s = ','.join(self.elemtype.val_to_str(i, otherfields) for i in v)
3842
return '[' + s + ']'
3943

40-
def write(self, io_out, v, otherfields):
44+
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
4145
for i in v:
4246
self.elemtype.write(io_out, i, otherfields)
4347

44-
def read_arr(self, io_in, otherfields, arraysize):
48+
def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]:
4549
"""arraysize None means take rest of io entirely and exactly"""
46-
vals = []
50+
vals: List[Any] = []
4751
while arraysize is None or len(vals) < arraysize:
4852
# Throws an exception on partial read, so None means completely empty.
4953
val = self.elemtype.read(io_in, otherfields)
@@ -60,73 +64,73 @@ def read_arr(self, io_in, otherfields, arraysize):
6064

6165
class SizedArrayType(ArrayType):
6266
"""A fixed-size array"""
63-
def __init__(self, outer, name, elemtype, arraysize):
67+
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, arraysize: int):
6468
super().__init__(outer, name, elemtype)
6569
self.arraysize = arraysize
6670

67-
def val_to_str(self, v, otherfields):
71+
def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str:
6872
if len(v) != self.arraysize:
6973
raise ValueError("Length of {} != {}", v, self.arraysize)
7074
return super().val_to_str(v, otherfields)
7175

72-
def val_from_str(self, s):
76+
def val_from_str(self, s: str) -> Tuple[List[Any], str]:
7377
a, b = super().val_from_str(s)
7478
if len(a) != self.arraysize:
7579
raise ValueError("Length of {} != {}", s, self.arraysize)
7680
return a, b
7781

78-
def write(self, io_out, v, otherfields):
82+
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
7983
if len(v) != self.arraysize:
8084
raise ValueError("Length of {} != {}", v, self.arraysize)
8185
return super().write(io_out, v, otherfields)
8286

83-
def read(self, io_in, otherfields):
87+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
8488
return super().read_arr(io_in, otherfields, self.arraysize)
8589

8690

8791
class EllipsisArrayType(ArrayType):
8892
"""This is used for ... fields at the end of a tlv: the array ends
8993
when the tlv ends"""
90-
def __init__(self, tlv, name, elemtype):
94+
def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType):
9195
super().__init__(tlv, name, elemtype)
9296

93-
def read(self, io_in, otherfields):
97+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
9498
"""Takes rest of bytestream"""
9599
return super().read_arr(io_in, otherfields, None)
96100

97-
def only_at_tlv_end(self):
101+
def only_at_tlv_end(self) -> bool:
98102
"""These only make sense at the end of a TLV"""
99103
return True
100104

101105

102106
class LengthFieldType(FieldType):
103107
"""Special type to indicate this serves as a length field for others"""
104-
def __init__(self, inttype):
108+
def __init__(self, inttype: IntegerType):
105109
if type(inttype) is not IntegerType:
106110
raise ValueError("{} cannot be a length; not an integer!"
107111
.format(self.name))
108112
super().__init__(inttype.name)
109113
self.underlying_type = inttype
110114
# You can be length for more than one field!
111-
self.len_for = []
115+
self.len_for: List[DynamicArrayType] = []
112116

113-
def is_optional(self):
117+
def is_optional(self) -> bool:
114118
"""This field value is always implies, never specified directly"""
115119
return True
116120

117-
def add_length_for(self, field):
121+
def add_length_for(self, field: 'DynamicArrayType') -> None:
118122
assert isinstance(field.fieldtype, DynamicArrayType)
119123
self.len_for.append(field)
120124

121-
def calc_value(self, otherfields):
125+
def calc_value(self, otherfields: Dict[str, Any]) -> int:
122126
"""Calculate length value from field(s) themselves"""
123127
if self.len_fields_bad('', otherfields):
124128
raise ValueError("Lengths of fields {} not equal!"
125129
.format(self.len_for))
126130

127131
return len(otherfields[self.len_for[0].name])
128132

129-
def _maybe_calc_value(self, fieldname, otherfields):
133+
def _maybe_calc_value(self, fieldname: str, otherfields: Dict[str, Any]) -> int:
130134
# Perhaps we're just demarshalling from binary now, so we actually
131135
# stored it. Remove, and we'll calc from now on.
132136
if fieldname in otherfields:
@@ -135,27 +139,27 @@ def _maybe_calc_value(self, fieldname, otherfields):
135139
return v
136140
return self.calc_value(otherfields)
137141

138-
def val_to_str(self, _, otherfields):
142+
def val_to_str(self, _, otherfields: Dict[str, Any]) -> str:
139143
return self.underlying_type.val_to_str(self.calc_value(otherfields),
140144
otherfields)
141145

142-
def name_and_val(self, name, v):
146+
def name_and_val(self, name: str, v: int) -> str:
143147
"""We don't print out length fields when printing out messages:
144148
they're implied by the length of other fields"""
145149
return ''
146150

147-
def read(self, io_in, otherfields):
151+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
148152
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
149153
return self.underlying_type.read(io_in, otherfields)
150154

151-
def write(self, io_out, _, otherfields):
155+
def write(self, io_out: BufferedIOBase, _, otherfields: Dict[str, Any]) -> None:
152156
self.underlying_type.write(io_out, self.calc_value(otherfields),
153157
otherfields)
154158

155-
def val_from_str(self, s):
159+
def val_from_str(self, s: str):
156160
raise ValueError('{} is implied, cannot be specified'.format(self))
157161

158-
def len_fields_bad(self, fieldname, otherfields):
162+
def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]:
159163
"""fieldname is the name to return if this length is bad"""
160164
mylen = None
161165
for lens in self.len_for:
@@ -170,11 +174,11 @@ def len_fields_bad(self, fieldname, otherfields):
170174

171175
class DynamicArrayType(ArrayType):
172176
"""This is used for arrays where another field controls the size"""
173-
def __init__(self, outer, name, elemtype, lenfield):
177+
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType):
174178
super().__init__(outer, name, elemtype)
175179
assert type(lenfield.fieldtype) is LengthFieldType
176180
self.lenfield = lenfield
177181

178-
def read(self, io_in, otherfields):
182+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
179183
return super().read_arr(io_in, otherfields,
180184
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))

contrib/pyln-proto/pyln/proto/message/fundamental_types.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import struct
2-
import io
2+
from io import BufferedIOBase
33
import sys
4-
from typing import Optional
4+
from typing import Dict, Optional, Tuple, List, Any
55

66

77
def try_unpack(name: str,
8-
io_out: io.BufferedIOBase,
8+
io_out: BufferedIOBase,
99
structfmt: str,
1010
empty_ok: bool) -> Optional[int]:
1111
"""Unpack a single value using struct.unpack.
@@ -20,7 +20,7 @@ def try_unpack(name: str,
2020
return struct.unpack(structfmt, b)[0]
2121

2222

23-
def split_field(s):
23+
def split_field(s: str) -> Tuple[str, str]:
2424
"""Helper to split string into first part and remainder"""
2525
def len_without(s, delim):
2626
pos = s.find(delim)
@@ -37,25 +37,28 @@ class FieldType(object):
3737
These are further specialized.
3838
3939
"""
40-
def __init__(self, name):
40+
def __init__(self, name: str):
4141
self.name = name
4242

43-
def only_at_tlv_end(self):
43+
def only_at_tlv_end(self) -> bool:
4444
"""Some types only make sense inside a tlv, at the end"""
4545
return False
4646

47-
def name_and_val(self, name, v):
47+
def name_and_val(self, name: str, v: Any) -> str:
4848
"""This is overridden by LengthFieldType to return nothing"""
49-
return " {}={}".format(name, self.val_to_str(v, None))
49+
return " {}={}".format(name, self.val_to_str(v, {}))
5050

51-
def is_optional(self):
51+
def is_optional(self) -> bool:
5252
"""Overridden for tlv fields and optional fields"""
5353
return False
5454

55-
def len_fields_bad(self, fieldname, fieldvals):
55+
def len_fields_bad(self, fieldname: str, fieldvals: Dict[str, Any]) -> List[str]:
5656
"""Overridden by length fields for arrays"""
5757
return []
5858

59+
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
60+
raise NotImplementedError()
61+
5962
def __str__(self):
6063
return self.name
6164

@@ -64,22 +67,22 @@ def __repr__(self):
6467

6568

6669
class IntegerType(FieldType):
67-
def __init__(self, name, bytelen, structfmt):
70+
def __init__(self, name: str, bytelen: int, structfmt: str):
6871
super().__init__(name)
6972
self.bytelen = bytelen
7073
self.structfmt = structfmt
7174

72-
def val_to_str(self, v, otherfields):
75+
def val_to_str(self, v: int, otherfields: Dict[str, Any]):
7376
return "{}".format(int(v))
7477

75-
def val_from_str(self, s):
78+
def val_from_str(self, s: str) -> Tuple[int, str]:
7679
a, b = split_field(s)
7780
return int(a), b
7881

79-
def write(self, io_out, v, otherfields):
82+
def write(self, io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any]) -> None:
8083
io_out.write(struct.pack(self.structfmt, v))
8184

82-
def read(self, io_in, otherfields):
85+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]:
8386
return try_unpack(self.name, io_in, self.structfmt, empty_ok=True)
8487

8588

@@ -91,11 +94,11 @@ class ShortChannelIDType(IntegerType):
9194
def __init__(self, name):
9295
super().__init__(name, 8, '>Q')
9396

94-
def val_to_str(self, v, otherfields):
97+
def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str:
9598
# See BOLT #7: ## Definition of `short_channel_id`
9699
return "{}x{}x{}".format(v >> 40, (v >> 16) & 0xFFFFFF, v & 0xFFFF)
97100

98-
def val_from_str(self, s):
101+
def val_from_str(self, s: str) -> Tuple[int, str]:
99102
a, b = split_field(s)
100103
parts = a.split('x')
101104
if len(parts) != 3:
@@ -107,25 +110,25 @@ def val_from_str(self, s):
107110

108111
class TruncatedIntType(FieldType):
109112
"""Truncated integer types"""
110-
def __init__(self, name, maxbytes):
113+
def __init__(self, name: str, maxbytes: int):
111114
super().__init__(name)
112115
self.maxbytes = maxbytes
113116

114-
def val_to_str(self, v, otherfields):
117+
def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str:
115118
return "{}".format(int(v))
116119

117-
def only_at_tlv_end(self):
120+
def only_at_tlv_end(self) -> bool:
118121
"""These only make sense at the end of a TLV"""
119122
return True
120123

121-
def val_from_str(self, s):
124+
def val_from_str(self, s: str) -> Tuple[int, str]:
122125
a, b = split_field(s)
123126
if int(a) >= (1 << (self.maxbytes * 8)):
124127
raise ValueError('{} exceeds maximum {} capacity'
125128
.format(a, self.name))
126129
return int(a), b
127130

128-
def write(self, io_out, v, otherfields):
131+
def write(self, io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any]) -> None:
129132
binval = struct.pack('>Q', v)
130133
while len(binval) != 0 and binval[0] == 0:
131134
binval = binval[1:]
@@ -134,41 +137,41 @@ def write(self, io_out, v, otherfields):
134137
.format(v, self.name))
135138
io_out.write(binval)
136139

137-
def read(self, io_in, otherfields):
140+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
138141
binval = io_in.read()
139142
if len(binval) > self.maxbytes:
140-
raise ValueError('{} is too long for {}'.format(binval, self.name))
143+
raise ValueError('{} is too long for {}'.format(binval.hex(), self.name))
141144
if len(binval) > 0 and binval[0] == 0:
142145
raise ValueError('{} encoding is not minimal: {}'
143-
.format(self.name, binval))
146+
.format(self.name, binval.hex()))
144147
# Pad with zeroes and convert as u64
145148
return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0]
146149

147150

148151
class FundamentalHexType(FieldType):
149152
"""The remaining fundamental types are simply represented as hex strings"""
150-
def __init__(self, name, bytelen):
153+
def __init__(self, name: str, bytelen: int):
151154
super().__init__(name)
152155
self.bytelen = bytelen
153156

154-
def val_to_str(self, v, otherfields):
157+
def val_to_str(self, v: bytes, otherfields: Dict[str, Any]) -> str:
155158
if len(bytes(v)) != self.bytelen:
156159
raise ValueError("Length of {} != {}", v, self.bytelen)
157160
return v.hex()
158161

159-
def val_from_str(self, s):
162+
def val_from_str(self, s: str) -> Tuple[bytes, str]:
160163
a, b = split_field(s)
161164
ret = bytes.fromhex(a)
162165
if len(ret) != self.bytelen:
163166
raise ValueError("Length of {} != {}", a, self.bytelen)
164167
return ret, b
165168

166-
def write(self, io_out, v, otherfields):
169+
def write(self, io_out: BufferedIOBase, v: bytes, otherfields: Dict[str, Any]) -> None:
167170
if len(bytes(v)) != self.bytelen:
168171
raise ValueError("Length of {} != {}", v, self.bytelen)
169172
io_out.write(v)
170173

171-
def read(self, io_in, otherfields):
174+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[bytes]:
172175
val = io_in.read(self.bytelen)
173176
if len(val) == 0:
174177
return None
@@ -182,13 +185,13 @@ class BigSizeType(FieldType):
182185
def __init__(self, name):
183186
super().__init__(name)
184187

185-
def val_from_str(self, s):
188+
def val_from_str(self, s: str) -> Tuple[int, str]:
186189
a, b = split_field(s)
187190
return int(a), b
188191

189192
# For the convenience of TLV header parsing
190193
@staticmethod
191-
def write(io_out, v, otherfields=None):
194+
def write(io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any] = {}) -> None:
192195
if v < 253:
193196
io_out.write(bytes([v]))
194197
elif v < 2**16:
@@ -199,7 +202,7 @@ def write(io_out, v, otherfields=None):
199202
io_out.write(bytes([255]) + struct.pack('>Q', v))
200203

201204
@staticmethod
202-
def read(io_in, otherfields=None):
205+
def read(io_in: BufferedIOBase, otherfields: Dict[str, Any] = {}) -> Optional[int]:
203206
"Returns value, or None on EOF"
204207
b = io_in.read(1)
205208
if len(b) == 0:
@@ -213,7 +216,7 @@ def read(io_in, otherfields=None):
213216
else:
214217
return try_unpack('BigSize', io_in, '>Q', empty_ok=False)
215218

216-
def val_to_str(self, v, otherfields):
219+
def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str:
217220
return "{}".format(int(v))
218221

219222

0 commit comments

Comments
 (0)