Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:
Expand Down
10 changes: 8 additions & 2 deletions management/internals/modules/reverseproxy/service/manager/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
}

service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}

if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
Expand Down Expand Up @@ -132,7 +135,10 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {

service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}

if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
Expand Down
235 changes: 219 additions & 16 deletions management/internals/modules/reverseproxy/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import (
"fmt"
"math/big"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"

"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
Expand Down Expand Up @@ -49,17 +52,25 @@ const (
SourceEphemeral = "ephemeral"
)

type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}

type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
}

type PasswordAuthConfig struct {
Expand Down Expand Up @@ -194,15 +205,17 @@ func (s *Service) ToAPIResponse() *api.Service {
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
st := api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
st.Options = targetOptionsToAPI(target.Options)
apiTargets = append(apiTargets, st)
}

meta := api.ServiceMeta{
Expand Down Expand Up @@ -256,10 +269,14 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{

pm := &proto.PathMapping{
Path: path,
Target: targetURL.String(),
})
}

pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
}

auth := &proto.Authentication{
Expand Down Expand Up @@ -312,13 +329,87 @@ func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}

func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode string

const (
PathRewritePreserve PathRewriteMode = "preserve"
)

func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
switch mode {
case PathRewritePreserve:
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
default:
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
}
}

func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
if opts.SkipTLSVerify {
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
}
if opts.RequestTimeout != 0 {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
}
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
return apiOpts
}

func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
}
return popts
}

func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
opts.SkipTLSVerify = *o.SkipTlsVerify
}
if o.RequestTimeout != nil {
d, err := time.ParseDuration(*o.RequestTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
}
opts.RequestTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
return opts, nil
}

func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID

targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
for i, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Expand All @@ -331,6 +422,13 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
}
target.Options = opts
}
targets = append(targets, target)
}
s.Targets = targets
Expand Down Expand Up @@ -368,6 +466,8 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
}
s.Auth.BearerAuth = bearerAuth
}

return nil
}

func (s *Service) Validate() error {
Expand Down Expand Up @@ -400,11 +500,108 @@ func (s *Service) Validate() error {
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
}

return nil
}

const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)

// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)

// hopByHopHeaders are headers that must not be set as custom headers
// because they are connection-level and stripped by the proxy.
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Proxy-Connection": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
}

// reservedHeaders are set authoritatively by the proxy or control HTTP framing
// and cannot be overridden.
var reservedHeaders = map[string]struct{}{
"Content-Length": {},
"Content-Type": {},
"Cookie": {},
"Forwarded": {},
"X-Forwarded-For": {},
"X-Forwarded-Host": {},
"X-Forwarded-Port": {},
"X-Forwarded-Proto": {},
"X-Real-Ip": {},
}

func validateTargetOptions(idx int, opts *TargetOptions) error {
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}

if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
}

if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}

return nil
}

func validateCustomHeaders(idx int, headers map[string]string) error {
if len(headers) > maxCustomHeaders {
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
}
for key, value := range headers {
if !httpHeaderNameRe.MatchString(key) {
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
}
if len(key) > maxHeaderKeyLen {
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
}
if len(value) > maxHeaderValueLen {
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
}
if containsCRLF(key) || containsCRLF(value) {
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
}
canonical := http.CanonicalHeaderKey(key)
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
}
if canonical == "Host" {
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
}
}
return nil
}

func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}

func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
}
Expand All @@ -417,6 +614,12 @@ func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
targetCopy.Options.CustomHeaders[k] = v
}
}
targets[i] = &targetCopy
}

Expand Down
Loading
Loading