diff --git a/.gitignore b/.gitignore index 2a7ed297a6df..b2e02920c52a 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ *.po *.pyc .cppcheck-suppress +.mypy_cache TAGS tags ccan/tools/configurator/configurator diff --git a/contrib/pyln-proto/Makefile b/contrib/pyln-proto/Makefile new file mode 100644 index 000000000000..6cd09e987eff --- /dev/null +++ b/contrib/pyln-proto/Makefile @@ -0,0 +1,13 @@ +#! /usr/bin/make + +check: + pytest + +check-source: check-flake8 check-mypy + +check-flake8: + flake8 --ignore=E501,E731,W503 + +# mypy . does not recurse. I have no idea why... +check-mypy: + mypy --ignore-missing-imports `find * -name '*.py'` diff --git a/contrib/pyln-proto/pyln/proto/message/Makefile b/contrib/pyln-proto/pyln/proto/message/Makefile new file mode 100644 index 000000000000..3e27da7e1719 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/Makefile @@ -0,0 +1,4 @@ +#! /usr/bin/make + +refresh: + for d in bolt*; do $(MAKE) -C $$d; done diff --git a/contrib/pyln-proto/pyln/proto/message/__init__.py b/contrib/pyln-proto/pyln/proto/message/__init__.py new file mode 100644 index 000000000000..286afb1947ef --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/__init__.py @@ -0,0 +1,33 @@ +from .array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType +from .message import MessageNamespace, MessageType, Message, SubtypeType +from .fundamental_types import split_field, FieldType + +__version__ = '0.0.1' + +__all__ = [ + "MessageNamespace", + "MessageType", + "Message", + "SubtypeType", + "FieldType", + "split_field", + "SizedArrayType", + "DynamicArrayType", + "EllipsisArrayType", + + # fundamental_types + 'byte', + 'u16', + 'u32', + 'u64', + 'tu16', + 'tu32', + 'tu64', + 'chain_hash', + 'channel_id', + 'sha256', + 'point', + 'short_channel_id', + 'signature', + 'bigsize', +] diff --git a/contrib/pyln-proto/pyln/proto/message/array_types.py b/contrib/pyln-proto/pyln/proto/message/array_types.py new file mode 100644 index 000000000000..077609dd4f6b --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/array_types.py @@ -0,0 +1,184 @@ +from .fundamental_types import FieldType, IntegerType, split_field +from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any +from io import BufferedIOBase +if TYPE_CHECKING: + from .message import SubtypeType, TlvStreamType + + +class ArrayType(FieldType): + """Abstract class for the different kinds of arrays. + +These are not in the namespace, but generated when a message says it +wants an array of some type. + + """ + def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType): + super().__init__("{}.{}".format(outer.name, name)) + self.elemtype = elemtype + + def val_from_str(self, s: str) -> Tuple[List[Any], str]: + # Simple arrays of bytes don't need commas + if self.elemtype.name == 'byte': + a, b = split_field(s) + return [b for b in bytes.fromhex(a)], b + + if not s.startswith('['): + raise ValueError("array of {} must be wrapped in '[]': bad {}" + .format(self.elemtype.name, s)) + s = s[1:] + ret = [] + while not s.startswith(']'): + val, s = self.elemtype.val_from_str(s) + ret.append(val) + if s[0] == ',': + s = s[1:] + return ret, s[1:] + + def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str: + if self.elemtype.name == 'byte': + return bytes(v).hex() + + s = ','.join(self.elemtype.val_to_str(i, otherfields) for i in v) + return '[' + s + ']' + + def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None: + for i in v: + self.elemtype.write(io_out, i, otherfields) + + def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]: + """arraysize None means take rest of io entirely and exactly""" + vals: List[Any] = [] + while arraysize is None or len(vals) < arraysize: + # Throws an exception on partial read, so None means completely empty. + val = self.elemtype.read(io_in, otherfields) + if val is None: + if arraysize is not None: + raise ValueError('{}: not enough remaining to read' + .format(self)) + break + + vals.append(val) + + return vals + + +class SizedArrayType(ArrayType): + """A fixed-size array""" + def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, arraysize: int): + super().__init__(outer, name, elemtype) + self.arraysize = arraysize + + def val_to_str(self, v: List[Any], otherfields: Dict[str, Any]) -> str: + if len(v) != self.arraysize: + raise ValueError("Length of {} != {}", v, self.arraysize) + return super().val_to_str(v, otherfields) + + def val_from_str(self, s: str) -> Tuple[List[Any], str]: + a, b = super().val_from_str(s) + if len(a) != self.arraysize: + raise ValueError("Length of {} != {}", s, self.arraysize) + return a, b + + def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None: + if len(v) != self.arraysize: + raise ValueError("Length of {} != {}", v, self.arraysize) + return super().write(io_out, v, otherfields) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: + return super().read_arr(io_in, otherfields, self.arraysize) + + +class EllipsisArrayType(ArrayType): + """This is used for ... fields at the end of a tlv: the array ends +when the tlv ends""" + def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType): + super().__init__(tlv, name, elemtype) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: + """Takes rest of bytestream""" + return super().read_arr(io_in, otherfields, None) + + def only_at_tlv_end(self) -> bool: + """These only make sense at the end of a TLV""" + return True + + +class LengthFieldType(FieldType): + """Special type to indicate this serves as a length field for others""" + def __init__(self, inttype: IntegerType): + if type(inttype) is not IntegerType: + raise ValueError("{} cannot be a length; not an integer!" + .format(self.name)) + super().__init__(inttype.name) + self.underlying_type = inttype + # You can be length for more than one field! + self.len_for: List[DynamicArrayType] = [] + + def is_optional(self) -> bool: + """This field value is always implies, never specified directly""" + return True + + def add_length_for(self, field: 'DynamicArrayType') -> None: + assert isinstance(field.fieldtype, DynamicArrayType) + self.len_for.append(field) + + def calc_value(self, otherfields: Dict[str, Any]) -> int: + """Calculate length value from field(s) themselves""" + if self.len_fields_bad('', otherfields): + raise ValueError("Lengths of fields {} not equal!" + .format(self.len_for)) + + return len(otherfields[self.len_for[0].name]) + + def _maybe_calc_value(self, fieldname: str, otherfields: Dict[str, Any]) -> int: + # Perhaps we're just demarshalling from binary now, so we actually + # stored it. Remove, and we'll calc from now on. + if fieldname in otherfields: + v = otherfields[fieldname] + del otherfields[fieldname] + return v + return self.calc_value(otherfields) + + def val_to_str(self, _, otherfields: Dict[str, Any]) -> str: + return self.underlying_type.val_to_str(self.calc_value(otherfields), + otherfields) + + def name_and_val(self, name: str, v: int) -> str: + """We don't print out length fields when printing out messages: +they're implied by the length of other fields""" + return '' + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None: + """We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)""" + return self.underlying_type.read(io_in, otherfields) + + def write(self, io_out: BufferedIOBase, _, otherfields: Dict[str, Any]) -> None: + self.underlying_type.write(io_out, self.calc_value(otherfields), + otherfields) + + def val_from_str(self, s: str): + raise ValueError('{} is implied, cannot be specified'.format(self)) + + def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]: + """fieldname is the name to return if this length is bad""" + mylen = None + for lens in self.len_for: + if mylen is not None: + if mylen != len(otherfields[lens.name]): + return [fieldname] + # Field might be missing! + if lens.name in otherfields: + mylen = len(otherfields[lens.name]) + return [] + + +class DynamicArrayType(ArrayType): + """This is used for arrays where another field controls the size""" + def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType): + super().__init__(outer, name, elemtype) + assert type(lenfield.fieldtype) is LengthFieldType + self.lenfield = lenfield + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]: + return super().read_arr(io_in, otherfields, + self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields)) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt1/Makefile b/contrib/pyln-proto/pyln/proto/message/bolt1/Makefile new file mode 100755 index 000000000000..7992280f2b2f --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt1/Makefile @@ -0,0 +1,7 @@ +#! /usr/bin/make + +SPECDIR := ../../../../../../../lightning-rfc + +csv.py: $(SPECDIR)/01-messaging.md Makefile + SPECNUM=`basename $< | sed 's/-.*//'`; (echo csv = '['; python3 $(SPECDIR)/tools/extract-formats.py $< | sed 's/\(.*\)/ "\1",/'; echo ']') > $@ + chmod a+x $@ diff --git a/contrib/pyln-proto/pyln/proto/message/bolt1/__init__.py b/contrib/pyln-proto/pyln/proto/message/bolt1/__init__.py new file mode 100644 index 000000000000..2ba3aceb67ae --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt1/__init__.py @@ -0,0 +1,16 @@ +from .csv import csv +from .bolt import namespace +import sys + +__version__ = '0.0.1' + +__all__ = [ + 'csv', + 'namespace', +] + +mod = sys.modules[__name__] +for d in namespace.subtypes, namespace.tlvtypes, namespace.messagetypes: + for name in d: + setattr(mod, name, d[name]) + __all__.append(name) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt1/bolt.py b/contrib/pyln-proto/pyln/proto/message/bolt1/bolt.py new file mode 100644 index 000000000000..565c41228744 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt1/bolt.py @@ -0,0 +1,5 @@ +from pyln.proto.message import MessageNamespace +from .csv import csv + + +namespace = MessageNamespace(csv_lines=csv) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt1/csv.py b/contrib/pyln-proto/pyln/proto/message/bolt1/csv.py new file mode 100755 index 000000000000..4c82899920b9 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt1/csv.py @@ -0,0 +1,35 @@ +csv = [ + "msgtype,init,16", + "msgdata,init,gflen,u16,", + "msgdata,init,globalfeatures,byte,gflen", + "msgdata,init,flen,u16,", + "msgdata,init,features,byte,flen", + "msgdata,init,tlvs,init_tlvs,", + "tlvtype,init_tlvs,networks,1", + "tlvdata,init_tlvs,networks,chains,chain_hash,...", + "msgtype,error,17", + "msgdata,error,channel_id,channel_id,", + "msgdata,error,len,u16,", + "msgdata,error,data,byte,len", + "msgtype,ping,18", + "msgdata,ping,num_pong_bytes,u16,", + "msgdata,ping,byteslen,u16,", + "msgdata,ping,ignored,byte,byteslen", + "msgtype,pong,19", + "msgdata,pong,byteslen,u16,", + "msgdata,pong,ignored,byte,byteslen", + "tlvtype,n1,tlv1,1", + "tlvdata,n1,tlv1,amount_msat,tu64,", + "tlvtype,n1,tlv2,2", + "tlvdata,n1,tlv2,scid,short_channel_id,", + "tlvtype,n1,tlv3,3", + "tlvdata,n1,tlv3,node_id,point,", + "tlvdata,n1,tlv3,amount_msat_1,u64,", + "tlvdata,n1,tlv3,amount_msat_2,u64,", + "tlvtype,n1,tlv4,254", + "tlvdata,n1,tlv4,cltv_delta,u16,", + "tlvtype,n2,tlv1,0", + "tlvdata,n2,tlv1,amount_msat,tu64,", + "tlvtype,n2,tlv2,11", + "tlvdata,n2,tlv2,cltv_expiry,tu32,", +] diff --git a/contrib/pyln-proto/pyln/proto/message/bolt2/Makefile b/contrib/pyln-proto/pyln/proto/message/bolt2/Makefile new file mode 100755 index 000000000000..832891543e60 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt2/Makefile @@ -0,0 +1,7 @@ +#! /usr/bin/make + +SPECDIR := ../../../../../../../lightning-rfc + +csv.py: $(SPECDIR)/02-peer-protocol.md Makefile + SPECNUM=`basename $< | sed 's/-.*//'`; (echo csv = '['; python3 $(SPECDIR)/tools/extract-formats.py $< | sed 's/\(.*\)/ "\1",/'; echo ']') > $@ + chmod a+x $@ diff --git a/contrib/pyln-proto/pyln/proto/message/bolt2/__init__.py b/contrib/pyln-proto/pyln/proto/message/bolt2/__init__.py new file mode 100644 index 000000000000..2ba3aceb67ae --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt2/__init__.py @@ -0,0 +1,16 @@ +from .csv import csv +from .bolt import namespace +import sys + +__version__ = '0.0.1' + +__all__ = [ + 'csv', + 'namespace', +] + +mod = sys.modules[__name__] +for d in namespace.subtypes, namespace.tlvtypes, namespace.messagetypes: + for name in d: + setattr(mod, name, d[name]) + __all__.append(name) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt2/bolt.py b/contrib/pyln-proto/pyln/proto/message/bolt2/bolt.py new file mode 100644 index 000000000000..565c41228744 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt2/bolt.py @@ -0,0 +1,5 @@ +from pyln.proto.message import MessageNamespace +from .csv import csv + + +namespace = MessageNamespace(csv_lines=csv) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt2/csv.py b/contrib/pyln-proto/pyln/proto/message/bolt2/csv.py new file mode 100755 index 000000000000..f43d75bbee3d --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt2/csv.py @@ -0,0 +1,100 @@ +csv = [ + "msgtype,open_channel,32", + "msgdata,open_channel,chain_hash,chain_hash,", + "msgdata,open_channel,temporary_channel_id,byte,32", + "msgdata,open_channel,funding_satoshis,u64,", + "msgdata,open_channel,push_msat,u64,", + "msgdata,open_channel,dust_limit_satoshis,u64,", + "msgdata,open_channel,max_htlc_value_in_flight_msat,u64,", + "msgdata,open_channel,channel_reserve_satoshis,u64,", + "msgdata,open_channel,htlc_minimum_msat,u64,", + "msgdata,open_channel,feerate_per_kw,u32,", + "msgdata,open_channel,to_self_delay,u16,", + "msgdata,open_channel,max_accepted_htlcs,u16,", + "msgdata,open_channel,funding_pubkey,point,", + "msgdata,open_channel,revocation_basepoint,point,", + "msgdata,open_channel,payment_basepoint,point,", + "msgdata,open_channel,delayed_payment_basepoint,point,", + "msgdata,open_channel,htlc_basepoint,point,", + "msgdata,open_channel,first_per_commitment_point,point,", + "msgdata,open_channel,channel_flags,byte,", + "msgdata,open_channel,tlvs,open_channel_tlvs,", + "tlvtype,open_channel_tlvs,upfront_shutdown_script,0", + "tlvdata,open_channel_tlvs,upfront_shutdown_script,shutdown_scriptpubkey,byte,...", + "msgtype,accept_channel,33", + "msgdata,accept_channel,temporary_channel_id,byte,32", + "msgdata,accept_channel,dust_limit_satoshis,u64,", + "msgdata,accept_channel,max_htlc_value_in_flight_msat,u64,", + "msgdata,accept_channel,channel_reserve_satoshis,u64,", + "msgdata,accept_channel,htlc_minimum_msat,u64,", + "msgdata,accept_channel,minimum_depth,u32,", + "msgdata,accept_channel,to_self_delay,u16,", + "msgdata,accept_channel,max_accepted_htlcs,u16,", + "msgdata,accept_channel,funding_pubkey,point,", + "msgdata,accept_channel,revocation_basepoint,point,", + "msgdata,accept_channel,payment_basepoint,point,", + "msgdata,accept_channel,delayed_payment_basepoint,point,", + "msgdata,accept_channel,htlc_basepoint,point,", + "msgdata,accept_channel,first_per_commitment_point,point,", + "msgdata,accept_channel,tlvs,accept_channel_tlvs,", + "tlvtype,accept_channel_tlvs,upfront_shutdown_script,0", + "tlvdata,accept_channel_tlvs,upfront_shutdown_script,shutdown_scriptpubkey,byte,...", + "msgtype,funding_created,34", + "msgdata,funding_created,temporary_channel_id,byte,32", + "msgdata,funding_created,funding_txid,sha256,", + "msgdata,funding_created,funding_output_index,u16,", + "msgdata,funding_created,signature,signature,", + "msgtype,funding_signed,35", + "msgdata,funding_signed,channel_id,channel_id,", + "msgdata,funding_signed,signature,signature,", + "msgtype,funding_locked,36", + "msgdata,funding_locked,channel_id,channel_id,", + "msgdata,funding_locked,next_per_commitment_point,point,", + "msgtype,shutdown,38", + "msgdata,shutdown,channel_id,channel_id,", + "msgdata,shutdown,len,u16,", + "msgdata,shutdown,scriptpubkey,byte,len", + "msgtype,closing_signed,39", + "msgdata,closing_signed,channel_id,channel_id,", + "msgdata,closing_signed,fee_satoshis,u64,", + "msgdata,closing_signed,signature,signature,", + "msgtype,update_add_htlc,128", + "msgdata,update_add_htlc,channel_id,channel_id,", + "msgdata,update_add_htlc,id,u64,", + "msgdata,update_add_htlc,amount_msat,u64,", + "msgdata,update_add_htlc,payment_hash,sha256,", + "msgdata,update_add_htlc,cltv_expiry,u32,", + "msgdata,update_add_htlc,onion_routing_packet,byte,1366", + "msgtype,update_fulfill_htlc,130", + "msgdata,update_fulfill_htlc,channel_id,channel_id,", + "msgdata,update_fulfill_htlc,id,u64,", + "msgdata,update_fulfill_htlc,payment_preimage,byte,32", + "msgtype,update_fail_htlc,131", + "msgdata,update_fail_htlc,channel_id,channel_id,", + "msgdata,update_fail_htlc,id,u64,", + "msgdata,update_fail_htlc,len,u16,", + "msgdata,update_fail_htlc,reason,byte,len", + "msgtype,update_fail_malformed_htlc,135", + "msgdata,update_fail_malformed_htlc,channel_id,channel_id,", + "msgdata,update_fail_malformed_htlc,id,u64,", + "msgdata,update_fail_malformed_htlc,sha256_of_onion,sha256,", + "msgdata,update_fail_malformed_htlc,failure_code,u16,", + "msgtype,commitment_signed,132", + "msgdata,commitment_signed,channel_id,channel_id,", + "msgdata,commitment_signed,signature,signature,", + "msgdata,commitment_signed,num_htlcs,u16,", + "msgdata,commitment_signed,htlc_signature,signature,num_htlcs", + "msgtype,revoke_and_ack,133", + "msgdata,revoke_and_ack,channel_id,channel_id,", + "msgdata,revoke_and_ack,per_commitment_secret,byte,32", + "msgdata,revoke_and_ack,next_per_commitment_point,point,", + "msgtype,update_fee,134", + "msgdata,update_fee,channel_id,channel_id,", + "msgdata,update_fee,feerate_per_kw,u32,", + "msgtype,channel_reestablish,136", + "msgdata,channel_reestablish,channel_id,channel_id,", + "msgdata,channel_reestablish,next_commitment_number,u64,", + "msgdata,channel_reestablish,next_revocation_number,u64,", + "msgdata,channel_reestablish,your_last_per_commitment_secret,byte,32", + "msgdata,channel_reestablish,my_current_per_commitment_point,point,", +] diff --git a/contrib/pyln-proto/pyln/proto/message/bolt4/Makefile b/contrib/pyln-proto/pyln/proto/message/bolt4/Makefile new file mode 100755 index 000000000000..3416819510b0 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt4/Makefile @@ -0,0 +1,7 @@ +#! /usr/bin/make + +SPECDIR := ../../../../../../../lightning-rfc + +csv.py: $(SPECDIR)/04-onion-routing.md Makefile + SPECNUM=`basename $< | sed 's/-.*//'`; (echo csv = '['; python3 $(SPECDIR)/tools/extract-formats.py $< | sed 's/\(.*\)/ "\1",/'; echo ']') > $@ + chmod a+x $@ diff --git a/contrib/pyln-proto/pyln/proto/message/bolt4/__init__.py b/contrib/pyln-proto/pyln/proto/message/bolt4/__init__.py new file mode 100644 index 000000000000..2ba3aceb67ae --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt4/__init__.py @@ -0,0 +1,16 @@ +from .csv import csv +from .bolt import namespace +import sys + +__version__ = '0.0.1' + +__all__ = [ + 'csv', + 'namespace', +] + +mod = sys.modules[__name__] +for d in namespace.subtypes, namespace.tlvtypes, namespace.messagetypes: + for name in d: + setattr(mod, name, d[name]) + __all__.append(name) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt4/bolt.py b/contrib/pyln-proto/pyln/proto/message/bolt4/bolt.py new file mode 100644 index 000000000000..565c41228744 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt4/bolt.py @@ -0,0 +1,5 @@ +from pyln.proto.message import MessageNamespace +from .csv import csv + + +namespace = MessageNamespace(csv_lines=csv) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt4/csv.py b/contrib/pyln-proto/pyln/proto/message/bolt4/csv.py new file mode 100755 index 000000000000..f51bcf829c77 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt4/csv.py @@ -0,0 +1,55 @@ +csv = [ + "tlvtype,tlv_payload,amt_to_forward,2", + "tlvdata,tlv_payload,amt_to_forward,amt_to_forward,tu64,", + "tlvtype,tlv_payload,outgoing_cltv_value,4", + "tlvdata,tlv_payload,outgoing_cltv_value,outgoing_cltv_value,tu32,", + "tlvtype,tlv_payload,short_channel_id,6", + "tlvdata,tlv_payload,short_channel_id,short_channel_id,short_channel_id,", + "tlvtype,tlv_payload,payment_data,8", + "tlvdata,tlv_payload,payment_data,payment_secret,byte,32", + "tlvdata,tlv_payload,payment_data,total_msat,tu64,", + "msgtype,invalid_realm,PERM|1", + "msgtype,temporary_node_failure,NODE|2", + "msgtype,permanent_node_failure,PERM|NODE|2", + "msgtype,required_node_feature_missing,PERM|NODE|3", + "msgtype,invalid_onion_version,BADONION|PERM|4", + "msgdata,invalid_onion_version,sha256_of_onion,sha256,", + "msgtype,invalid_onion_hmac,BADONION|PERM|5", + "msgdata,invalid_onion_hmac,sha256_of_onion,sha256,", + "msgtype,invalid_onion_key,BADONION|PERM|6", + "msgdata,invalid_onion_key,sha256_of_onion,sha256,", + "msgtype,temporary_channel_failure,UPDATE|7", + "msgdata,temporary_channel_failure,len,u16,", + "msgdata,temporary_channel_failure,channel_update,byte,len", + "msgtype,permanent_channel_failure,PERM|8", + "msgtype,required_channel_feature_missing,PERM|9", + "msgtype,unknown_next_peer,PERM|10", + "msgtype,amount_below_minimum,UPDATE|11", + "msgdata,amount_below_minimum,htlc_msat,u64,", + "msgdata,amount_below_minimum,len,u16,", + "msgdata,amount_below_minimum,channel_update,byte,len", + "msgtype,fee_insufficient,UPDATE|12", + "msgdata,fee_insufficient,htlc_msat,u64,", + "msgdata,fee_insufficient,len,u16,", + "msgdata,fee_insufficient,channel_update,byte,len", + "msgtype,incorrect_cltv_expiry,UPDATE|13", + "msgdata,incorrect_cltv_expiry,cltv_expiry,u32,", + "msgdata,incorrect_cltv_expiry,len,u16,", + "msgdata,incorrect_cltv_expiry,channel_update,byte,len", + "msgtype,expiry_too_soon,UPDATE|14", + "msgdata,expiry_too_soon,len,u16,", + "msgdata,expiry_too_soon,channel_update,byte,len", + "msgtype,incorrect_or_unknown_payment_details,PERM|15", + "msgdata,incorrect_or_unknown_payment_details,htlc_msat,u64,", + "msgdata,incorrect_or_unknown_payment_details,height,u32,", + "msgtype,final_incorrect_cltv_expiry,18", + "msgdata,final_incorrect_cltv_expiry,cltv_expiry,u32,", + "msgtype,final_incorrect_htlc_amount,19", + "msgdata,final_incorrect_htlc_amount,incoming_htlc_amt,u64,", + "msgtype,channel_disabled,UPDATE|20", + "msgtype,expiry_too_far,21", + "msgtype,invalid_onion_payload,PERM|22", + "msgdata,invalid_onion_payload,type,varint,", + "msgdata,invalid_onion_payload,offset,u16,", + "msgtype,mpp_timeout,23", +] diff --git a/contrib/pyln-proto/pyln/proto/message/bolt7/Makefile b/contrib/pyln-proto/pyln/proto/message/bolt7/Makefile new file mode 100755 index 000000000000..13a7d684747a --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt7/Makefile @@ -0,0 +1,7 @@ +#! /usr/bin/make + +SPECDIR := ../../../../../../../lightning-rfc + +csv.py: $(SPECDIR)/07-routing-gossip.md Makefile + SPECNUM=`basename $< | sed 's/-.*//'`; (echo csv = '['; python3 $(SPECDIR)/tools/extract-formats.py $< | sed 's/\(.*\)/ "\1",/'; echo ']') > $@ + chmod a+x $@ diff --git a/contrib/pyln-proto/pyln/proto/message/bolt7/__init__.py b/contrib/pyln-proto/pyln/proto/message/bolt7/__init__.py new file mode 100644 index 000000000000..2ba3aceb67ae --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt7/__init__.py @@ -0,0 +1,16 @@ +from .csv import csv +from .bolt import namespace +import sys + +__version__ = '0.0.1' + +__all__ = [ + 'csv', + 'namespace', +] + +mod = sys.modules[__name__] +for d in namespace.subtypes, namespace.tlvtypes, namespace.messagetypes: + for name in d: + setattr(mod, name, d[name]) + __all__.append(name) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt7/bolt.py b/contrib/pyln-proto/pyln/proto/message/bolt7/bolt.py new file mode 100644 index 000000000000..565c41228744 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt7/bolt.py @@ -0,0 +1,5 @@ +from pyln.proto.message import MessageNamespace +from .csv import csv + + +namespace = MessageNamespace(csv_lines=csv) diff --git a/contrib/pyln-proto/pyln/proto/message/bolt7/csv.py b/contrib/pyln-proto/pyln/proto/message/bolt7/csv.py new file mode 100755 index 000000000000..6c33c7b66382 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/bolt7/csv.py @@ -0,0 +1,83 @@ +csv = [ + "msgtype,announcement_signatures,259", + "msgdata,announcement_signatures,channel_id,channel_id,", + "msgdata,announcement_signatures,short_channel_id,short_channel_id,", + "msgdata,announcement_signatures,node_signature,signature,", + "msgdata,announcement_signatures,bitcoin_signature,signature,", + "msgtype,channel_announcement,256", + "msgdata,channel_announcement,node_signature_1,signature,", + "msgdata,channel_announcement,node_signature_2,signature,", + "msgdata,channel_announcement,bitcoin_signature_1,signature,", + "msgdata,channel_announcement,bitcoin_signature_2,signature,", + "msgdata,channel_announcement,len,u16,", + "msgdata,channel_announcement,features,byte,len", + "msgdata,channel_announcement,chain_hash,chain_hash,", + "msgdata,channel_announcement,short_channel_id,short_channel_id,", + "msgdata,channel_announcement,node_id_1,point,", + "msgdata,channel_announcement,node_id_2,point,", + "msgdata,channel_announcement,bitcoin_key_1,point,", + "msgdata,channel_announcement,bitcoin_key_2,point,", + "msgtype,node_announcement,257", + "msgdata,node_announcement,signature,signature,", + "msgdata,node_announcement,flen,u16,", + "msgdata,node_announcement,features,byte,flen", + "msgdata,node_announcement,timestamp,u32,", + "msgdata,node_announcement,node_id,point,", + "msgdata,node_announcement,rgb_color,byte,3", + "msgdata,node_announcement,alias,byte,32", + "msgdata,node_announcement,addrlen,u16,", + "msgdata,node_announcement,addresses,byte,addrlen", + "msgtype,channel_update,258", + "msgdata,channel_update,signature,signature,", + "msgdata,channel_update,chain_hash,chain_hash,", + "msgdata,channel_update,short_channel_id,short_channel_id,", + "msgdata,channel_update,timestamp,u32,", + "msgdata,channel_update,message_flags,byte,", + "msgdata,channel_update,channel_flags,byte,", + "msgdata,channel_update,cltv_expiry_delta,u16,", + "msgdata,channel_update,htlc_minimum_msat,u64,", + "msgdata,channel_update,fee_base_msat,u32,", + "msgdata,channel_update,fee_proportional_millionths,u32,", + "msgdata,channel_update,htlc_maximum_msat,u64,,option_channel_htlc_max", + "msgtype,query_short_channel_ids,261,gossip_queries", + "msgdata,query_short_channel_ids,chain_hash,chain_hash,", + "msgdata,query_short_channel_ids,len,u16,", + "msgdata,query_short_channel_ids,encoded_short_ids,byte,len", + "msgdata,query_short_channel_ids,tlvs,query_short_channel_ids_tlvs,", + "tlvtype,query_short_channel_ids_tlvs,query_flags,1", + "tlvdata,query_short_channel_ids_tlvs,query_flags,encoding_type,u8,", + "tlvdata,query_short_channel_ids_tlvs,query_flags,encoded_query_flags,byte,...", + "msgtype,reply_short_channel_ids_end,262,gossip_queries", + "msgdata,reply_short_channel_ids_end,chain_hash,chain_hash,", + "msgdata,reply_short_channel_ids_end,full_information,byte,", + "msgtype,query_channel_range,263,gossip_queries", + "msgdata,query_channel_range,chain_hash,chain_hash,", + "msgdata,query_channel_range,first_blocknum,u32,", + "msgdata,query_channel_range,number_of_blocks,u32,", + "msgdata,query_channel_range,tlvs,query_channel_range_tlvs,", + "tlvtype,query_channel_range_tlvs,query_option,1", + "tlvdata,query_channel_range_tlvs,query_option,query_option_flags,varint,", + "msgtype,reply_channel_range,264,gossip_queries", + "msgdata,reply_channel_range,chain_hash,chain_hash,", + "msgdata,reply_channel_range,first_blocknum,u32,", + "msgdata,reply_channel_range,number_of_blocks,u32,", + "msgdata,reply_channel_range,full_information,byte,", + "msgdata,reply_channel_range,len,u16,", + "msgdata,reply_channel_range,encoded_short_ids,byte,len", + "msgdata,reply_channel_range,tlvs,reply_channel_range_tlvs,", + "tlvtype,reply_channel_range_tlvs,timestamps_tlv,1", + "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,", + "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoded_timestamps,byte,...", + "tlvtype,reply_channel_range_tlvs,checksums_tlv,3", + "tlvdata,reply_channel_range_tlvs,checksums_tlv,checksums,channel_update_checksums,...", + "subtype,channel_update_timestamps", + "subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,", + "subtypedata,channel_update_timestamps,timestamp_node_id_2,u32,", + "subtype,channel_update_checksums", + "subtypedata,channel_update_checksums,checksum_node_id_1,u32,", + "subtypedata,channel_update_checksums,checksum_node_id_2,u32,", + "msgtype,gossip_timestamp_filter,265,gossip_queries", + "msgdata,gossip_timestamp_filter,chain_hash,chain_hash,", + "msgdata,gossip_timestamp_filter,first_timestamp,u32,", + "msgdata,gossip_timestamp_filter,timestamp_range,u32,", +] diff --git a/contrib/pyln-proto/pyln/proto/message/fundamental_types.py b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py new file mode 100644 index 000000000000..80c21570e8f7 --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/fundamental_types.py @@ -0,0 +1,249 @@ +import struct +from io import BufferedIOBase +import sys +from typing import Dict, Optional, Tuple, List, Any + + +def try_unpack(name: str, + io_out: BufferedIOBase, + structfmt: str, + empty_ok: bool) -> Optional[int]: + """Unpack a single value using struct.unpack. + +If need_all, never return None, otherwise returns None if EOF.""" + b = io_out.read(struct.calcsize(structfmt)) + if len(b) == 0 and empty_ok: + return None + elif len(b) < struct.calcsize(structfmt): + raise ValueError("{}: not enough bytes", name) + + return struct.unpack(structfmt, b)[0] + + +def split_field(s: str) -> Tuple[str, str]: + """Helper to split string into first part and remainder""" + def len_without(s, delim): + pos = s.find(delim) + if pos == -1: + return len(s) + return pos + + firstlen = min([len_without(s, d) for d in (',', '}', ']')]) + return s[:firstlen], s[firstlen:] + + +class FieldType(object): + """A (abstract) class representing the underlying type of a field. +These are further specialized. + + """ + def __init__(self, name: str): + self.name = name + + def only_at_tlv_end(self) -> bool: + """Some types only make sense inside a tlv, at the end""" + return False + + def name_and_val(self, name: str, v: Any) -> str: + """This is overridden by LengthFieldType to return nothing""" + return " {}={}".format(name, self.val_to_str(v, {})) + + def is_optional(self) -> bool: + """Overridden for tlv fields and optional fields""" + return False + + def len_fields_bad(self, fieldname: str, fieldvals: Dict[str, Any]) -> List[str]: + """Overridden by length fields for arrays""" + return [] + + def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str: + raise NotImplementedError() + + def __str__(self): + return self.name + + def __repr__(self): + return 'FieldType({})'.format(self.name) + + +class IntegerType(FieldType): + def __init__(self, name: str, bytelen: int, structfmt: str): + super().__init__(name) + self.bytelen = bytelen + self.structfmt = structfmt + + def val_to_str(self, v: int, otherfields: Dict[str, Any]): + return "{}".format(int(v)) + + def val_from_str(self, s: str) -> Tuple[int, str]: + a, b = split_field(s) + return int(a), b + + def write(self, io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any]) -> None: + io_out.write(struct.pack(self.structfmt, v)) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]: + return try_unpack(self.name, io_in, self.structfmt, empty_ok=True) + + +class ShortChannelIDType(IntegerType): + """short_channel_id has a special string representation, but is +basically a u64. + + """ + def __init__(self, name): + super().__init__(name, 8, '>Q') + + def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str: + # See BOLT #7: ## Definition of `short_channel_id` + return "{}x{}x{}".format(v >> 40, (v >> 16) & 0xFFFFFF, v & 0xFFFF) + + def val_from_str(self, s: str) -> Tuple[int, str]: + a, b = split_field(s) + parts = a.split('x') + if len(parts) != 3: + raise ValueError("short_channel_id should be NxNxN") + return ((int(parts[0]) << 40) + | (int(parts[1]) << 16) + | (int(parts[2]))), b + + +class TruncatedIntType(FieldType): + """Truncated integer types""" + def __init__(self, name: str, maxbytes: int): + super().__init__(name) + self.maxbytes = maxbytes + + def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str: + return "{}".format(int(v)) + + def only_at_tlv_end(self) -> bool: + """These only make sense at the end of a TLV""" + return True + + def val_from_str(self, s: str) -> Tuple[int, str]: + a, b = split_field(s) + if int(a) >= (1 << (self.maxbytes * 8)): + raise ValueError('{} exceeds maximum {} capacity' + .format(a, self.name)) + return int(a), b + + def write(self, io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any]) -> None: + binval = struct.pack('>Q', v) + while len(binval) != 0 and binval[0] == 0: + binval = binval[1:] + if len(binval) > self.maxbytes: + raise ValueError('{} exceeds maximum {} capacity' + .format(v, self.name)) + io_out.write(binval) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None: + binval = io_in.read() + if len(binval) > self.maxbytes: + raise ValueError('{} is too long for {}'.format(binval.hex(), self.name)) + if len(binval) > 0 and binval[0] == 0: + raise ValueError('{} encoding is not minimal: {}' + .format(self.name, binval.hex())) + # Pad with zeroes and convert as u64 + return struct.unpack_from('>Q', bytes(8 - len(binval)) + binval)[0] + + +class FundamentalHexType(FieldType): + """The remaining fundamental types are simply represented as hex strings""" + def __init__(self, name: str, bytelen: int): + super().__init__(name) + self.bytelen = bytelen + + def val_to_str(self, v: bytes, otherfields: Dict[str, Any]) -> str: + if len(bytes(v)) != self.bytelen: + raise ValueError("Length of {} != {}", v, self.bytelen) + return v.hex() + + def val_from_str(self, s: str) -> Tuple[bytes, str]: + a, b = split_field(s) + ret = bytes.fromhex(a) + if len(ret) != self.bytelen: + raise ValueError("Length of {} != {}", a, self.bytelen) + return ret, b + + def write(self, io_out: BufferedIOBase, v: bytes, otherfields: Dict[str, Any]) -> None: + if len(bytes(v)) != self.bytelen: + raise ValueError("Length of {} != {}", v, self.bytelen) + io_out.write(v) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[bytes]: + val = io_in.read(self.bytelen) + if len(val) == 0: + return None + elif len(val) != self.bytelen: + raise ValueError('{}: not enough remaining'.format(self)) + return val + + +class BigSizeType(FieldType): + """BigSize type, mainly used to encode TLV headers""" + def __init__(self, name): + super().__init__(name) + + def val_from_str(self, s: str) -> Tuple[int, str]: + a, b = split_field(s) + return int(a), b + + # For the convenience of TLV header parsing + @staticmethod + def write(io_out: BufferedIOBase, v: int, otherfields: Dict[str, Any] = {}) -> None: + if v < 253: + io_out.write(bytes([v])) + elif v < 2**16: + io_out.write(bytes([253]) + struct.pack('>H', v)) + elif v < 2**32: + io_out.write(bytes([254]) + struct.pack('>I', v)) + else: + io_out.write(bytes([255]) + struct.pack('>Q', v)) + + @staticmethod + def read(io_in: BufferedIOBase, otherfields: Dict[str, Any] = {}) -> Optional[int]: + "Returns value, or None on EOF" + b = io_in.read(1) + if len(b) == 0: + return None + if b[0] < 253: + return int(b[0]) + elif b[0] == 253: + return try_unpack('BigSize', io_in, '>H', empty_ok=False) + elif b[0] == 254: + return try_unpack('BigSize', io_in, '>I', empty_ok=False) + else: + return try_unpack('BigSize', io_in, '>Q', empty_ok=False) + + def val_to_str(self, v: int, otherfields: Dict[str, Any]) -> str: + return "{}".format(int(v)) + + +def fundamental_types(): + # From 01-messaging.md#fundamental-types: + return [IntegerType('byte', 1, 'B'), + IntegerType('u16', 2, '>H'), + IntegerType('u32', 4, '>I'), + IntegerType('u64', 8, '>Q'), + TruncatedIntType('tu16', 2), + TruncatedIntType('tu32', 4), + TruncatedIntType('tu64', 8), + FundamentalHexType('chain_hash', 32), + FundamentalHexType('channel_id', 32), + FundamentalHexType('sha256', 32), + FundamentalHexType('point', 33), + ShortChannelIDType('short_channel_id'), + FundamentalHexType('signature', 64), + BigSizeType('bigsize'), + # FIXME: See https://github.com/lightningnetwork/lightning-rfc/pull/778 + BigSizeType('varint'), + # FIXME + IntegerType('u8', 1, 'B'), + ] + + +# Expose these as native types. +mod = sys.modules[FieldType.__module__] +for m in fundamental_types(): + setattr(mod, m.name, m) diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py new file mode 100644 index 000000000000..1b8d8f37f66f --- /dev/null +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -0,0 +1,658 @@ +import struct +from io import BufferedIOBase, BytesIO +from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType +from .array_types import ( + SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType +) +from typing import Dict, List, Optional, Tuple, Any, cast + + +class MessageNamespace(object): + """A class which contains all FieldTypes and Messages in a particular +domain, such as within a given BOLT""" + def __init__(self, csv_lines: List[str] = []): + self.subtypes: Dict[str, SubtypeType] = {} + self.fundamentaltypes: Dict[str, SubtypeType] = {} + self.tlvtypes: Dict[str, TlvStreamType] = {} + self.messagetypes: Dict[str, MessageType] = {} + + # For convenience, basic types go in every namespace + for t in fundamental_types(): + self.add_fundamentaltype(t) + + self.load_csv(csv_lines) + + def __add__(self, other: 'MessageNamespace'): + ret = MessageNamespace() + ret.subtypes = self.subtypes.copy() + for v in other.subtypes.values(): + ret.add_subtype(v) + ret.tlvtypes = self.tlvtypes.copy() + for v in other.tlvtypes.values(): + ret.add_tlvtype(v) + ret.messagetypes = self.messagetypes.copy() + for v in other.messagetypes.values(): + ret.add_messagetype(v) + return ret + + def add_subtype(self, t: 'SubtypeType') -> None: + prev = self.get_type(t.name) + if prev: + raise ValueError('Already have {}'.format(prev)) + self.subtypes[t.name] = t + + def add_fundamentaltype(self, t: 'SubtypeType') -> None: + assert not self.get_type(t.name) + self.fundamentaltypes[t.name] = t + + def add_tlvtype(self, t: 'TlvStreamType') -> None: + prev = self.get_type(t.name) + if prev: + raise ValueError('Already have {}'.format(prev)) + self.tlvtypes[t.name] = t + + def add_messagetype(self, m: 'MessageType') -> None: + if self.get_msgtype(m.name): + raise ValueError('{}: message already exists'.format(m.name)) + if self.get_msgtype_by_number(m.number): + raise ValueError('{}: message {} already number {}'.format( + m.name, self.get_msgtype_by_number(m.number), m.number)) + self.messagetypes[m.name] = m + + def get_msgtype(self, name: str) -> Optional['MessageType']: + if name in self.messagetypes: + return self.messagetypes[name] + return None + + def get_msgtype_by_number(self, num: int) -> Optional['MessageType']: + for m in self.messagetypes.values(): + if m.number == num: + return m + return None + + def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']: + if name in self.fundamentaltypes: + return self.fundamentaltypes[name] + return None + + def get_subtype(self, name: str) -> Optional['SubtypeType']: + if name in self.subtypes: + return self.subtypes[name] + return None + + def get_tlvtype(self, name: str) -> Optional['TlvStreamType']: + if name in self.tlvtypes: + return self.tlvtypes[name] + return None + + def get_type(self, name: str) -> Optional['SubtypeType']: + t = self.get_fundamentaltype(name) + if t is None: + t = self.get_subtype(name) + if t is None: + t = self.get_tlvtype(name) + return t + + def load_csv(self, lines: List[str]) -> None: + """Load a series of comma-separate-value lines into the namespace""" + vals: Dict[str, List[List[str]]] = {'msgtype': [], + 'msgdata': [], + 'tlvtype': [], + 'tlvdata': [], + 'subtype': [], + 'subtypedata': []} + for l in lines: + parts = l.split(',') + if parts[0] not in vals: + raise ValueError("Unknown type {} in {}".format(parts[0], l)) + vals[parts[0]].append(parts[1:]) + + # Types can refer to other types, so add data last. + for parts in vals['msgtype']: + self.add_messagetype(MessageType.msgtype_from_csv(parts)) + + for parts in vals['subtype']: + self.add_subtype(SubtypeType.subtype_from_csv(parts)) + + for parts in vals['tlvtype']: + TlvStreamType.tlvtype_from_csv(self, parts) + + for parts in vals['msgdata']: + MessageType.msgfield_from_csv(self, parts) + + for parts in vals['subtypedata']: + SubtypeType.subfield_from_csv(self, parts) + + for parts in vals['tlvdata']: + TlvStreamType.tlvfield_from_csv(self, parts) + + +class MessageTypeField(object): + """A field within a particular message type or subtype""" + def __init__(self, ownername: str, name: str, fieldtype: FieldType, option: Optional[str] = None): + self.full_name = "{}.{}".format(ownername, name) + self.name = name + self.fieldtype = fieldtype + self.option = option + + def missing_fields(self, fieldvals: Dict[str, Any]): + """Return this field if it's not in fields""" + if self.name not in fieldvals and not self.option and not self.fieldtype.is_optional(): + return [self] + return [] + + def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]: + return self.fieldtype.len_fields_bad(fieldname, otherfields) + + def __str__(self): + return self.full_name + + def __repr__(self): + """Yuck, but this is what format() uses for lists""" + return self.full_name + + +class SubtypeType(object): + """This defines a 'subtype' in BOLT-speak. It consists of fields of +other types. Since 'msgtype' and 'tlvtype' are almost identical, they +inherit from this too. + + """ + def __init__(self, name: str): + self.name = name + self.fields: List[FieldType] = [] + + def find_field(self, fieldname: str): + for f in self.fields: + if f.name == fieldname: + return f + return None + + def add_field(self, field: FieldType): + if self.find_field(field.name): + raise ValueError("{}: duplicate field {}".format(self, field)) + self.fields.append(field) + + def __str__(self): + return "subtype-{}".format(self.name) + + def len_fields_bad(self, fieldname: str, otherfields: Dict[str, Any]) -> List[str]: + bad_fields: List[str] = [] + for f in self.fields: + bad_fields += f.len_fields_bad('{}.{}'.format(fieldname, f.name), + otherfields) + + return bad_fields + + @staticmethod + def subtype_from_csv(parts: List[str]) -> 'SubtypeType': + """e.g subtype,channel_update_timestamps""" + if len(parts) != 1: + raise ValueError("subtype expected 2 CSV parts, not {}" + .format(parts)) + return SubtypeType(parts[0]) + + def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField: + """Takes msgdata/subtypedata after first two fields + e.g. [...]timestamp_node_id_1,u32, + + """ + basetype = namespace.get_type(parts[1]) + if basetype is None: + raise ValueError('Unknown type {}'.format(parts[1])) + + # Fixed number, or another field. + if parts[2] != '': + lenfield = self.find_field(parts[2]) + if lenfield is not None: + # If we didn't know that field was a length, we do now! + if type(lenfield.fieldtype) is not LengthFieldType: + lenfield.fieldtype = LengthFieldType(lenfield.fieldtype) + field = MessageTypeField(self.name, parts[0], + DynamicArrayType(self, + parts[0], + basetype, + lenfield), + option) + lenfield.fieldtype.add_length_for(field) + elif ellipsisok and parts[2] == '...': + field = MessageTypeField(self.name, parts[0], + EllipsisArrayType(self, + parts[0], basetype), + option) + else: + field = MessageTypeField(self.name, parts[0], + SizedArrayType(self, + parts[0], basetype, + int(parts[2])), + option) + else: + field = MessageTypeField(self.name, parts[0], basetype, option) + + return field + + def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]: + if not s.startswith('{'): + raise ValueError("subtype {} must be wrapped in '{{}}': bad {}" + .format(self, s)) + s = s[1:] + ret: Dict[str, Any] = {} + # FIXME: perhaps allow unlabelled fields to imply assign fields in order? + while not s.startswith('}'): + fieldname, s = s.split('=', 1) + f = self.find_field(fieldname) + if f is None: + raise ValueError("Unknown field name {}".format(fieldname)) + ret[fieldname], s = f.fieldtype.val_from_str(s) + if s[0] == ',': + s = s[1:] + + # All non-optional fields must be specified. + for f in self.fields: + if not f.fieldtype.is_optional() and f.name not in ret: + raise ValueError("{} missing field {}".format(self, f)) + + return ret, s[1:] + + def _raise_if_badvals(self, v: Dict[str, Any]) -> None: + # Every non-optional value must be specified, and no others. + defined = set([f.name for f in self.fields]) + have = set(v) + + unknown = have.difference(defined) + if unknown: + raise ValueError("Unknown fields specified: {}".format(unknown)) + + for f in defined.difference(have): + if not f.fieldtype.is_optional(): + raise ValueError("Missing value for {}".format(f)) + + def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str: + self._raise_if_badvals(v) + s = '' + sep = '' + for fname, val in v.items(): + field = self.find_field(fname) + s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields) + sep = ',' + + return '{' + s + '}' + + def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None: + self._raise_if_badvals(v) + for fname, val in v.items(): + field = self.find_field(fname) + field.fieldtype.write(io_out, val, otherfields) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]: + vals = {} + for field in self.fields: + val = field.fieldtype.read(io_in, otherfields) + if val is None: + # Might only exist with certain options available + if field.fieldtype.option is None: + raise ValueError("{}.{}: short read".format(self, field)) + vals[field.name] = val + + return vals + + @staticmethod + def subfield_from_csv(namespace: MessageNamespace, parts: List[str]) -> None: + """e.g +subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,""" + if len(parts) != 4: + raise ValueError("subtypedata expected 4 CSV parts, not {}" + .format(parts)) + subtype = namespace.get_subtype(parts[0]) + if subtype is None: + raise ValueError("unknown subtype {}".format(parts[0])) + + field = subtype._field_from_csv(namespace, parts[1:]) + if field.fieldtype.only_at_tlv_end(): + raise ValueError("{}: cannot have TLV field {}" + .format(subtype, field)) + subtype.add_field(field) + + +class MessageType(SubtypeType): + """Each MessageType has a specific value, eg 17 is error""" + # * 0x8000 (BADONION): unparsable onion encrypted by sending peer + # * 0x4000 (PERM): permanent failure (otherwise transient) + # * 0x2000 (NODE): node failure (otherwise channel) + # * 0x1000 (UPDATE): new channel update enclosed + onion_types = {'BADONION': 0x8000, + 'PERM': 0x4000, + 'NODE': 0x2000, + 'UPDATE': 0x1000} + + def __init__(self, name: str, value: str, option: Optional[str] = None): + super().__init__(name) + self.number = self.parse_value(value) + self.option = option + + def parse_value(self, value: str) -> int: + result = 0 + for token in value.split('|'): + if token in self.onion_types.keys(): + result |= self.onion_types[token] + else: + result |= int(token) + + return result + + def __str__(self): + return "msgtype-{}".format(self.name) + + @staticmethod + def msgtype_from_csv(parts: List[str]) -> 'MessageType': + """e.g msgtype,open_channel,32,option_foo""" + option = None + if len(parts) == 3: + option = parts[2] + elif len(parts) < 2 or len(parts) > 3: + raise ValueError("msgtype expected 3 CSV parts, not {}" + .format(parts)) + return MessageType(parts[0], parts[1], option) + + @staticmethod + def msgfield_from_csv(namespace: MessageNamespace, parts: List[str]) -> None: + """e.g msgdata,open_channel,temporary_channel_id,byte,32[,opt]""" + option = None + if len(parts) == 5: + option = parts[4] + elif len(parts) != 4: + raise ValueError("msgdata expected 4 CSV parts, not {}" + .format(parts)) + messagetype = namespace.get_msgtype(parts[0]) + if messagetype is None: + raise ValueError("unknown subtype {}".format(parts[0])) + + field = messagetype._field_from_csv(namespace, parts[1:4], + option=option) + messagetype.add_field(field) + + +class TlvStreamType(SubtypeType): + """A TlvStreamType is just a Subtype, but its fields are +TlvMessageTypes. In the CSV format these are created implicitly, when +a tlvtype line (which defines a TlvMessageType within the TlvType, +confusingly) refers to them. + + """ + def __init__(self, name): + super().__init__(name) + + def __str__(self): + return "tlvstreamtype-{}".format(self.name) + + def find_field_by_number(self, num: int) -> Optional['TlvMessageType']: + for f in self.fields: + if f.number == num: + return f + return None + + def is_optional(self) -> bool: + """You can omit a tlvstream= altogether""" + return True + + @staticmethod + def tlvtype_from_csv(namespace: MessageNamespace, parts: List[str]) -> None: + """e.g tlvtype,reply_channel_range_tlvs,timestamps_tlv,1""" + if len(parts) != 3: + raise ValueError("tlvtype expected 4 CSV parts, not {}" + .format(parts)) + tlvstream = namespace.get_tlvtype(parts[0]) + if tlvstream is None: + tlvstream = TlvStreamType(parts[0]) + namespace.add_tlvtype(tlvstream) + + tlvstream.add_field(TlvMessageType(parts[1], parts[2])) + + @staticmethod + def tlvfield_from_csv(namespace: MessageNamespace, parts: List[str]) -> None: + """e.g +tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8, + + """ + if len(parts) != 5: + raise ValueError("tlvdata expected 6 CSV parts, not {}" + .format(parts)) + + tlvstream = namespace.get_tlvtype(parts[0]) + if tlvstream is None: + raise ValueError("unknown tlvtype {}".format(parts[0])) + + field = tlvstream.find_field(parts[1]) + if field is None: + raise ValueError("Unknown tlv field {}.{}" + .format(tlvstream, parts[1])) + + subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True) + field.add_field(subfield) + + def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]: + """{fieldname={...},...}. Returns dict of fieldname->val""" + if not s.startswith('{'): + raise ValueError("tlvtype {} must be wrapped in '{{}}': bad {}" + .format(self, s)) + s = s[1:] + ret: Dict[str, Any] = {} + while not s.startswith('}'): + fieldname, s = s.split('=', 1) + f = self.find_field(fieldname) + if f is None: + # Unknown fields are number=hexstring + hexstring, s = split_field(s) + # Make sure it is actually a valid int! + ret[str(int(fieldname))] = bytes.fromhex(hexstring) + else: + ret[fieldname], s = f.val_from_str(s) + if s[0] == ',': + s = s[1:] + + return ret, s[1:] + + def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str: + s = '' + sep = '' + for fieldname in v: + f = self.find_field(fieldname) + s += sep + if f is None: + s += str(int(fieldname)) + '=' + v[fieldname].hex() + else: + s += f.name + '=' + f.val_to_str(v[fieldname], otherfields) + sep = ',' + + return '{' + s + '}' + + def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None: + # If they didn't specify this tlvstream, it's empty. + if v is None: + return + + # Make a tuple of (fieldnum, val_to_bin, val) so we can sort into + # ascending order as TLV spec requires. + def write_raw_val(iobuf, val, otherfields: Dict[str, Any]): + iobuf.write(val) + + def get_value(tup): + """Get value from num, fun, val tuple""" + return tup[0] + + ordered = [] + for fieldname in v: + f = self.find_field(fieldname) + if f is None: + # fieldname can be an integer for a raw field. + ordered.append((int(fieldname), write_raw_val, v[fieldname])) + else: + ordered.append((f.number, f.write, v[fieldname])) + + ordered.sort(key=get_value) + + for typenum, writefunc, val in ordered: + buf = BytesIO() + writefunc(buf, val, otherfields) + BigSizeType.write(io_out, typenum) + BigSizeType.write(io_out, len(buf.getvalue())) + io_out.write(buf.getvalue()) + + def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]: + vals: Dict[str, Any] = {} + + while True: + tlv_type = BigSizeType.read(io_in) + if tlv_type is None: + return vals + + tlv_len = BigSizeType.read(io_in) + if tlv_len is None: + raise ValueError("{}: truncated tlv_len field".format(self)) + binval = io_in.read(tlv_len) + if len(binval) != tlv_len: + raise ValueError("{}: truncated tlv {} value" + .format(tlv_type, self)) + f = self.find_field_by_number(tlv_type) + if f is None: + # Raw fields are allowed, just index by number. + vals[tlv_type] = binval + else: + # FIXME: Why doesn't mypy think BytesIO is a valid BufferedIOBase? + vals[f.name] = f.read(cast(BufferedIOBase, BytesIO(binval)), otherfields) + + def name_and_val(self, name: str, v: Dict[str, Any]) -> str: + """This is overridden by LengthFieldType to return nothing""" + return " {}={}".format(name, self.val_to_str(v, {})) + + +class TlvMessageType(MessageType): + """A 'tlvtype' in BOLT-speak""" + + def __init__(self, name: str, value: str): + super().__init__(name, value) + + def __str__(self): + return "tlvmsgtype-{}".format(self.name) + + +class Message(object): + """A particular message instance""" + def __init__(self, messagetype: MessageType, **kwargs): + """MessageType is the type of this msg, with fields. Fields can either be valid values for the type, or if they are strings they are converted according to the field type""" + self.messagetype = messagetype + self.fields: Dict[str, Any] = {} + + # Convert arguments from strings to values if necessary. + for field in kwargs: + self.set_field(field, kwargs[field]) + + bad_lens = self.messagetype.len_fields_bad(self.messagetype.name, + self.fields) + if bad_lens: + raise ValueError("Inconsistent length fields: {}".format(bad_lens)) + + def set_field(self, field: str, val: Any) -> None: + f = self.messagetype.find_field(field) + if f is None: + raise ValueError("Unknown field {}".format(field)) + + if isinstance(val, str): + val, remainder = f.fieldtype.val_from_str(val) + if remainder != '': + raise ValueError('Unexpected {} at end of initializer for {}'.format(remainder, field)) + self.fields[field] = val + + def missing_fields(self) -> List[str]: + """Are any required fields missing?""" + missing: List[str] = [] + for ftype in self.messagetype.fields: + missing += ftype.missing_fields(self.fields) + + return missing + + @staticmethod + def read(namespace: MessageNamespace, io_in: BufferedIOBase) -> Optional['Message']: + """Read and decode a Message within that namespace. + +Returns None on EOF + + """ + typenum = try_unpack('message_type', io_in, ">H", empty_ok=True) + if typenum is None: + return None + + mtype = namespace.get_msgtype_by_number(typenum) + if mtype is None: + raise ValueError('Unknown message type number {}'.format(typenum)) + + fields: Dict[str, Any] = {} + for f in mtype.fields: + fields[f.name] = f.fieldtype.read(io_in, fields) + if fields[f.name] is None: + # optional fields are OK to be missing at end! + if f.option is not None: + break + raise ValueError('{}: truncated at field {}' + .format(mtype, f.name)) + + return Message(mtype, **fields) + + @staticmethod + def from_str(namespace: MessageNamespace, s: str, incomplete_ok=False) -> 'Message': + """Decode a string to a Message within that namespace. + +Format is msgname [ field=...]*. + + """ + parts = s.split() + + mtype = namespace.get_msgtype(parts[0]) + if mtype is None: + raise ValueError('Unknown message type name {}'.format(parts[0])) + + args = {} + for p in parts[1:]: + assign = p.split('=', 1) + args[assign[0]] = assign[1] + + m = Message(mtype, **args) + + if not incomplete_ok: + missing = m.missing_fields() + if len(missing): + raise ValueError('Missing fields: {}'.format(missing)) + + return m + + def write(self, io_out: BufferedIOBase) -> None: + """Write a Message into its wire format. + +Must not have missing fields. + + """ + if self.missing_fields(): + raise ValueError('Missing fields: {}' + .format(self.missing_fields())) + + io_out.write(struct.pack(">H", self.messagetype.number)) + for f in self.messagetype.fields: + # Optional fields get val == None. Usually this means they don't + # write anything, but length fields are an exception: they intuit + # their value from other fields. + if f.name in self.fields: + val = self.fields[f.name] + else: + # If this isn't present, and it's marked optional, don't write. + if f.option is not None: + return + val = None + f.fieldtype.write(io_out, val, self.fields) + + def to_str(self) -> str: + """Encode a Message into a string""" + ret = "{}".format(self.messagetype.name) + for f in self.messagetype.fields: + if f.name in self.fields: + ret += f.fieldtype.name_and_val(f.name, self.fields[f.name]) + return ret diff --git a/contrib/pyln-proto/requirements.txt b/contrib/pyln-proto/requirements.txt index 98a17156b21f..4c579bfd197f 100644 --- a/contrib/pyln-proto/requirements.txt +++ b/contrib/pyln-proto/requirements.txt @@ -2,3 +2,4 @@ bitstring==3.1.6 cryptography==2.8 coincurve==13.0.0 base58==1.0.2 +mypy diff --git a/contrib/pyln-proto/setup.py b/contrib/pyln-proto/setup.py index 0b9e8721a9be..e50e5326f784 100644 --- a/contrib/pyln-proto/setup.py +++ b/contrib/pyln-proto/setup.py @@ -17,7 +17,7 @@ author='Christian Decker', author_email='decker.christian@gmail.com', license='MIT', - packages=['pyln.proto'], + packages=['pyln.proto', 'pyln.proto.message', 'pyln.proto.message.bolts', 'pyln.proto.message.bolt1'], scripts=[], zip_safe=True, install_requires=requirements) diff --git a/contrib/pyln-proto/tests/test_array_types.py b/contrib/pyln-proto/tests/test_array_types.py new file mode 100644 index 000000000000..2c1c9f2279f4 --- /dev/null +++ b/contrib/pyln-proto/tests/test_array_types.py @@ -0,0 +1,100 @@ +#! /usr/bin/python3 +from pyln.proto.message.fundamental_types import byte, u16, short_channel_id +from pyln.proto.message.array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType, LengthFieldType +import io + + +def test_sized_array(): + + # Simple class to make outer work. + class dummy: + def __init__(self, name): + self.name = name + + for arrtype, s, b in [[SizedArrayType(dummy("test1"), "test_arr", byte, 4), + "00010203", + bytes([0, 1, 2, 3])], + [SizedArrayType(dummy("test2"), "test_arr", u16, 4), + "[0,1,2,256]", + bytes([0, 0, 0, 1, 0, 2, 1, 0])], + [SizedArrayType(dummy("test3"), "test_arr", short_channel_id, 4), + "[1x2x3,4x5x6,7x8x9,10x11x12]", + bytes([0, 0, 1, 0, 0, 2, 0, 3] + + [0, 0, 4, 0, 0, 5, 0, 6] + + [0, 0, 7, 0, 0, 8, 0, 9] + + [0, 0, 10, 0, 0, 11, 0, 12])]]: + v, _ = arrtype.val_from_str(s) + assert arrtype.val_to_str(v, None) == s + v2 = arrtype.read(io.BytesIO(b), None) + assert v2 == v + buf = io.BytesIO() + arrtype.write(buf, v, None) + assert buf.getvalue() == b + + +def test_ellipsis_array(): + # Simple class to make outer work. + class dummy: + def __init__(self, name): + self.name = name + + for arrtype, s, b in [[EllipsisArrayType(dummy("test1"), "test_arr", byte), + "00010203", + bytes([0, 1, 2, 3])], + [EllipsisArrayType(dummy("test2"), "test_arr", u16), + "[0,1,2,256]", + bytes([0, 0, 0, 1, 0, 2, 1, 0])], + [EllipsisArrayType(dummy("test3"), "test_arr", short_channel_id), + "[1x2x3,4x5x6,7x8x9,10x11x12]", + bytes([0, 0, 1, 0, 0, 2, 0, 3] + + [0, 0, 4, 0, 0, 5, 0, 6] + + [0, 0, 7, 0, 0, 8, 0, 9] + + [0, 0, 10, 0, 0, 11, 0, 12])]]: + v, _ = arrtype.val_from_str(s) + assert arrtype.val_to_str(v, None) == s + v2 = arrtype.read(io.BytesIO(b), None) + assert v2 == v + buf = io.BytesIO() + arrtype.write(buf, v, None) + assert buf.getvalue() == b + + +def test_dynamic_array(): + # Simple class to make outer. + class dummy: + def __init__(self, name): + self.name = name + + class field_dummy: + def __init__(self, name, ftype): + self.fieldtype = ftype + self.name = name + + lenfield = field_dummy('lenfield', LengthFieldType(u16)) + + for arrtype, s, b in [[DynamicArrayType(dummy("test1"), "test_arr", byte, + lenfield), + "00010203", + bytes([0, 1, 2, 3])], + [DynamicArrayType(dummy("test2"), "test_arr", u16, + lenfield), + "[0,1,2,256]", + bytes([0, 0, 0, 1, 0, 2, 1, 0])], + [DynamicArrayType(dummy("test3"), "test_arr", short_channel_id, + lenfield), + "[1x2x3,4x5x6,7x8x9,10x11x12]", + bytes([0, 0, 1, 0, 0, 2, 0, 3] + + [0, 0, 4, 0, 0, 5, 0, 6] + + [0, 0, 7, 0, 0, 8, 0, 9] + + [0, 0, 10, 0, 0, 11, 0, 12])]]: + + lenfield.fieldtype.add_length_for(field_dummy(s, arrtype)) + v, _ = arrtype.val_from_str(s) + otherfields = {s: v} + assert arrtype.val_to_str(v, otherfields) == s + v2 = arrtype.read(io.BytesIO(b), otherfields) + assert v2 == v + buf = io.BytesIO() + arrtype.write(buf, v, None) + assert buf.getvalue() == b + lenfield.fieldtype.len_for = [] diff --git a/contrib/pyln-proto/tests/test_bolt1.py b/contrib/pyln-proto/tests/test_bolt1.py new file mode 100644 index 000000000000..d64578dee6c0 --- /dev/null +++ b/contrib/pyln-proto/tests/test_bolt1.py @@ -0,0 +1,61 @@ +#! /usr/bin/python3 +from pyln.proto.message import Message, MessageNamespace +import pyln.proto.message.bolt1 as bolt1 +import io + + +def test_bolt_01_csv_tlv(): + # FIXME: Test failure cases too! + for t in [['0x', ''], + ['0x21 00', '33='], + ['0xfd0201 00', '513='], + ['0xfd00fd 00', '253='], + ['0xfd00ff 00', '255='], + ['0xfe02000001 00', '33554433='], + ['0xff0200000000000001 00', '144115188075855873='], + ['0x01 00', 'tlv1={amount_msat=0}'], + ['0x01 01 01', 'tlv1={amount_msat=1}'], + ['0x01 02 0100', 'tlv1={amount_msat=256}'], + ['0x01 03 010000', 'tlv1={amount_msat=65536}'], + ['0x01 04 01000000', 'tlv1={amount_msat=16777216}'], + ['0x01 05 0100000000', 'tlv1={amount_msat=4294967296}'], + ['0x01 06 010000000000', 'tlv1={amount_msat=1099511627776}'], + ['0x01 07 01000000000000', 'tlv1={amount_msat=281474976710656}'], + ['0x01 08 0100000000000000', 'tlv1={amount_msat=72057594037927936}'], + ['0x02 08 0000000000000226', 'tlv2={scid=0x0x550}'], + ['0x03 31 023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002', 'tlv3={node_id=023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb,amount_msat_1=1,amount_msat_2=2}'], + ['0xfd00fe 02 0226', 'tlv4={cltv_delta=550}']]: + msg = io.BytesIO(bytes.fromhex(t[0][2:].replace(' ', ''))) + + val = bolt1.n1.read(msg, None) + assert len(msg.read()) == 0 + assert bolt1.n1.val_to_str(val, None) == '{' + t[1] + '}' + + +def test_bolt_01_csv(): + # We can create a namespace from the csv. + ns = MessageNamespace(bolt1.csv) + + # string [expected string] + for t in [['init globalfeatures= features=80', + 'init globalfeatures= features=80 tlvs={}'], + ['init globalfeatures= features=80 tlvs={}'], + ['init globalfeatures= features=80 tlvs={networks={chains=[6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000]}}'], + ['init globalfeatures= features=80 tlvs={networks={chains=[6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000,1fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000]}}'], + ['error channel_id=0000000000000000000000000000000000000000000000000000000000000000 data=00'], + ['ping num_pong_bytes=0 ignored='], + ['ping num_pong_bytes=3 ignored=0000'], + ['pong ignored='], + ['pong ignored=000000']]: + m = Message.from_str(bolt1.namespace, t[0]) + b = io.BytesIO() + m.write(b) + + # Works with our manually-made namespace, and the builtin one. + b.seek(0) + m2 = Message.read(bolt1.namespace, b) + assert m2.to_str() == t[-1] + + b.seek(0) + m2 = Message.read(ns, b) + assert m2.to_str() == t[-1] diff --git a/contrib/pyln-proto/tests/test_bolt2.py b/contrib/pyln-proto/tests/test_bolt2.py new file mode 100644 index 000000000000..7068b5803d1d --- /dev/null +++ b/contrib/pyln-proto/tests/test_bolt2.py @@ -0,0 +1,8 @@ +#! /usr/bin/python3 +from pyln.proto.message import MessageNamespace +import pyln.proto.message.bolt2 as bolt2 + + +# FIXME: more tests +def test_bolt_02_csv(): + MessageNamespace(bolt2.csv) diff --git a/contrib/pyln-proto/tests/test_bolt4.py b/contrib/pyln-proto/tests/test_bolt4.py new file mode 100644 index 000000000000..460611609419 --- /dev/null +++ b/contrib/pyln-proto/tests/test_bolt4.py @@ -0,0 +1,8 @@ +#! /usr/bin/python3 +from pyln.proto.message import MessageNamespace +import pyln.proto.message.bolt4 as bolt4 + + +# FIXME: more tests +def test_bolt_04_csv(): + MessageNamespace(bolt4.csv) diff --git a/contrib/pyln-proto/tests/test_bolt7.py b/contrib/pyln-proto/tests/test_bolt7.py new file mode 100644 index 000000000000..0ab7decb74cf --- /dev/null +++ b/contrib/pyln-proto/tests/test_bolt7.py @@ -0,0 +1,14 @@ +#! /usr/bin/python3 +from pyln.proto.message import MessageNamespace +import pyln.proto.message.bolt7 as bolt7 + + +# FIXME: more tests +def test_bolt_07_csv(): + MessageNamespace(bolt7.csv) + + +def test_bolt_07_subtypes(): + for t in ['{timestamp_node_id_1=1,timestamp_node_id_2=2}']: + vals, _ = bolt7.channel_update_timestamps.val_from_str(t) + assert bolt7.channel_update_timestamps.val_to_str(vals, None) == t diff --git a/contrib/pyln-proto/tests/test_fundamental_types.py b/contrib/pyln-proto/tests/test_fundamental_types.py new file mode 100644 index 000000000000..5513a13e545b --- /dev/null +++ b/contrib/pyln-proto/tests/test_fundamental_types.py @@ -0,0 +1,77 @@ +#! /usr/bin/python3 +from pyln.proto.message.fundamental_types import fundamental_types +import io + + +def test_fundamental_types(): + expect = {'byte': [['255', b'\xff'], + ['0', b'\x00']], + 'u16': [['65535', b'\xff\xff'], + ['0', b'\x00\x00']], + 'u32': [['4294967295', b'\xff\xff\xff\xff'], + ['0', b'\x00\x00\x00\x00']], + 'u64': [['18446744073709551615', + b'\xff\xff\xff\xff\xff\xff\xff\xff'], + ['0', b'\x00\x00\x00\x00\x00\x00\x00\x00']], + 'tu16': [['65535', b'\xff\xff'], + ['256', b'\x01\x00'], + ['255', b'\xff'], + ['0', b'']], + 'tu32': [['4294967295', b'\xff\xff\xff\xff'], + ['65536', b'\x01\x00\x00'], + ['65535', b'\xff\xff'], + ['256', b'\x01\x00'], + ['255', b'\xff'], + ['0', b'']], + 'tu64': [['18446744073709551615', + b'\xff\xff\xff\xff\xff\xff\xff\xff'], + ['4294967296', b'\x01\x00\x00\x00\x00'], + ['4294967295', b'\xff\xff\xff\xff'], + ['65536', b'\x01\x00\x00'], + ['65535', b'\xff\xff'], + ['256', b'\x01\x00'], + ['255', b'\xff'], + ['0', b'']], + 'chain_hash': [['0102030405060708090a0b0c0d0e0f10' + '1112131415161718191a1b1c1d1e1f20', + bytes(range(1, 33))]], + 'channel_id': [['0102030405060708090a0b0c0d0e0f10' + '1112131415161718191a1b1c1d1e1f20', + bytes(range(1, 33))]], + 'sha256': [['0102030405060708090a0b0c0d0e0f10' + '1112131415161718191a1b1c1d1e1f20', + bytes(range(1, 33))]], + 'signature': [['0102030405060708090a0b0c0d0e0f10' + '1112131415161718191a1b1c1d1e1f20' + '2122232425262728292a2b2c2d2e2f30' + '3132333435363738393a3b3c3d3e3f40', + bytes(range(1, 65))]], + 'point': [['02030405060708090a0b0c0d0e0f10' + '1112131415161718191a1b1c1d1e1f20' + '2122', + bytes(range(2, 35))]], + 'short_channel_id': [['1x2x3', bytes([0, 0, 1, 0, 0, 2, 0, 3])]], + 'bigsize': [['0', bytes([0])], + ['252', bytes([252])], + ['253', bytes([253, 0, 253])], + ['65535', bytes([253, 255, 255])], + ['65536', bytes([254, 0, 1, 0, 0])], + ['4294967295', bytes([254, 255, 255, 255, 255])], + ['4294967296', bytes([255, 0, 0, 0, 1, 0, 0, 0, 0])]], + } + + untested = set() + for t in fundamental_types(): + if t.name not in expect: + untested.add(t.name) + continue + for test in expect[t.name]: + v, _ = t.val_from_str(test[0]) + assert t.val_to_str(v, None) == test[0] + v2 = t.read(io.BytesIO(test[1]), None) + assert v2 == v + buf = io.BytesIO() + t.write(buf, v, None) + assert buf.getvalue() == test[1] + + assert untested == set(['varint', 'u8']) diff --git a/contrib/pyln-proto/tests/test_message.py b/contrib/pyln-proto/tests/test_message.py new file mode 100644 index 000000000000..186880a795d0 --- /dev/null +++ b/contrib/pyln-proto/tests/test_message.py @@ -0,0 +1,179 @@ +#! /usr/bin/python3 +from pyln.proto.message import MessageNamespace, Message +import pytest +import io + + +def test_fundamental(): + ns = MessageNamespace() + ns.load_csv(['msgtype,test,1', + 'msgdata,test,test_byte,byte,', + 'msgdata,test,test_u16,u16,', + 'msgdata,test,test_u32,u32,', + 'msgdata,test,test_u64,u64,', + 'msgdata,test,test_chain_hash,chain_hash,', + 'msgdata,test,test_channel_id,channel_id,', + 'msgdata,test,test_sha256,sha256,', + 'msgdata,test,test_signature,signature,', + 'msgdata,test,test_point,point,', + 'msgdata,test,test_short_channel_id,short_channel_id,', + ]) + + mstr = """test + test_byte=255 + test_u16=65535 + test_u32=4294967295 + test_u64=18446744073709551615 + test_chain_hash=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 + test_channel_id=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 + test_sha256=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 + test_signature=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f40 + test_point=0201030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021 + test_short_channel_id=1x2x3""" + m = Message.from_str(ns, mstr) + + # Same (ignoring whitespace differences) + assert m.to_str().split() == mstr.split() + + +def test_static_array(): + ns = MessageNamespace() + ns.load_csv(['msgtype,test1,1', + 'msgdata,test1,test_arr,byte,4']) + ns.load_csv(['msgtype,test2,2', + 'msgdata,test2,test_arr,short_channel_id,4']) + + for test in [["test1 test_arr=00010203", bytes([0, 1] + [0, 1, 2, 3])], + ["test2 test_arr=[0x1x2,4x5x6,7x8x9,10x11x12]", + bytes([0, 2] + + [0, 0, 0, 0, 0, 1, 0, 2] + + [0, 0, 4, 0, 0, 5, 0, 6] + + [0, 0, 7, 0, 0, 8, 0, 9] + + [0, 0, 10, 0, 0, 11, 0, 12])]]: + m = Message.from_str(ns, test[0]) + assert m.to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == test[1] + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] + + +def test_subtype(): + ns = MessageNamespace() + ns.load_csv(['msgtype,test1,1', + 'msgdata,test1,test_sub,channel_update_timestamps,4', + 'subtype,channel_update_timestamps', + 'subtypedata,' + + 'channel_update_timestamps,timestamp_node_id_1,u32,', + 'subtypedata,' + + 'channel_update_timestamps,timestamp_node_id_2,u32,']) + + for test in [["test1 test_sub=[" + "{timestamp_node_id_1=1,timestamp_node_id_2=2}" + ",{timestamp_node_id_1=3,timestamp_node_id_2=4}" + ",{timestamp_node_id_1=5,timestamp_node_id_2=6}" + ",{timestamp_node_id_1=7,timestamp_node_id_2=8}]", + bytes([0, 1] + + [0, 0, 0, 1, 0, 0, 0, 2] + + [0, 0, 0, 3, 0, 0, 0, 4] + + [0, 0, 0, 5, 0, 0, 0, 6] + + [0, 0, 0, 7, 0, 0, 0, 8])]]: + m = Message.from_str(ns, test[0]) + assert m.to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == test[1] + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] + + # Test missing field logic. + m = Message.from_str(ns, "test1", incomplete_ok=True) + assert m.missing_fields() + + +def test_tlv(): + ns = MessageNamespace() + ns.load_csv(['msgtype,test1,1', + 'msgdata,test1,tlvs,test_tlvstream,', + 'tlvtype,test_tlvstream,tlv1,1', + 'tlvdata,test_tlvstream,tlv1,field1,byte,4', + 'tlvdata,test_tlvstream,tlv1,field2,u32,', + 'tlvtype,test_tlvstream,tlv2,255', + 'tlvdata,test_tlvstream,tlv2,field3,byte,...']) + + for test in [["test1 tlvs={tlv1={field1=01020304,field2=5}}", + bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5])], + ["test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304}}", + bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + + [253, 0, 255, 4, 1, 2, 3, 4])], + ["test1 tlvs={tlv1={field1=01020304,field2=5},4=010203,tlv2={field3=01020304}}", + bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + + [4, 3, 1, 2, 3] + + [253, 0, 255, 4, 1, 2, 3, 4])]]: + m = Message.from_str(ns, test[0]) + assert m.to_str() == test[0] + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == test[1] + assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] + + # Ordering test (turns into canonical ordering) + m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}') + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + + [4, 3, 1, 2, 3] + + [253, 0, 255, 4, 1, 2, 3, 4]) + + +def test_message_constructor(): + ns = MessageNamespace(['msgtype,test1,1', + 'msgdata,test1,tlvs,test_tlvstream,', + 'tlvtype,test_tlvstream,tlv1,1', + 'tlvdata,test_tlvstream,tlv1,field1,byte,4', + 'tlvdata,test_tlvstream,tlv1,field2,u32,', + 'tlvtype,test_tlvstream,tlv2,255', + 'tlvdata,test_tlvstream,tlv2,field3,byte,...']) + + m = Message(ns.get_msgtype('test1'), + tlvs='{tlv1={field1=01020304,field2=5}' + ',tlv2={field3=01020304},4=010203}') + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == bytes([0, 1] + + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] + + [4, 3, 1, 2, 3] + + [253, 0, 255, 4, 1, 2, 3, 4]) + + +def test_dynamic_array(): + """Test that dynamic array types enforce matching lengths""" + ns = MessageNamespace(['msgtype,test1,1', + 'msgdata,test1,count,u16,', + 'msgdata,test1,arr1,byte,count', + 'msgdata,test1,arr2,u32,count']) + + # This one is fine. + m = Message(ns.get_msgtype('test1'), + arr1='01020304', arr2='[1,2,3,4]') + buf = io.BytesIO() + m.write(buf) + assert buf.getvalue() == bytes([0, 1] + + [0, 4] + + [1, 2, 3, 4] + + [0, 0, 0, 1, + 0, 0, 0, 2, + 0, 0, 0, 3, + 0, 0, 0, 4]) + + # These ones are not + with pytest.raises(ValueError, match='Inconsistent length.*count'): + m = Message(ns.get_msgtype('test1'), + arr1='01020304', arr2='[1,2,3]') + + with pytest.raises(ValueError, match='Inconsistent length.*count'): + m = Message(ns.get_msgtype('test1'), + arr1='01020304', arr2='[1,2,3,4,5]')