Skip to content

Commit c92f888

Browse files
committed
feat: Add rate limiting
Adds a rate limiting mechanism, that will send HTTP/429 responses once a defined limit is reached (Token Bucket) Signed-off-by: Manuel Rüger <[email protected]>
1 parent 60725be commit c92f888

9 files changed

+90
-25
lines changed

docs/web-config.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ tls_server_config:
1010
basic_auth_users:
1111
alice: $2y$10$mDwo.lAisC94iLAyP81MCesa29IzH37oigHC/42V2pdJlUprsJPze
1212
bob: $2y$10$hLqFl9jSjoAAy95Z/zw8Ye8wkdMBM8c5Bn1ptYqP/AXyV0.oy0S8m
13+
14+
# Rate limiting requests on the endpoint using a token bucket
15+
rate_limit:
16+
interval: "1s" # time interval between two requests, set to 0 to disable rate limiter
17+
burst: 20 # and permits a burst of up to 20 requests.

docs/web-configuration.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Generic placeholders are defined as follows:
2020
* `<filename>`: a valid path in the current working directory
2121
* `<secret>`: a regular string that is a secret, such as a password
2222
* `<string>`: a regular string
23+
* `<int>`: a regular integer
2324

2425
```
2526
tls_server_config:
@@ -125,6 +126,12 @@ http_server_config:
125126
# required. Passwords are hashed with bcrypt.
126127
basic_auth_users:
127128
[ <string>: <secret> ... ]
129+
130+
131+
# Rate limiting requests on the endpoint using a token bucket
132+
rate_limit:
133+
interval: <duration> # time interval between two requests, set to 0 to disable rate limiter
134+
burst: <int> # and permits a burst of <int> requests.
128135
```
129136

130137
[A sample configuration file](web-config.yml) is provided.

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
go.yaml.in/yaml/v2 v2.4.2
1111
golang.org/x/crypto v0.41.0
1212
golang.org/x/sync v0.16.0
13+
golang.org/x/time v0.12.0
1314
)
1415

1516
require (

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
6262
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
6363
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
6464
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
65+
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
66+
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
6567
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
6668
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
6769
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

web/handler.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"sync"
2525

2626
"golang.org/x/crypto/bcrypt"
27+
"golang.org/x/time/rate"
2728
)
2829

2930
// extraHTTPHeaders is a map of HTTP headers that can be added to HTTP
@@ -80,6 +81,7 @@ type webHandler struct {
8081
handler http.Handler
8182
logger *slog.Logger
8283
cache *cache
84+
limiter *rate.Limiter
8385
// bcryptMtx is there to ensure that bcrypt.CompareHashAndPassword is run
8486
// only once in parallel as this is CPU intensive.
8587
bcryptMtx sync.Mutex
@@ -93,6 +95,11 @@ func (u *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
9395
return
9496
}
9597

98+
if u.limiter != nil && !u.limiter.Allow() {
99+
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
100+
return
101+
}
102+
96103
// Configure http headers.
97104
for k, v := range c.HTTPConfig.Header {
98105
w.Header().Set(k, v)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rate_limit:
2+
interval: 0
3+
burst: 0
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rate_limit:
2+
interval: "1s"
3+
burst: 0

web/tls_config.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ import (
2626
"path/filepath"
2727
"strconv"
2828
"strings"
29+
"time"
2930

3031
"github.com/coreos/go-systemd/v22/activation"
3132
"github.com/mdlayher/vsock"
3233
config_util "github.com/prometheus/common/config"
3334
"go.yaml.in/yaml/v2"
3435
"golang.org/x/sync/errgroup"
36+
"golang.org/x/time/rate"
3537
)
3638

3739
var (
@@ -40,9 +42,10 @@ var (
4042
)
4143

4244
type Config struct {
43-
TLSConfig TLSConfig `yaml:"tls_server_config"`
44-
HTTPConfig HTTPConfig `yaml:"http_server_config"`
45-
Users map[string]config_util.Secret `yaml:"basic_auth_users"`
45+
TLSConfig TLSConfig `yaml:"tls_server_config"`
46+
HTTPConfig HTTPConfig `yaml:"http_server_config"`
47+
RateLimiterConfig RateLimiterConfig `yaml:"rate_limit"`
48+
Users map[string]config_util.Secret `yaml:"basic_auth_users"`
4649
}
4750

4851
type TLSConfig struct {
@@ -109,6 +112,11 @@ type HTTPConfig struct {
109112
Header map[string]string `yaml:"headers,omitempty"`
110113
}
111114

115+
type RateLimiterConfig struct {
116+
Burst int `yaml:"burst"`
117+
Interval time.Duration `yaml:"interval"`
118+
}
119+
112120
func getConfig(configPath string) (*Config, error) {
113121
content, err := os.ReadFile(configPath)
114122
if err != nil {
@@ -365,11 +373,18 @@ func Serve(l net.Listener, server *http.Server, flags *FlagConfig, logger *slog.
365373
return err
366374
}
367375

376+
var limiter *rate.Limiter
377+
if c.RateLimiterConfig.Interval != 0 {
378+
limiter = rate.NewLimiter(rate.Every(c.RateLimiterConfig.Interval), c.RateLimiterConfig.Burst)
379+
logger.Info("Rate Limiter is enabled.", "burst", c.RateLimiterConfig.Burst, "interval", c.RateLimiterConfig.Interval)
380+
}
381+
368382
server.Handler = &webHandler{
369383
tlsConfigPath: tlsConfigPath,
370384
logger: logger,
371385
handler: handler,
372386
cache: newCache(),
387+
limiter: limiter,
373388
}
374389

375390
config, err := ConfigToTLSConfig(&c.TLSConfig)

web/tls_config_test.go

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ var (
7272
// Introduced in Go 1.21
7373
"Certificate required": regexp.MustCompile(`certificate required`),
7474
"Unknown CA": regexp.MustCompile(`unknown certificate authority`),
75+
"Too Many Requests": regexp.MustCompile(`Too Many Requests`),
7576
}
7677
)
7778

@@ -98,6 +99,7 @@ type TestInputs struct {
9899
Username string
99100
Password string
100101
ClientCertificate string
102+
Requests int
101103
}
102104

103105
func TestYAMLFiles(t *testing.T) {
@@ -364,6 +366,20 @@ func TestServerBehaviour(t *testing.T) {
364366
ClientCertificate: "client2_selfsigned",
365367
ExpectedError: ErrorMap["Invalid client cert"],
366368
},
369+
{
370+
Name: "valid rate limiter (no rate limiter set up) that doesn't block",
371+
YAMLConfigPath: "testdata/web_config_rate_limiter_nonblocking.yaml",
372+
UseTLSClient: false,
373+
Requests: 10,
374+
ExpectedError: nil,
375+
},
376+
{
377+
Name: "valid rate limiter with an interval of one second",
378+
YAMLConfigPath: "testdata/web_config_rate_limiter_one_second.yaml",
379+
UseTLSClient: false,
380+
Requests: 10,
381+
ExpectedError: ErrorMap["Too Many Requests"],
382+
},
367383
}
368384
for _, testInputs := range testTables {
369385
t.Run(testInputs.Name, testInputs.Test)
@@ -511,35 +527,41 @@ func (test *TestInputs) Test(t *testing.T) {
511527
if test.Username != "" {
512528
req.SetBasicAuth(test.Username, test.Password)
513529
}
530+
514531
return client.Do(req)
515532
}
516533
go func() {
517534
time.Sleep(250 * time.Millisecond)
518-
r, err := ClientConnection()
519-
if err != nil {
520-
recordConnectionError(err)
521-
return
522-
}
523535

524-
if test.ActualCipher != 0 {
525-
if r.TLS.CipherSuite != test.ActualCipher {
526-
recordConnectionError(
527-
fmt.Errorf("bad cipher suite selected. Expected: %s, got: %s",
528-
tls.CipherSuiteName(test.ActualCipher),
529-
tls.CipherSuiteName(r.TLS.CipherSuite),
530-
),
531-
)
536+
for req := 0; req <= test.Requests; req++ {
537+
538+
r, err := ClientConnection()
539+
540+
if err != nil {
541+
recordConnectionError(err)
542+
return
532543
}
533-
}
534544

535-
body, err := io.ReadAll(r.Body)
536-
if err != nil {
537-
recordConnectionError(err)
538-
return
539-
}
540-
if string(body) != "Hello World!" {
541-
recordConnectionError(errors.New(string(body)))
542-
return
545+
if test.ActualCipher != 0 {
546+
if r.TLS.CipherSuite != test.ActualCipher {
547+
recordConnectionError(
548+
fmt.Errorf("bad cipher suite selected. Expected: %s, got: %s",
549+
tls.CipherSuiteName(test.ActualCipher),
550+
tls.CipherSuiteName(r.TLS.CipherSuite),
551+
),
552+
)
553+
}
554+
}
555+
556+
body, err := io.ReadAll(r.Body)
557+
if err != nil {
558+
recordConnectionError(err)
559+
return
560+
}
561+
if string(body) != "Hello World!" {
562+
recordConnectionError(errors.New(string(body)))
563+
return
564+
}
543565
}
544566
recordConnectionError(nil)
545567
}()

0 commit comments

Comments
 (0)