1
1
from .fundamental_types import FieldType , IntegerType , split_field
2
+ from typing import List , Optional , Dict , Tuple , TYPE_CHECKING , Any
3
+ from io import BufferedIOBase
4
+ if TYPE_CHECKING :
5
+ from .message import SubtypeType , TlvStreamType
2
6
3
7
4
8
class ArrayType (FieldType ):
@@ -8,11 +12,11 @@ class ArrayType(FieldType):
8
12
wants an array of some type.
9
13
10
14
"""
11
- def __init__ (self , outer , name , elemtype ):
15
+ def __init__ (self , outer : 'SubtypeType' , name : str , elemtype : FieldType ):
12
16
super ().__init__ ("{}.{}" .format (outer .name , name ))
13
17
self .elemtype = elemtype
14
18
15
- def val_from_str (self , s ) :
19
+ def val_from_str (self , s : str ) -> Tuple [ List [ Any ], str ] :
16
20
# Simple arrays of bytes don't need commas
17
21
if self .elemtype .name == 'byte' :
18
22
a , b = split_field (s )
@@ -30,20 +34,20 @@ def val_from_str(self, s):
30
34
s = s [1 :]
31
35
return ret , s [1 :]
32
36
33
- def val_to_str (self , v , otherfields ) :
37
+ def val_to_str (self , v : List [ Any ] , otherfields : Dict [ str , Any ]) -> str :
34
38
if self .elemtype .name == 'byte' :
35
39
return bytes (v ).hex ()
36
40
37
41
s = ',' .join (self .elemtype .val_to_str (i , otherfields ) for i in v )
38
42
return '[' + s + ']'
39
43
40
- def write (self , io_out , v , otherfields ) :
44
+ def write (self , io_out : BufferedIOBase , v : List [ Any ] , otherfields : Dict [ str , Any ]) -> None :
41
45
for i in v :
42
46
self .elemtype .write (io_out , i , otherfields )
43
47
44
- def read_arr (self , io_in , otherfields , arraysize ) :
48
+ def read_arr (self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ], arraysize : Optional [ int ]) -> List [ Any ] :
45
49
"""arraysize None means take rest of io entirely and exactly"""
46
- vals = []
50
+ vals : List [ Any ] = []
47
51
while arraysize is None or len (vals ) < arraysize :
48
52
# Throws an exception on partial read, so None means completely empty.
49
53
val = self .elemtype .read (io_in , otherfields )
@@ -60,73 +64,73 @@ def read_arr(self, io_in, otherfields, arraysize):
60
64
61
65
class SizedArrayType (ArrayType ):
62
66
"""A fixed-size array"""
63
- def __init__ (self , outer , name , elemtype , arraysize ):
67
+ def __init__ (self , outer : 'SubtypeType' , name : str , elemtype : FieldType , arraysize : int ):
64
68
super ().__init__ (outer , name , elemtype )
65
69
self .arraysize = arraysize
66
70
67
- def val_to_str (self , v , otherfields ) :
71
+ def val_to_str (self , v : List [ Any ] , otherfields : Dict [ str , Any ]) -> str :
68
72
if len (v ) != self .arraysize :
69
73
raise ValueError ("Length of {} != {}" , v , self .arraysize )
70
74
return super ().val_to_str (v , otherfields )
71
75
72
- def val_from_str (self , s ) :
76
+ def val_from_str (self , s : str ) -> Tuple [ List [ Any ], str ] :
73
77
a , b = super ().val_from_str (s )
74
78
if len (a ) != self .arraysize :
75
79
raise ValueError ("Length of {} != {}" , s , self .arraysize )
76
80
return a , b
77
81
78
- def write (self , io_out , v , otherfields ) :
82
+ def write (self , io_out : BufferedIOBase , v : List [ Any ] , otherfields : Dict [ str , Any ]) -> None :
79
83
if len (v ) != self .arraysize :
80
84
raise ValueError ("Length of {} != {}" , v , self .arraysize )
81
85
return super ().write (io_out , v , otherfields )
82
86
83
- def read (self , io_in , otherfields ) :
87
+ def read (self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ]) -> List [ Any ] :
84
88
return super ().read_arr (io_in , otherfields , self .arraysize )
85
89
86
90
87
91
class EllipsisArrayType (ArrayType ):
88
92
"""This is used for ... fields at the end of a tlv: the array ends
89
93
when the tlv ends"""
90
- def __init__ (self , tlv , name , elemtype ):
94
+ def __init__ (self , tlv : 'TlvStreamType' , name : str , elemtype : FieldType ):
91
95
super ().__init__ (tlv , name , elemtype )
92
96
93
- def read (self , io_in , otherfields ) :
97
+ def read (self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ]) -> List [ Any ] :
94
98
"""Takes rest of bytestream"""
95
99
return super ().read_arr (io_in , otherfields , None )
96
100
97
- def only_at_tlv_end (self ):
101
+ def only_at_tlv_end (self ) -> bool :
98
102
"""These only make sense at the end of a TLV"""
99
103
return True
100
104
101
105
102
106
class LengthFieldType (FieldType ):
103
107
"""Special type to indicate this serves as a length field for others"""
104
- def __init__ (self , inttype ):
108
+ def __init__ (self , inttype : IntegerType ):
105
109
if type (inttype ) is not IntegerType :
106
110
raise ValueError ("{} cannot be a length; not an integer!"
107
111
.format (self .name ))
108
112
super ().__init__ (inttype .name )
109
113
self .underlying_type = inttype
110
114
# You can be length for more than one field!
111
- self .len_for = []
115
+ self .len_for : List [ DynamicArrayType ] = []
112
116
113
- def is_optional (self ):
117
+ def is_optional (self ) -> bool :
114
118
"""This field value is always implies, never specified directly"""
115
119
return True
116
120
117
- def add_length_for (self , field ) :
121
+ def add_length_for (self , field : 'DynamicArrayType' ) -> None :
118
122
assert isinstance (field .fieldtype , DynamicArrayType )
119
123
self .len_for .append (field )
120
124
121
- def calc_value (self , otherfields ) :
125
+ def calc_value (self , otherfields : Dict [ str , Any ]) -> int :
122
126
"""Calculate length value from field(s) themselves"""
123
127
if self .len_fields_bad ('' , otherfields ):
124
128
raise ValueError ("Lengths of fields {} not equal!"
125
129
.format (self .len_for ))
126
130
127
131
return len (otherfields [self .len_for [0 ].name ])
128
132
129
- def _maybe_calc_value (self , fieldname , otherfields ) :
133
+ def _maybe_calc_value (self , fieldname : str , otherfields : Dict [ str , Any ]) -> int :
130
134
# Perhaps we're just demarshalling from binary now, so we actually
131
135
# stored it. Remove, and we'll calc from now on.
132
136
if fieldname in otherfields :
@@ -135,27 +139,27 @@ def _maybe_calc_value(self, fieldname, otherfields):
135
139
return v
136
140
return self .calc_value (otherfields )
137
141
138
- def val_to_str (self , _ , otherfields ) :
142
+ def val_to_str (self , _ , otherfields : Dict [ str , Any ]) -> str :
139
143
return self .underlying_type .val_to_str (self .calc_value (otherfields ),
140
144
otherfields )
141
145
142
- def name_and_val (self , name , v ) :
146
+ def name_and_val (self , name : str , v : int ) -> str :
143
147
"""We don't print out length fields when printing out messages:
144
148
they're implied by the length of other fields"""
145
149
return ''
146
150
147
- def read (self , io_in , otherfields ) :
151
+ def read (self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ]) -> None :
148
152
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
149
153
return self .underlying_type .read (io_in , otherfields )
150
154
151
- def write (self , io_out , _ , otherfields ) :
155
+ def write (self , io_out : BufferedIOBase , _ , otherfields : Dict [ str , Any ]) -> None :
152
156
self .underlying_type .write (io_out , self .calc_value (otherfields ),
153
157
otherfields )
154
158
155
- def val_from_str (self , s ):
159
+ def val_from_str (self , s : str ):
156
160
raise ValueError ('{} is implied, cannot be specified' .format (self ))
157
161
158
- def len_fields_bad (self , fieldname , otherfields ) :
162
+ def len_fields_bad (self , fieldname : str , otherfields : Dict [ str , Any ]) -> List [ str ] :
159
163
"""fieldname is the name to return if this length is bad"""
160
164
mylen = None
161
165
for lens in self .len_for :
@@ -170,11 +174,11 @@ def len_fields_bad(self, fieldname, otherfields):
170
174
171
175
class DynamicArrayType (ArrayType ):
172
176
"""This is used for arrays where another field controls the size"""
173
- def __init__ (self , outer , name , elemtype , lenfield ):
177
+ def __init__ (self , outer : 'SubtypeType' , name : str , elemtype : FieldType , lenfield : LengthFieldType ):
174
178
super ().__init__ (outer , name , elemtype )
175
179
assert type (lenfield .fieldtype ) is LengthFieldType
176
180
self .lenfield = lenfield
177
181
178
- def read (self , io_in , otherfields ) :
182
+ def read (self , io_in : BufferedIOBase , otherfields : Dict [ str , Any ]) -> List [ Any ] :
179
183
return super ().read_arr (io_in , otherfields ,
180
184
self .lenfield .fieldtype ._maybe_calc_value (self .lenfield .name , otherfields ))
0 commit comments