Skip to content

Commit f520652

Browse files
rustyrussellcdecker
authored andcommitted
pyln.proto.message.*: add type annotations.
Other changes along the way: 1. In a couple of places we passed None as a dummy for for `otherfields` where {} is just as good. 2. Turned bytes into hex for errors. 3. Remove nonsensical (unused) get_tlv_by_number() function from MessageNamespace 4. Renamed unrelated-but-overlapping `field_from_csv` and `type_from_csv` static methods, since mypy thought they should have the same type. 5. Unknown tlv fields are placed in dict as strings, not ints, for type simplicity. Signed-off-by: Rusty Russell <[email protected]>
1 parent da070e7 commit f520652

File tree

4 files changed

+161
-156
lines changed

4 files changed

+161
-156
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
*.po
1212
*.pyc
1313
.cppcheck-suppress
14+
.mypy_cache
1415
TAGS
1516
tags
1617
ccan/tools/configurator/configurator

contrib/pyln-proto/pyln/proto/message/array_types.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from .fundamental_types import FieldType, IntegerType, split_field
2+
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any
3+
from io import BufferedIOBase
4+
if TYPE_CHECKING:
5+
from .message import SubtypeType, TlvStreamType
26

37

48
class ArrayType(FieldType):
@@ -8,11 +12,11 @@ class ArrayType(FieldType):
812
wants an array of some type.
913
1014
"""
11-
def __init__(self, outer, name, elemtype):
15+
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType):
1216
super().__init__("{}.{}".format(outer.name, name))
1317
self.elemtype = elemtype
1418

15-
def val_from_str(self, s):
19+
def val_from_str(self, s: str) -> Tuple[List[Any], str]:
1620
# Simple arrays of bytes don't need commas
1721
if self.elemtype.name == 'byte':
1822
a, b = split_field(s)
@@ -30,20 +34,20 @@ def val_from_str(self, s):
3034
s = s[1:]
3135
return ret, s[1:]
3236

33-
def val_to_str(self, v, otherfields):
37+
def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str:
3438
if self.elemtype.name == 'byte':
3539
return bytes(v).hex()
3640

3741
s = ','.join(self.elemtype.val_to_str(i, otherfields) for i in v)
3842
return '[' + s + ']'
3943

40-
def write(self, io_out, v, otherfields):
44+
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
4145
for i in v:
4246
self.elemtype.write(io_out, i, otherfields)
4347

44-
def read_arr(self, io_in, otherfields, arraysize):
48+
def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]:
4549
"""arraysize None means take rest of io entirely and exactly"""
46-
vals = []
50+
vals: List[Any] = []
4751
while arraysize is None or len(vals) < arraysize:
4852
# Throws an exception on partial read, so None means completely empty.
4953
val = self.elemtype.read(io_in, otherfields)
@@ -60,73 +64,73 @@ def read_arr(self, io_in, otherfields, arraysize):
6064

6165
class SizedArrayType(ArrayType):
6266
"""A fixed-size array"""
63-
def __init__(self, outer, name, elemtype, arraysize):
67+
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, arraysize: int):
6468
super().__init__(outer, name, elemtype)
6569
self.arraysize = arraysize
6670

67-
def val_to_str(self, v, otherfields):
71+
def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str:
6872
if len(v) != self.arraysize:
6973
raise ValueError("Length of {} != {}", v, self.arraysize)
7074
return super().val_to_str(v, otherfields)
7175

72-
def val_from_str(self, s):
76+
def val_from_str(self, s: str) -> Tuple[List[Any], str]:
7377
a, b = super().val_from_str(s)
7478
if len(a) != self.arraysize:
7579
raise ValueError("Length of {} != {}", s, self.arraysize)
7680
return a, b
7781

78-
def write(self, io_out, v, otherfields):
82+
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
7983
if len(v) != self.arraysize:
8084
raise ValueError("Length of {} != {}", v, self.arraysize)
8185
return super().write(io_out, v, otherfields)
8286

83-
def read(self, io_in, otherfields):
87+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
8488
return super().read_arr(io_in, otherfields, self.arraysize)
8589

8690

8791
class EllipsisArrayType(ArrayType):
8892
"""This is used for ... fields at the end of a tlv: the array ends
8993
when the tlv ends"""
90-
def __init__(self, tlv, name, elemtype):
94+
def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType):
9195
super().__init__(tlv, name, elemtype)
9296

93-
def read(self, io_in, otherfields):
97+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
9498
"""Takes rest of bytestream"""
9599
return super().read_arr(io_in, otherfields, None)
96100

97-
def only_at_tlv_end(self):
101+
def only_at_tlv_end(self) -> bool:
98102
"""These only make sense at the end of a TLV"""
99103
return True
100104

101105

102106
class LengthFieldType(FieldType):
103107
"""Special type to indicate this serves as a length field for others"""
104-
def __init__(self, inttype):
108+
def __init__(self, inttype: IntegerType):
105109
if type(inttype) is not IntegerType:
106110
raise ValueError("{} cannot be a length; not an integer!"
107111
.format(self.name))
108112
super().__init__(inttype.name)
109113
self.underlying_type = inttype
110114
# You can be length for more than one field!
111-
self.len_for = []
115+
self.len_for: List[DynamicArrayType] = []
112116

113-
def is_optional(self):
117+
def is_optional(self) -> bool:
114118
"""This field value is always implies, never specified directly"""
115119
return True
116120

117-
def add_length_for(self, field):
121+
def add_length_for(self, field: 'DynamicArrayType') -> None:
118122
assert isinstance(field.fieldtype, DynamicArrayType)
119123
self.len_for.append(field)
120124

121-
def calc_value(self, otherfields):
125+
def calc_value(self, otherfields: Dict[str, Any]) -> int:
122126
"""Calculate length value from field(s) themselves"""
123127
if self.len_fields_bad('', otherfields):
124128
raise ValueError("Lengths of fields {} not equal!"
125129
.format(self.len_for))
126130

127131
return len(otherfields[self.len_for[0].name])
128132

129-
def _maybe_calc_value(self, fieldname, otherfields):
133+
def _maybe_calc_value(self, fieldname: str, otherfields: Dict[str, Any]) -> int:
130134
# Perhaps we're just demarshalling from binary now, so we actually
131135
# stored it. Remove, and we'll calc from now on.
132136
if fieldname in otherfields:
@@ -135,27 +139,27 @@ def _maybe_calc_value(self, fieldname, otherfields):
135139
return v
136140
return self.calc_value(otherfields)
137141

138-
def val_to_str(self, _, otherfields):
142+
def val_to_str(self, _, otherfields: Dict[str, Any]) -> str:
139143
return self.underlying_type.val_to_str(self.calc_value(otherfields),
140144
otherfields)
141145

142-
def name_and_val(self, name, v):
146+
def name_and_val(self, name: str, v: int) -> str:
143147
"""We don't print out length fields when printing out messages:
144148
they're implied by the length of other fields"""
145149
return ''
146150

147-
def read(self, io_in, otherfields):
151+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
148152
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
149153
return self.underlying_type.read(io_in, otherfields)
150154

151-
def write(self, io_out, _, otherfields):
155+
def write(self, io_out: BufferedIOBase, _, otherfields: Dict[str, Any]) -> None:
152156
self.underlying_type.write(io_out, self.calc_value(otherfields),
153157
otherfields)
154158

155-
def val_from_str(self, s):
159+
def val_from_str(self, s: str):
156160
raise ValueError('{} is implied, cannot be specified'.format(self))
157161

158-
def len_fields_bad(self, fieldname, otherfields):
162+
def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]:
159163
"""fieldname is the name to return if this length is bad"""
160164
mylen = None
161165
for lens in self.len_for:
@@ -170,11 +174,11 @@ def len_fields_bad(self, fieldname, otherfields):
170174

171175
class DynamicArrayType(ArrayType):
172176
"""This is used for arrays where another field controls the size"""
173-
def __init__(self, outer, name, elemtype, lenfield):
177+
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType):
174178
super().__init__(outer, name, elemtype)
175179
assert type(lenfield.fieldtype) is LengthFieldType
176180
self.lenfield = lenfield
177181

178-
def read(self, io_in, otherfields):
182+
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
179183
return super().read_arr(io_in, otherfields,
180184
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))

0 commit comments

Comments
 (0)