Skip to content

Add Config to FlagConfig struct #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@ var extraHTTPHeaders = map[string][]string{
"Content-Security-Policy": nil,
}

func validateUsers(configPath string) error {
c, err := getConfig(configPath)
if err != nil {
return err
}

func validateUsers(c *Config) error {
for _, p := range c.Users {
_, err = bcrypt.Cost([]byte(p))
_, err := bcrypt.Cost([]byte(p))
if err != nil {
return err
}
Expand Down Expand Up @@ -77,6 +72,7 @@ HeadersLoop:

type webHandler struct {
tlsConfigPath string
config *Config
handler http.Handler
logger *slog.Logger
cache *cache
Expand All @@ -86,9 +82,23 @@ type webHandler struct {
}

func (u *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := getConfig(u.tlsConfigPath)
var c *Config
var err error

if u.config == nil {
c, err = getConfig(u.tlsConfigPath)
if err != nil {
u.logger.Error("Unable to parse configuration", "err", err.Error())
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
} else {
c = u.config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check that u.config is well-defined? The code later loops over c.HTTPConfig.Header and accesses c.Users[user]. getConfig() also sets some defaults here:

c := &Config{

}

err = ValidateWebConfig(c)
if err != nil {
u.logger.Error("Unable to parse configuration", "err", err.Error())
u.logger.Error("Invalid web configuration", "err", err.Error())
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
Expand Down
96 changes: 73 additions & 23 deletions web/tls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ type TLSConfig struct {
type FlagConfig struct {
WebListenAddresses *[]string
WebSystemdSocket *bool
WebConfigFile *string
WebConfigFile *string // Optional: path to the TLS config file. Ether this or TLSConfig must be set.
WebConfig *Config // Optional: Configuration. If set, it overrides WebConfigFile.
}

// SetDirectory joins any relative file paths with dir.
Expand Down Expand Up @@ -135,6 +136,11 @@ func getTLSConfig(configPath string) (*tls.Config, error) {
if err != nil {
return nil, err
}

if err := validateUsers(c); err != nil {
return nil, err
}

return ConfigToTLSConfig(&c.TLSConfig)
}

Expand Down Expand Up @@ -345,13 +351,28 @@ func parseVsockPort(address string) (uint32, error) {
// WebConfigFile in the FlagConfig, TLS or basic auth could be enabled.
func Serve(l net.Listener, server *http.Server, flags *FlagConfig, logger *slog.Logger) error {
logger.Info("Listening on", "address", l.Addr().String())
tlsConfigPath := *flags.WebConfigFile
if tlsConfigPath == "" {
logger.Info("TLS is disabled.", "http2", false, "address", l.Addr().String())
return server.Serve(l)
var c *Config
var err error

// WebConfig overrides WebConfigFile. If WebConfig field is not set, then WebConfigFile is used.
if flags.WebConfig == nil {
tlsConfigPath := *flags.WebConfigFile
if tlsConfigPath == "" {
logger.Info("TLS is disabled.", "http2", false, "address", l.Addr().String())
return server.Serve(l)
}

c, err = getConfig(tlsConfigPath)
if err != nil {
return err
}
} else {
// Use the provided config.
c = flags.WebConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here,

c := &Config{
sets also certain defaults.

}

if err := validateUsers(tlsConfigPath); err != nil {
err = ValidateWebConfig(c)
if err != nil {
return err
}

Expand All @@ -361,16 +382,11 @@ func Serve(l net.Listener, server *http.Server, flags *FlagConfig, logger *slog.
handler = server.Handler
}

c, err := getConfig(tlsConfigPath)
if err != nil {
return err
}

server.Handler = &webHandler{
tlsConfigPath: tlsConfigPath,
logger: logger,
handler: handler,
cache: newCache(),
config: c,
logger: logger,
handler: handler,
cache: newCache(),
}

config, err := ConfigToTLSConfig(&c.TLSConfig)
Expand All @@ -395,12 +411,29 @@ func Serve(l net.Listener, server *http.Server, flags *FlagConfig, logger *slog.
// Set the GetConfigForClient method of the HTTPS server so that the config
// and certs are reloaded on new connections.
server.TLSConfig.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
config, err := getTLSConfig(tlsConfigPath)
if err != nil {
return nil, err
var tlsConfig *tls.Config
var err error
// Config overrides WebConfigFile. If Config fiels is not set, then WebConfigFile is used.
if flags.WebConfig == nil {
tlsConfigPath := *flags.WebConfigFile

tlsConfig, err = getTLSConfig(tlsConfigPath)
if err != nil {
return nil, err
}
} else {
err = ValidateWebConfig(flags.WebConfig)
if err != nil {
return nil, err
}
// Use the provided config.
tlsConfig, err = ConfigToTLSConfig(&flags.WebConfig.TLSConfig)
if err != nil {
return nil, err
}
}
config.NextProtos = server.TLSConfig.NextProtos
return config, nil
tlsConfig.NextProtos = server.TLSConfig.NextProtos
return tlsConfig, nil
}
return server.ServeTLS(l, "", "")
}
Expand All @@ -410,20 +443,37 @@ func Validate(tlsConfigPath string) error {
if tlsConfigPath == "" {
return nil
}
if err := validateUsers(tlsConfigPath); err != nil {
return err
}
c, err := getConfig(tlsConfigPath)
if err != nil {
return err
}
if err := validateUsers(c); err != nil {
return err
}
_, err = ConfigToTLSConfig(&c.TLSConfig)
if err == errNoTLSConfig {
return nil
}
return err
}

// ValidateWebConfig validates the web configuration, including the TLS config and HTTP headers.
func ValidateWebConfig(config *Config) error {
if config == nil {
return nil
}
if err := validateUsers(config); err != nil {
return err
}
if err := validateHeaderConfig(config.HTTPConfig.Header); err != nil {
return err
}
if _, err := ConfigToTLSConfig(&config.TLSConfig); err != nil && err != errNoTLSConfig {
return err
}
return nil
}

type Cipher uint16

func (c *Cipher) UnmarshalYAML(unmarshal func(interface{}) error) error {
Expand Down
Loading