Skip to content

Commit 738689f

Browse files
conn: Add the possibility to send and receive message through the connection
This commit enable the connection to send a message through the wire in an ergonomic way. This feature is a basic blocks for the lnprototest refactoring that allow to semplify how to write test with lnprototest in the future by keeping the state with the peer by connection and keep inside the runner just the necessary logic to interact with the node. Signed-off-by: Vincenzo Palazzo <[email protected]>
1 parent 9f1d7a4 commit 738689f

File tree

4 files changed

+63
-9
lines changed

4 files changed

+63
-9
lines changed

lnprototest/runner.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@
77
import coincurve
88
import functools
99

10-
import pyln
11-
from pyln.proto.message import Message
12-
1310
from abc import ABC, abstractmethod
1411
from typing import Dict, Optional, List, Union, Any, Callable
1512

13+
from pyln.proto.message import Message
14+
1615
from .bitfield import bitfield
1716
from .errors import SpecFileError
1817
from .structure import Sequence
1918
from .event import Event, MustNotMsg, ExpectMsg
20-
from .utils import privkey_expand
19+
from .utils import privkey_expand, ResolvableStr, ResolvableInt, resolve_args
2120
from .keyset import KeySet
2221
from .namespace import namespace
2322

@@ -78,6 +77,33 @@ def get_stash(self, event: Event, stashname: str, default: Any = None) -> Any:
7877
raise SpecFileError(event, "Unknown stash name {}".format(stashname))
7978
return self.stash[stashname]
8079

80+
def recv_msg(
81+
self, timeout: int = 1000, skip_filter: Optional[int] = None
82+
) -> Message:
83+
"""Listen on the connection for incoming message.
84+
85+
If the {skip_filter} is specified, the message that
86+
match the filters are skipped.
87+
"""
88+
raw_msg = self.connection.read_message()
89+
msg = Message.read(namespace(), io.BytesIO(raw_msg))
90+
self.add_stash(msg.messagetype.name, msg)
91+
return msg
92+
93+
def send_msg(
94+
self, msg_name: str, **kwargs: Union[ResolvableStr, ResolvableInt]
95+
) -> None:
96+
"""Send a message through the last connection"""
97+
msgtype = namespace().get_msgtype(msg_name)
98+
msg = Message(msgtype, **resolve_args(self, kwargs))
99+
missing = msg.missing_fields()
100+
if missing:
101+
raise SpecFileError(self, "Missing fields {}".format(missing))
102+
binmsg = io.BytesIO()
103+
msg.write(binmsg)
104+
self.connection.send_message(binmsg.getvalue())
105+
# FIXME: we should listen to possible connection here
106+
81107

82108
class Runner(ABC):
83109
"""Abstract base class for runners.
@@ -189,7 +215,7 @@ def is_running(self) -> bool:
189215
pass
190216

191217
@abstractmethod
192-
def connect(self, event: Event, connprivkey: str) -> None:
218+
def connect(self, event: Event, connprivkey: str) -> RunnerConn:
193219
pass
194220

195221
def send_msg(self, msg: Message) -> None:

lnprototest/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
check_hex,
1515
privkey_for_index,
1616
merge_events_sequences,
17+
Resolvable,
18+
ResolvableBool,
19+
ResolvableInt,
20+
ResolvableStr,
21+
resolve_arg,
22+
resolve_args,
1723
)
1824
from .bitcoin_utils import (
1925
ScriptType,

lnprototest/utils/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88
import logging
99
import traceback
1010

11-
from typing import Union, Sequence, List
11+
from typing import Union, Sequence, List, Dict, Callable, Any
1212
from enum import IntEnum
1313

1414
from lnprototest.keyset import KeySet
1515

16+
# Type for arguments: either strings, or functions to call at runtime
17+
ResolvableStr = Union[str, Callable[["RunnerConn", "Event", str], str]]
18+
ResolvableInt = Union[int, Callable[["RunnerConn", "Event", str], int]]
19+
ResolvableBool = Union[int, Callable[["RunnerConn", "Event", str], bool]]
20+
Resolvable = Union[Any, Callable[["RunnerConn", "Event", str], Any]]
21+
1622

1723
class Side(IntEnum):
1824
local = 0
@@ -106,3 +112,19 @@ def merge_events_sequences(
106112
"""Merge the two list in the pre-post order"""
107113
pre.extend(post)
108114
return pre
115+
116+
117+
def resolve_arg(fieldname: str, conn: "RunnerConn", arg: Resolvable) -> Any:
118+
"""If this is a string, return it, otherwise call it to get result"""
119+
if callable(arg):
120+
return arg(conn, fieldname)
121+
else:
122+
return arg
123+
124+
125+
def resolve_args(conn: "RunnerConn", kwargs: Dict[str, Resolvable]) -> Dict[str, Any]:
126+
"""Take a dict of args, replace callables with their return values"""
127+
ret: Dict[str, str] = {}
128+
for field, str_or_func in kwargs.items():
129+
ret[field] = resolve_arg(field, conn, str_or_func)
130+
return ret

tests/test_v2_bolt1-01-init.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def test_v2_init_is_first_msg(runner: Runner, namespaceoverride: Any) -> None:
1111
"""
1212
runner.start()
1313

14-
runner.connect(None, connprivkey="03")
15-
init_msg = runner.recv_msg()
14+
conn1 = runner.connect(None, connprivkey="03")
15+
init_msg = conn1.recv_msg()
1616
assert (
1717
init_msg.messagetype.number == 16
1818
), f"received not an init msg but: {init_msg.to_str()}"
19-
19+
conn1.send_msg("init", globalfeatures="", features="")
2020
runner.stop()

0 commit comments

Comments
 (0)