From 711eafef717c54520b4f5bb360191c27d29a7b3e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 17 May 2024 13:11:35 +0800 Subject: [PATCH] add support for connected UDP sockets --- wasip1/listen_wasip1.go | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/wasip1/listen_wasip1.go b/wasip1/listen_wasip1.go index 38ff140..5e5940b 100644 --- a/wasip1/listen_wasip1.go +++ b/wasip1/listen_wasip1.go @@ -44,13 +44,37 @@ func ListenPacket(network, address string) (net.PacketConn, error) { addr := &netAddr{network, address} return nil, listenErr(addr, err) } - conn, err := listenPacketAddr(addrs[0]) + conn, err := listenPacketAddr(addrs[0], nil) if err != nil { return nil, listenErr(addrs[0], err) } return conn, nil } +// DialUDP connects to the remote UDP network address from the local UDP network address. +func DialUDP(network, localAddr, remoteAddr string) (net.PacketConn, error) { + switch network { + case "udp", "udp4", "udp6": + default: + return nil, unsupportedNetwork(network, localAddr) + } + laddrs, err := lookupAddr(context.Background(), "listen", network, localAddr) + if err != nil { + addr := &netAddr{network, localAddr} + return nil, listenErr(addr, err) + } + raddrs, err := lookupAddr(context.Background(), "dial", network, remoteAddr) + if err != nil { + addr := &netAddr{network, localAddr} + return nil, dialErr(addr, err) + } + conn, err := listenPacketAddr(laddrs[0], raddrs[0]) + if err != nil { + return nil, listenErr(laddrs[0], err) + } + return conn, nil +} + func unsupportedNetwork(network, address string) error { return fmt.Errorf("unsupported network: %s://%s", network, address) } @@ -105,7 +129,8 @@ func listenAddr(addr net.Addr) (net.Listener, error) { return makeListener(l, name), nil } -func listenPacketAddr(addr net.Addr) (net.PacketConn, error) { +// If remoteAddr is set, the socket will be connected to that address. +func listenPacketAddr(addr net.Addr, remoteAddr net.Addr) (net.PacketConn, error) { fd, err := socket(family(addr), SOCK_DGRAM, 0) if err != nil { return nil, os.NewSyscallError("socket", err) @@ -131,6 +156,16 @@ func listenPacketAddr(addr net.Addr) (net.PacketConn, error) { return nil, os.NewSyscallError("bind", err) } + if remoteAddr != nil { + remoteSockAddr, err := socketAddress(remoteAddr) + if err != nil { + return nil, os.NewSyscallError("connect", err) + } + if err := connect(fd, remoteSockAddr); err != nil { + return nil, os.NewSyscallError("connect", err) + } + } + name, err := getsockname(fd) if err != nil { return nil, os.NewSyscallError("getsockname", err)