diff --git a/src/cmd/serve.go b/src/cmd/serve.go index 2639c2a..f9ffc36 100644 --- a/src/cmd/serve.go +++ b/src/cmd/serve.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "net" "net/netip" "os" "slices" @@ -353,7 +354,15 @@ func (c serveCmdConfig) Run() { Peers: []peer.PeerConfigArgs{ { PublicKey: viper.GetString("Relay.Peer.publickey"), - Endpoint: viper.GetString("Relay.Peer.endpoint"), + Endpoint: func() string { + if len(viper.GetString("Relay.Peer.endpoint")) > 0 { + endpoint, err := net.ResolveUDPAddr("udp", (viper.GetString("Relay.Peer.endpoint"))) + check("failed to resolve endpoint DNS name for '" + viper.GetString("Relay.Peer.endpoint") + "'", err) + return endpoint.String() + } else { + return "" + } + }(), PersistentKeepaliveInterval: func() int { if len(viper.GetString("Relay.Peer.endpoint")) > 0 { return viper.GetInt("Relay.Peer.keepalive") diff --git a/src/peer/config.go b/src/peer/config.go index c8d5968..d9c100a 100644 --- a/src/peer/config.go +++ b/src/peer/config.go @@ -378,14 +378,15 @@ func (c *Config) GetPeerPublicKey(i int) string { func (c *Config) GetPeerEndpoint(i int) string { if len(c.peers) > i { + endpointDNS := c.peers[i].endpointDNS + if endpointDNS != "" { + return endpointDNS + } endpoint := c.peers[i].config.Endpoint if endpoint != nil { return endpoint.String() } - - return "" } - return "" } diff --git a/src/peer/peer_config.go b/src/peer/peer_config.go index 7cc4064..0b20e6a 100644 --- a/src/peer/peer_config.go +++ b/src/peer/peer_config.go @@ -15,6 +15,7 @@ import ( type PeerConfig struct { config wgtypes.PeerConfig privateKey *wgtypes.Key + endpointDNS string nickname string } @@ -22,6 +23,7 @@ type peerConfigJSON struct { Config wgtypes.PeerConfig PrivateKey *wgtypes.Key Nickname string + EndpointDNS string } type PeerConfigArgs struct { @@ -87,7 +89,7 @@ func GetPeerConfig(args PeerConfigArgs) (PeerConfig, error) { return PeerConfig{}, err } } - + if args.Nickname != "" { err = c.SetNickname(args.Nickname) if err != nil { @@ -109,7 +111,7 @@ func NewPeerConfig() (PeerConfig, error) { PublicKey: privateKey.PublicKey(), }, privateKey: &privateKey, - nickname: "", + nickname: "", }, nil } @@ -118,6 +120,7 @@ func (p *PeerConfig) MarshalJSON() ([]byte, error) { p.config, p.privateKey, p.nickname, + p.endpointDNS, }) } @@ -132,6 +135,7 @@ func (p *PeerConfig) UnmarshalJSON(b []byte) error { p.config = tmp.Config p.privateKey = tmp.PrivateKey p.nickname = tmp.Nickname + p.endpointDNS = tmp.EndpointDNS return nil } @@ -170,12 +174,20 @@ func (p *PeerConfig) SetPresharedKey(presharedKey string) error { } func (p *PeerConfig) SetEndpoint(addr string) error { - endpoint, err := net.ResolveUDPAddr("udp", addr) + host, _, err := net.SplitHostPort(addr) if err != nil { return err } - - p.config.Endpoint = endpoint + ip := net.ParseIP(host) + if ip != nil { + endpoint, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + p.config.Endpoint = endpoint + } else { + p.endpointDNS = addr + } return nil } @@ -258,13 +270,13 @@ func (p *PeerConfig) SetNickname(nickname string) error { func (p *PeerConfig) AsFile() string { var s strings.Builder s.WriteString("[Peer]\n") - + if p.nickname != "" { s.WriteString(fmt.Sprintf("%s Nickname = %s\n", CUSTOM_PREFIX, p.nickname)) } - + s.WriteString(fmt.Sprintf("PublicKey = %s\n", p.config.PublicKey.String())) - + ips := []string{} for _, a := range p.config.AllowedIPs { ips = append(ips, a.String()) @@ -275,6 +287,9 @@ func (p *PeerConfig) AsFile() string { if p.config.Endpoint != nil { s.WriteString(fmt.Sprintf("Endpoint = %s\n", p.config.Endpoint.String())) } + if p.endpointDNS != "" { + s.WriteString(fmt.Sprintf("Endpoint = %s\n", p.endpointDNS)) + } if p.config.PersistentKeepaliveInterval != nil { s.WriteString(fmt.Sprintf("PersistentKeepalive = %d\n", *p.config.PersistentKeepaliveInterval/time.Second)) } @@ -289,6 +304,9 @@ func (p *PeerConfig) AsIPC() string { if p.config.Endpoint != nil { s.WriteString(fmt.Sprintf("endpoint=%s\n", p.config.Endpoint.String())) } + if p.endpointDNS != "" { + s.WriteString(fmt.Sprintf("endpoint=%s\n", p.endpointDNS)) + } for _, a := range p.config.AllowedIPs { s.WriteString(fmt.Sprintf("allowed_ip=%s\n", a.String())) }