Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Server struct {
LMTP bool

Domain string
MaxConnections int
MaxRecipients int
MaxMessageBytes int64
MaxLineLength int
Expand Down Expand Up @@ -127,17 +128,7 @@ func (s *Server) Serve(l net.Listener) error {
}

func (s *Server) handleConn(c *Conn) error {
s.locker.Lock()
s.conns[c] = struct{}{}
s.locker.Unlock()

defer func() {
c.Close()

s.locker.Lock()
delete(s.conns, c)
s.locker.Unlock()
}()
defer c.Close()

if tlsConn, ok := c.conn.(*tls.Conn); ok {
if d := s.ReadTimeout; d != 0 {
Expand All @@ -151,6 +142,29 @@ func (s *Server) handleConn(c *Conn) error {
}
}

// register connection
maxConnsExceeded := false
s.locker.Lock()
if s.MaxConnections > 0 && len(s.conns) >= s.MaxConnections {
maxConnsExceeded = true
} else {
s.conns[c] = struct{}{}
}
s.locker.Unlock()

// limit connections
if maxConnsExceeded {
c.writeResponse(421, EnhancedCode{4, 4, 5}, "Too busy. Try again later.")
return nil
}

// unregister connection
defer func() {
s.locker.Lock()
delete(s.conns, c)
s.locker.Unlock()
}()

c.greet()

for {
Expand Down
46 changes: 46 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1514,3 +1514,49 @@ func TestServerDSNwithSMTPUTF8(t *testing.T) {
t.Fatal("Invalid ORCPT address:", val)
}
}

func TestServer_MaxConnections(t *testing.T) {
cases := []struct {
name string
maxConnections int
expected string
}{
// 0 = unlimited; all connections should be accepted
{name: "MaxConnections set to 0", maxConnections: 0, expected: "220 localhost ESMTP Service Ready"},
// 1 = only one connection is allowed; the second connection should be rejected
{name: "MaxConnections set to 1", maxConnections: 1, expected: "421 4.4.5 Too busy. Try again later."},
// 2 = two connections are allowed; the second connection should be accepted
{name: "MaxConnections set to 2", maxConnections: 2, expected: "220 localhost ESMTP Service Ready"},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// create server with limited allowed connections
_, s, c, scanner1 := testServer(t, func(s *smtp.Server) {
s.MaxConnections = tc.maxConnections
})
defer s.Close()

// there is already be one connection registered
// and we can read the greeting from it (see testServerGreeted())
scanner1.Scan()
if scanner1.Text() != "220 localhost ESMTP Service Ready" {
t.Fatal("Invalid first greeting:", scanner1.Text())
}

// now we create a second connection
c2, err := net.Dial("tcp", c.RemoteAddr().String())
if err != nil {
t.Fatal("Error creating second connection:", err)
}

// we should get an appropriate greeting now
scanner2 := bufio.NewScanner(c2)
scanner2.Scan()
if scanner2.Text() != tc.expected {
t.Fatal("Invalid second greeting:", scanner2.Text())
}
})
}

}