diff --git a/cmd/viam-agent/main.go b/cmd/viam-agent/main.go index 4c4bae6b..ffa69827 100644 --- a/cmd/viam-agent/main.go +++ b/cmd/viam-agent/main.go @@ -213,7 +213,8 @@ func setupExitSignalHandling() (context.Context, context.CancelFunc) { case syscall.SIGABRT: fallthrough case syscall.SIGTERM: - globalLogger.Info("exiting") + exitMsg := fmt.Sprintf("Signal received. %s will now exit to be restarted by service manager", agent.SubsystemName) + globalLogger.Infow(exitMsg, "signal", sig) signal.Ignore(os.Interrupt, syscall.SIGTERM, syscall.SIGABRT) // keeping SIGQUIT for stack trace debugging return diff --git a/go.mod b/go.mod index 523f7327..afd3605d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/jessevdk/go-flags v1.6.1 github.com/nightlyone/lockfile v1.0.0 github.com/pkg/errors v0.9.1 + github.com/samber/mo v1.16.0 github.com/schollz/progressbar/v3 v3.18.0 github.com/sergeymakinen/go-systemdconf/v2 v2.0.2 github.com/tidwall/jsonc v0.3.2 @@ -29,7 +30,7 @@ require ( require ( github.com/cenkalti/backoff v2.2.1+incompatible // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect github.com/dgottlieb/smarty-assertions v1.2.6 // indirect @@ -72,7 +73,7 @@ require ( github.com/pion/stun v0.6.1 // indirect github.com/pion/transport/v2 v2.2.10 // indirect github.com/pion/turn/v2 v2.1.6 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/cors v1.11.1 // indirect github.com/saltosystems/winrt-go v0.0.0-20240509164145-4f7860a3bd2b // indirect @@ -80,7 +81,7 @@ require ( github.com/soypat/cyw43439 v0.0.0-20241116210509-ae1ce0e084c5 // indirect github.com/soypat/seqs v0.0.0-20240527012110-1201bab640ef // indirect github.com/srikrsna/protoc-gen-gotag v0.6.2 // indirect - github.com/stretchr/testify v1.10.0 // indirect + github.com/stretchr/testify v1.11.1 // indirect github.com/tinygo-org/cbgo v0.0.4 // indirect github.com/tinygo-org/pio v0.0.0-20231216154340-cd888eb58899 // indirect github.com/viamrobotics/webrtc/v3 v3.99.10 // indirect diff --git a/go.sum b/go.sum index 56eebb07..44585be7 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,9 @@ github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7Do github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/daixiang0/gci v0.2.8/go.mod h1:+4dZ7TISfSmqfAGv59ePaHfNzgGtIkHAhhdKggP1JAc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= @@ -539,8 +540,9 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polyfloyd/go-errorlint v0.0.0-20201127212506-19bd8db6546f/go.mod h1:wi9BfjxjF/bwiZ701TzmfKu6UKC357IOAtNr0Td0Lvw= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -593,6 +595,8 @@ github.com/ryanrolds/sqlclosecheck v0.3.0/go.mod h1:1gREqxyTGR3lVtpngyFo3hZAgk0K github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/saltosystems/winrt-go v0.0.0-20240509164145-4f7860a3bd2b h1:du3zG5fd8snsFN6RBoLA7fpaYV9ZQIsyH9snlk2Zvik= github.com/saltosystems/winrt-go v0.0.0-20240509164145-4f7860a3bd2b/go.mod h1:CIltaIm7qaANUIvzr0Vmz71lmQMAIbGJ7cvgzX7FMfA= +github.com/samber/mo v1.16.0 h1:qpEPCI63ou6wXlsNDMLE0IIN8A+devbGX/K1xdgr4b4= +github.com/samber/mo v1.16.0/go.mod h1:DlgzJ4SYhOh41nP1L9kh9rDNERuf8IqWSAs+gj2Vxag= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/sanposhiho/wastedassign v0.1.3/go.mod h1:LGpq5Hsv74QaqM47WtIsRSF/ik9kqk07kchgv66tLVE= github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= @@ -661,8 +665,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tdakkota/asciicheck v0.0.0-20200416200610-e657995f937b/go.mod h1:yHp0ai0Z9gUljN3o0xMhYJnH/IcvkdTBOX2fmJ93JEM= github.com/tetafro/godot v1.4.4/go.mod h1:FVDd4JuKliW3UgjswZfJfHq4vAx0bD/Jd5brJjGeaz4= diff --git a/manager.go b/manager.go index a5a655ff..ec07cd0b 100644 --- a/manager.go +++ b/manager.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/url" "os" "regexp" @@ -21,13 +22,18 @@ import ( "github.com/viamrobotics/agent/subsystems/viamserver" "github.com/viamrobotics/agent/utils" pb "go.viam.com/api/app/agent/v1" + apppb "go.viam.com/api/app/v1" "go.viam.com/rdk/logging" goutils "go.viam.com/utils" "go.viam.com/utils/rpc" ) const ( - minimalCheckInterval = time.Second * 5 + // The minimal (and default) interval for checking for config updates via DeviceAgentConfig. + minimalDeviceAgentConfigCheckInterval = time.Second * 5 + // The minimal (and default) interval for checking whether agent needs to be restarted. + minimalNeedsRestartCheckInterval = time.Second * 1 + defaultNetworkTimeout = time.Second * 15 // stopAllTimeout must be lower than systemd subsystems/viamagent/viam-agent.service timeout of 4mins // and higher than subsystems/viamserver/viamserver.go timeout of 2mins. @@ -42,7 +48,6 @@ type Manager struct { connMu sync.RWMutex conn rpc.ClientConn - client pb.AgentDeviceServiceClient cloudConfig *logging.CloudConfig logger logging.Logger @@ -209,7 +214,7 @@ func (m *Manager) SubsystemUpdates(ctx context.Context) { m.logger.Warn(err) } if m.viamAgentNeedsRestart { - m.Exit() + m.Exit(fmt.Sprintf("A new version of %s has been installed", SubsystemName)) return } } else { @@ -221,17 +226,19 @@ func (m *Manager) SubsystemUpdates(ctx context.Context) { needRestartConfigChange := m.viamServer.Update(ctx, m.cfg) if needRestart || needRestartConfigChange || m.viamServerNeedsRestart || m.viamAgentNeedsRestart { - if m.viamServer.(viamserver.RestartCheck).SafeToRestart(ctx) { + if m.viamServer.Property(ctx, viamserver.RestartPropertyRestartAllowed) { + m.logger.Infof("%s has allowed a restart; will restart", viamserver.SubsysName) if err := m.viamServer.Stop(ctx); err != nil { m.logger.Warn(err) } else { m.viamServerNeedsRestart = false } if m.viamAgentNeedsRestart { - m.Exit() + m.Exit(fmt.Sprintf("A new version of %s has been installed", SubsystemName)) return } } else { + m.logger.Warnf("%s has NOT allowed a restart; will NOT restart", viamserver.SubsysName) m.viamServerNeedsRestart = true } } @@ -280,26 +287,26 @@ func (m *Manager) SubsystemUpdates(ctx context.Context) { // CheckUpdates retrieves an updated config from the cloud, and then passes it to SubsystemUpdates(). func (m *Manager) CheckUpdates(ctx context.Context) time.Duration { defer utils.Recover(m.logger, nil) - m.logger.Debug("Checking cloud for update") - interval, err := m.GetConfig(ctx) + m.logger.Debug("Checking cloud for device agent config updates") + deviceAgentConfigCheckInterval, err := m.GetConfig(ctx) - if interval < minimalCheckInterval { - interval = minimalCheckInterval + if deviceAgentConfigCheckInterval < minimalDeviceAgentConfigCheckInterval { + deviceAgentConfigCheckInterval = minimalDeviceAgentConfigCheckInterval } // randomly fuzz the interval by +/- 5% - interval = utils.FuzzTime(interval, 0.05) + deviceAgentConfigCheckInterval = utils.FuzzTime(deviceAgentConfigCheckInterval, 0.05) // we already log in all error cases inside GetConfig, so // no need to log again. if err != nil { - return interval + return deviceAgentConfigCheckInterval } // update and (re)start subsystems m.SubsystemUpdates(ctx) - return interval + return deviceAgentConfigCheckInterval } func (m *Manager) setDebug(debug bool) { @@ -380,13 +387,51 @@ func (m *Manager) SubsystemHealthChecks(ctx context.Context) { } } +// CheckIfNeedsRestart returns the check restart interval and whether the agent (and +// therefore all its subsystems) has been forcibly restarted by app. +func (m *Manager) CheckIfNeedsRestart(ctx context.Context) (time.Duration, bool) { + m.logger.Debug("Checking cloud for forced restarts") + if m.cloudConfig == nil { + m.logger.Warn("can't CheckIfNeedsRestart until successful config load") + return minimalNeedsRestartCheckInterval, false + } + + // Only continue this check if viam-server does not handle restart checking itself + // (return early if viamserver _does_ handle restart checking). + if !m.viamServer.Property(ctx, viamserver.RestartPropertyDoesNotHandleNeedsRestart) { + return minimalNeedsRestartCheckInterval, false + } + + m.logger.Debug("Checking cloud for forced restarts") + timeoutCtx, cancelFunc := context.WithTimeout(ctx, defaultNetworkTimeout) + defer cancelFunc() + + if err := m.dial(timeoutCtx); err != nil { + m.logger.Warn(errw.Wrapf(err, "dialing to check if restart needed")) + return minimalNeedsRestartCheckInterval, false + } + + robotServiceClient := apppb.NewRobotServiceClient(m.conn) + req := &apppb.NeedsRestartRequest{Id: m.cloudConfig.ID} + res, err := robotServiceClient.NeedsRestart(timeoutCtx, req) + if err != nil { + m.logger.Warn(errw.Wrapf(err, "checking if restart needed")) + return minimalNeedsRestartCheckInterval, false + } + + return res.GetRestartCheckInterval().AsDuration(), res.GetMustRestart() +} + // CloseAll stops all subsystems and closes the cloud connection. func (m *Manager) CloseAll() { ctx, cancel := context.WithCancel(context.Background()) // Use a slow goroutine watcher to log and continue if shutdown is taking too long. slowWatcher, slowWatcherCancel := goutils.SlowGoroutineWatcher( - stopAllTimeout, "Agent is taking a while to shut down,", m.logger) + stopAllTimeout, + fmt.Sprintf("Viam agent subsystems and/or background workers failed to shut down within %v", stopAllTimeout), + m.logger, + ) slowTicker := time.NewTicker(10 * time.Second) defer slowTicker.Stop() @@ -430,7 +475,6 @@ func (m *Manager) CloseAll() { } } - m.client = nil m.conn = nil }) @@ -479,7 +523,8 @@ func (m *Manager) CloseAll() { } } -// StartBackgroundChecks kicks off a go routine that loops on a timer to check for updates and health checks. +// StartBackgroundChecks kicks off go routines that loop on a timerr to check for updates, +// health checks, and restarts. func (m *Manager) StartBackgroundChecks(ctx context.Context) { if ctx.Err() != nil { return @@ -495,18 +540,18 @@ func (m *Manager) StartBackgroundChecks(ctx context.Context) { }) defer m.activeBackgroundWorkers.Done() - checkInterval := minimalCheckInterval + deviceAgentConfigCheckInterval := minimalDeviceAgentConfigCheckInterval m.cfgMu.RLock() wait := m.cfg.AdvancedSettings.WaitForUpdateCheck.Get() m.cfgMu.RUnlock() if wait { - checkInterval = m.CheckUpdates(ctx) + deviceAgentConfigCheckInterval = m.CheckUpdates(ctx) } else { // premptively start things before we go into the regular update/check/restart m.SubsystemHealthChecks(ctx) } - timer := time.NewTimer(checkInterval) + timer := time.NewTimer(deviceAgentConfigCheckInterval) defer timer.Stop() for { if ctx.Err() != nil { @@ -516,9 +561,39 @@ func (m *Manager) StartBackgroundChecks(ctx context.Context) { case <-ctx.Done(): return case <-timer.C: - checkInterval = m.CheckUpdates(ctx) + deviceAgentConfigCheckInterval = m.CheckUpdates(ctx) m.SubsystemHealthChecks(ctx) - timer.Reset(checkInterval) + timer.Reset(deviceAgentConfigCheckInterval) + } + } + }() + + m.activeBackgroundWorkers.Add(1) + go func() { + defer m.activeBackgroundWorkers.Done() + + timer := time.NewTimer(minimalNeedsRestartCheckInterval) + defer timer.Stop() + for { + if ctx.Err() != nil { + return + } + select { + case <-ctx.Done(): + return + case <-timer.C: + needsRestartCheckInterval, needsRestart := m.CheckIfNeedsRestart(ctx) + if needsRestartCheckInterval < minimalNeedsRestartCheckInterval { + needsRestartCheckInterval = minimalNeedsRestartCheckInterval + } + if needsRestart { + // Do not mark m.agentNeedsRestart and instead Exit immediately; we do not want + // to wait for viam-server to allow a restart as it may be in a bad state. + m.Exit(fmt.Sprintf("A restart of %s was requested from app", SubsystemName)) + } + // As with the device agent config check interval, randomly fuzz the interval by + // +/- 5%. + timer.Reset(utils.FuzzTime(needsRestartCheckInterval, 0.05)) } } }() @@ -531,11 +606,11 @@ func (m *Manager) dial(ctx context.Context) error { return ctx.Err() } if m.cloudConfig == nil { - return errors.New("cannot dial() until successful LoadConfig") + return errors.New("cannot dial() until successful config load") } m.connMu.Lock() defer m.connMu.Unlock() - if m.client != nil { + if m.conn != nil { return nil } @@ -564,7 +639,6 @@ func (m *Manager) dial(ctx context.Context) error { return err } m.conn = conn - m.client = pb.NewAgentDeviceServiceClient(m.conn) if m.netAppender != nil { m.netAppender.SetConn(conn, true) @@ -577,27 +651,28 @@ func (m *Manager) dial(ctx context.Context) error { // GetConfig retrieves the configuration from the cloud. func (m *Manager) GetConfig(ctx context.Context) (time.Duration, error) { if m.cloudConfig == nil { - err := errors.New("can't GetConfig until successful LoadConfig") + err := errors.New("can't GetConfig until successful config load") m.logger.Warn(err) - return minimalCheckInterval, err + return minimalDeviceAgentConfigCheckInterval, err } timeoutCtx, cancelFunc := context.WithTimeout(ctx, defaultNetworkTimeout) defer cancelFunc() if err := m.dial(timeoutCtx); err != nil { - m.logger.Warn(errw.Wrapf(err, "fetching %s config", SubsystemName)) - return minimalCheckInterval, err + m.logger.Warn(errw.Wrapf(err, "dialing to fetch %s config", SubsystemName)) + return minimalDeviceAgentConfigCheckInterval, err } + agentDeviceServiceClient := pb.NewAgentDeviceServiceClient(m.conn) req := &pb.DeviceAgentConfigRequest{ Id: m.cloudConfig.ID, HostInfo: m.getHostInfo(), VersionInfo: m.getVersions(), } - resp, err := m.client.DeviceAgentConfig(timeoutCtx, req) + resp, err := agentDeviceServiceClient.DeviceAgentConfig(timeoutCtx, req) if err != nil { m.logger.Warn(errw.Wrapf(err, "fetching %s config", SubsystemName)) - return minimalCheckInterval, err + return minimalDeviceAgentConfigCheckInterval, err } fixWindowsPaths(resp) @@ -699,7 +774,7 @@ func (m *Manager) getVersions() *pb.VersionInfo { return vers } -func (m *Manager) Exit() { - m.logger.Info("A new viam-agent has been installed. Will now exit to be restarted by service manager.") +func (m *Manager) Exit(reason string) { + m.logger.Infow(fmt.Sprintf("%s will now exit to be restarted by service manager", SubsystemName), "reason", reason) m.globalCancel() } diff --git a/subsystems/networking/networking_linux.go b/subsystems/networking/networking_linux.go index 5288511f..18f9bce3 100644 --- a/subsystems/networking/networking_linux.go +++ b/subsystems/networking/networking_linux.go @@ -425,3 +425,8 @@ func (n *Networking) writeWifiPowerSave(ctx context.Context) error { return nil } + +// Property is a noop for the networking subsystem. +func (n *Networking) Property(_ context.Context, _ string) bool { + return false +} diff --git a/subsystems/subsystems.go b/subsystems/subsystems.go index be1de316..aab9257d 100644 --- a/subsystems/subsystems.go +++ b/subsystems/subsystems.go @@ -19,6 +19,9 @@ type Subsystem interface { // HealthCheck reports if a subsystem is running correctly (it is restarted if not) HealthCheck(ctx context.Context) error + + // Property gets an arbitrary property about the running subystem. + Property(ctx context.Context, property string) bool } // Dummy is a fake subsystem for when a particular OS doesn't (yet) have support. @@ -39,3 +42,7 @@ func (d *Dummy) Update(_ context.Context, _ utils.AgentConfig) bool { func (d *Dummy) HealthCheck(_ context.Context) error { return nil } + +func (d *Dummy) Property(_ context.Context, _ string) bool { + return false +} diff --git a/subsystems/syscfg/syscfg_linux.go b/subsystems/syscfg/syscfg_linux.go index 85c0c53c..70fee82b 100644 --- a/subsystems/syscfg/syscfg_linux.go +++ b/subsystems/syscfg/syscfg_linux.go @@ -114,3 +114,8 @@ func (s *syscfg) HealthCheck(ctx context.Context) error { } return errors.New("healthcheck failed") } + +// Property is a noop for the syscfg subsystem. +func (s *syscfg) Property(_ context.Context, _ string) bool { + return false +} diff --git a/subsystems/viamserver/restart_properties.go b/subsystems/viamserver/restart_properties.go new file mode 100644 index 00000000..b6b007af --- /dev/null +++ b/subsystems/viamserver/restart_properties.go @@ -0,0 +1,205 @@ +package viamserver + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "runtime" + "testing" + "time" + + errw "github.com/pkg/errors" + "github.com/samber/mo" + goutils "go.viam.com/utils" +) + +type ( + // RestartStatusResponse is the http/json response from viamserver's /restart_status URL. + RestartStatusResponse struct { + // RestartAllowed represents whether this instance of the viamserver can be + // safely restarted. + RestartAllowed bool `json:"restart_allowed"` + // DoesNotHandleNeedsRestart represents whether this instance of the viamserver does + // not check for the need to restart against app itself and, thus, needs agent to do so. + // Newer versions of viamserver (>= v0.9x.0) will report true for this value, while + // older versions won't report it at all, and agent should let viamserver handle + // NeedsRestart logic. + DoesNotHandleNeedsRestart bool `json:"does_not_handle_needs_restart,omitempty"` + } + + // restartProperty is a property related to restarting about which agent can query + // viamserver. + restartProperty = string +) + +const ( + RestartPropertyRestartAllowed restartProperty = "restart allowed" + RestartPropertyDoesNotHandleNeedsRestart restartProperty = "does not handle needs restart" + + restartURLSuffix = "/restart_status" + + checkRestartPropertyTimeout = 10 * time.Second +) + +// Creates test URLs for property checks. Must be called with s.mu locked. +func (s *viamServer) makeTestURLs(rp restartProperty) ([]string, error) { + urls := []string{s.checkURL, s.checkURLAlt} + // On Windows, the local IPV4 addresses created below this check will not be reachable. + // Tests for checkRestartProperty are also unable to reach the local IPV4s created below + // due to how the test server is set up. + //nolint:goconst + if runtime.GOOS == "windows" || testing.Testing() { + return urls, nil + } + + port := "8080" + mainURL, err := url.Parse(s.checkURL) + if err != nil { + s.logger.Warnf("Cannot determine port for %s check, using default of 8080", rp) + } else { + port = mainURL.Port() + s.logger.Debugf("Using port %s for %s check", port, rp) + } + + ips, err := getAllLocalIPv4s() + if err != nil { + return []string{}, err + } + for _, ip := range ips { + urls = append(urls, fmt.Sprintf("https://%s:%s", ip, port)) + } + + return urls, nil +} + +// Gets all local IPV4s. Copied from goutils, but loopback checks are removed, as we DO +// want loopback adapters. Used in creating test URLS. +func getAllLocalIPv4s() ([]string, error) { + allInterfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + all := []string{} + + for _, i := range allInterfaces { + addrs, err := i.Addrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + _, bits := v.Mask.Size() + if bits != 32 { + // this is what limits to ipv4 + continue + } + + all = append(all, v.IP.String()) + default: + return nil, fmt.Errorf("unknown address type: %T", v) + } + } + } + + return all, nil +} + +// Returns the value of the requested restart property (false if not determined) and any +// encountered errors. Must be called with s.mu held, as makeTestURLs is called. +func (s *viamServer) checkRestartProperty(ctx context.Context, rp restartProperty) (bool, error) { + urls, err := s.makeTestURLs(rp) + if err != nil { + return false, err + } + + // Create a buffered channel for Result[bool] values. Sending to this channel should not + // block, as we'll only ever have len(urls) goroutines trying to send one value. + resultChan := make(chan mo.Result[bool], len(urls)) + + timeoutCtx, cancelFunc := context.WithTimeout(ctx, checkRestartPropertyTimeout) + defer cancelFunc() + + // Disabling the cert verification because it doesn't work in offline mode (when + // connecting to localhost). + //nolint:gosec + client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} + defer func() { + // CloseIdleConnections at the end of the method to ensure that any goroutine created + // below does not leave an idle HTTP connection open to the server. + client.CloseIdleConnections() + }() + + for _, url := range urls { + go func() { + s.logger.Debugf("Starting %s check for %s using %s", rp, SubsysName, url) + + restartURL := url + restartURLSuffix + + req, err := http.NewRequestWithContext(timeoutCtx, http.MethodGet, restartURL, nil) + if err != nil { + resultChan <- mo.Err[bool]( + errw.Wrapf(err, "creating HTTP request for %s check for %s via %s", + rp, SubsysName, restartURL)) + return + } + + resp, err := client.Do(req) + if err != nil { + resultChan <- mo.Err[bool](errw.Wrapf(err, "sending HTTP request for %s check for %s via %s", + rp, SubsysName, restartURL)) + return + } + defer func() { + goutils.UncheckedError(resp.Body.Close()) + }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + // Interacting with older viam-server instances will result in a non-successful + // HTTP response status code, as the /restart_status endpoint will not be + // available. + resultChan <- mo.Errf[bool]("checking %s status via %s, got code: %d", SubsysName, restartURL, resp.StatusCode) + return + } + + var restartStatusResponse RestartStatusResponse + if err = json.NewDecoder(resp.Body).Decode(&restartStatusResponse); err != nil { + resultChan <- mo.Err[bool](errw.Wrapf(err, "decoding HTTP response for %s check for %s via %s", + rp, SubsysName, restartURL)) + return + } + + switch rp { + case RestartPropertyRestartAllowed: + resultChan <- mo.Ok(restartStatusResponse.RestartAllowed) + case RestartPropertyDoesNotHandleNeedsRestart: + resultChan <- mo.Ok(restartStatusResponse.DoesNotHandleNeedsRestart) + } + }() + } + + var combinedErr error + for range urls { + select { + case result := <-resultChan: + property, err := result.Get() + if err != nil { + combinedErr = errors.Join(combinedErr, err) + } else { + // We assume below that the first test URL through which we encountered the feature's + // value represents the actual value. + return property, nil + } + case <-timeoutCtx.Done(): + return false, errors.Join(combinedErr, ctx.Err()) + } + } + return false, combinedErr +} diff --git a/subsystems/viamserver/restart_properties_test.go b/subsystems/viamserver/restart_properties_test.go new file mode 100644 index 00000000..cd632cc9 --- /dev/null +++ b/subsystems/viamserver/restart_properties_test.go @@ -0,0 +1,135 @@ +package viamserver + +import ( + "context" + "encoding/json" + "net" + "net/http" + "sync" + "testing" + + "go.viam.com/rdk/logging" + "go.viam.com/test" + "go.viam.com/utils" +) + +// Mimics an old server's response to the restart_status HTTP endpoint. +type oldRestartStatusResponse struct { + RestartAllowed bool `json:"restart_allowed"` +} + +// Ensures that checkRestartProperty works correctly for restart_allowed and +// does_not_handle_needs_restart against a fake viamserver instance (HTTP server). +func TestCheckRestartProperty(t *testing.T) { + logger := logging.NewTestLogger(t) + ctx := context.Background() + + targetAddr := "localhost:8080" + s := &viamServer{ + logger: logger, + // checkURL will normally be the .cloud address of the machine; use localhost instead + // here. + checkURL: "http://" + targetAddr, + // checkURLAlt is always 127.0.0.1:[bind-port] in agent code. + checkURLAlt: "http://127.0.0.1:8080", + } + + falseVal := false + trueVal := true + testCases := []struct { + name string + expectedRestartAllowed bool + // Can be unset (mimic old server), false, and true. + expectedDoesNotHandleNeedsRestart *bool + }{ + { + "restart_allowed=false;does_not_handle_needs_restart=unset", + false, + nil, + }, + { + "restart_allowed=false;does_not_handle_needs_restart=false", + false, + &falseVal, + }, + { + "restart_allowed=false;does_not_handle_needs_restart=true", + false, + &trueVal, + }, + { + "restart_allowed=true;does_not_handle_needs_restart=unset", + true, + nil, + }, + { + "restart_allowed=true;does_not_handle_needs_restart=false", + true, + nil, + }, + { + "restart_allowed=true;does_not_handle_needs_restart=true", + true, + &trueVal, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var expectedRestartStatusResponse any + if tc.expectedDoesNotHandleNeedsRestart != nil { + expectedRestartStatusResponse = RestartStatusResponse{ + RestartAllowed: tc.expectedRestartAllowed, + DoesNotHandleNeedsRestart: *tc.expectedDoesNotHandleNeedsRestart, + } + } else { + expectedRestartStatusResponse = oldRestartStatusResponse{ + RestartAllowed: tc.expectedRestartAllowed, + } + } + + mux := http.NewServeMux() + mux.HandleFunc("/restart_status", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + test.That(t, json.NewEncoder(w).Encode(expectedRestartStatusResponse), test.ShouldBeNil) + }) + // Use NewPossiblySecureHTTPServer to mimic RDK's behavior. + httpServer, err := utils.NewPossiblySecureHTTPServer(mux, utils.HTTPServerOptions{ + Secure: false, + Addr: targetAddr, + }) + test.That(t, err, test.ShouldBeNil) + ln, err := net.Listen("tcp", targetAddr) + test.That(t, err, test.ShouldBeNil) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := httpServer.Serve(ln) + // Should be "server closed" due to Shutdown below. + test.That(t, err, test.ShouldBeError, http.ErrServerClosed) + }() + t.Cleanup(func() { + test.That(t, httpServer.Shutdown(ctx), test.ShouldBeNil) + wg.Wait() + }) + + s.mu.Lock() + t.Cleanup(s.mu.Unlock) + + restartAllowed, err := s.checkRestartProperty(ctx, RestartPropertyRestartAllowed) + test.That(t, err, test.ShouldBeNil) + test.That(t, restartAllowed, test.ShouldEqual, tc.expectedRestartAllowed) + + doesNotHandleNeedsRestart, err := s.checkRestartProperty(ctx, RestartPropertyDoesNotHandleNeedsRestart) + test.That(t, err, test.ShouldBeNil) + // does_not_handle_restart should be false if explicitly false or unset in the test + // case. + var expectedDoesNotHandleNeedsRestart bool + if tc.expectedDoesNotHandleNeedsRestart != nil { + expectedDoesNotHandleNeedsRestart = *tc.expectedDoesNotHandleNeedsRestart + } + test.That(t, doesNotHandleNeedsRestart, test.ShouldEqual, expectedDoesNotHandleNeedsRestart) + }) + } +} diff --git a/subsystems/viamserver/viamserver.go b/subsystems/viamserver/viamserver.go index 519e2254..22372546 100644 --- a/subsystems/viamserver/viamserver.go +++ b/subsystems/viamserver/viamserver.go @@ -3,13 +3,7 @@ package viamserver import ( "context" - "crypto/tls" - "encoding/json" "errors" - "fmt" - "net" - "net/http" - "net/url" "os" "os/exec" "path" @@ -24,23 +18,22 @@ import ( "github.com/viamrobotics/agent/subsystems" "github.com/viamrobotics/agent/utils" "go.viam.com/rdk/logging" - goutils "go.viam.com/utils" ) const ( // stopTermTimeout must be higher than viam-server shutdown timeout of 90 secs. stopTermTimeout = time.Minute * 2 stopKillTimeout = time.Second * 10 - SubsysName = "viam-server" -) -// RestartStatusResponse is the http/json response from viam_server's /health_check URL -// This MUST remain in sync with RDK. -type RestartStatusResponse struct { - // RestartAllowed represents whether this instance of the viam-server can be - // safely restarted. - RestartAllowed bool `json:"restart_allowed"` -} + SubsysName = "viam-server" + + // ViamAgentHandlesNeedsRestartChecking is the environment variable that viam-agent will + // set before starting viam-server to indicate that agent is a new enough version to + // have its own background loop that runs NeedsRestart against app.viam.com to determine + // if the system needs a restart. MUST be kept in line with the equivalent value in the + // rdk repo. + ViamAgentHandlesNeedsRestartChecking = "VIAM_AGENT_HANDLES_NEEDS_RESTART_CHECKING" +) type viamServer struct { mu sync.Mutex @@ -59,6 +52,12 @@ type viamServer struct { // for blocking start/stop/check ops while another is in progress startStopMu sync.Mutex + // whether this viamserver instance handles needs restart checking itself; calculated + // and cached at startup; used by the manager to determine whether agent should handle + // needs restart checking on viamserver's behalf (this is the case for new viamserver + // versions) + doesNotHandleNeedsRestart bool + logger logging.Logger } @@ -106,15 +105,18 @@ func (s *viamServer) Start(ctx context.Context) error { s.cmd.Stdout = stdio s.cmd.Stderr = stderr - if len(s.extraEnvVars) > 0 { - s.logger.Infow("Adding environment variables from config to viam-server startup", "extraEnvVars", s.extraEnvVars) + // if s.cmd.Env is not explicitly specified (nil), viam-server would inherit all env vars in Agent's environment + s.cmd.Env = s.cmd.Environ() + // TODO(RSDK-12057): Stop setting this environment variable once we fully remove all + // NeedsRestart checking logic from viam-server. + s.cmd.Env = append(s.cmd.Env, ViamAgentHandlesNeedsRestartChecking+"=true") - // if s.cmd.Env is not explicitly specified (nil), viam-server would inherit all env vars in Agent's environment - s.cmd.Env = s.cmd.Environ() + if len(s.extraEnvVars) > 0 { + s.logger.Infow("Adding extra environment variables from config to viam-server startup", "extraEnvVars", s.extraEnvVars) for k, v := range s.extraEnvVars { s.cmd.Env = append(s.cmd.Env, k+"="+v) } - s.logger.Debugw("Starting viam-server with environment variables", "cmd.Env", s.cmd.Env) + s.logger.Debugw("Starting viam-server with extra environment variables", "cmd.Env", s.cmd.Env) } // watch for this line in the logs to indicate successful startup @@ -150,16 +152,19 @@ func (s *viamServer) Start(ctx context.Context) error { defer s.mu.Unlock() s.running = false s.logger.Infof("%s exited", SubsysName) - if err != nil { - s.logger.Error(errw.Wrap(err, "error while getting process status")) - } - if s.cmd.ProcessState != nil { - s.lastExit = s.cmd.ProcessState.ExitCode() - if s.lastExit != 0 { - s.logger.Errorf("non-zero exit code: %d", s.lastExit) - } - } + // Only log errors from Wait() or the exit code of the process state if subsystem + // exited unexpectedly (was not stopped by agent and is therefore still marked as + // shouldRun). if s.shouldRun { + if err != nil { + s.logger.Error(errw.Wrap(err, "error while getting process status")) + } + if s.cmd.ProcessState != nil { + s.lastExit = s.cmd.ProcessState.ExitCode() + if s.lastExit != 0 { + s.logger.Errorf("non-zero exit code: %d", s.lastExit) + } + } s.logger.Infof("%s exited unexpectedly and will be restarted shortly", SubsysName) } close(s.exitChan) @@ -168,9 +173,24 @@ func (s *viamServer) Start(ctx context.Context) error { select { case matches := <-c: s.checkURL = matches[1] - s.checkURLAlt = strings.Replace(matches[2], "0.0.0.0", "localhost", 1) - s.logger.Infof("viam-server restart allowed check URLs: %s %s", s.checkURL, s.checkURLAlt) + s.checkURLAlt = strings.Replace(matches[2], "0.0.0.0", "127.0.0.1", 1) s.logger.Infof("%s started", SubsysName) + s.logger.Infof("%s found serving at the following URLs: %s %s", SubsysName, s.checkURL, s.checkURLAlt) + + // Once the subsystem has successfully started, check whether it handles needs restart + // logic. We can calculate this value only once at startup and cache it, with the + // assumption that it will not change over the course of the lifetime of the + // subsystem. + s.mu.Lock() + s.doesNotHandleNeedsRestart, err = s.checkRestartProperty(ctx, RestartPropertyDoesNotHandleNeedsRestart) + s.mu.Unlock() + if err != nil { + s.logger.Warn(err) + } + if !s.doesNotHandleNeedsRestart { + s.logger.Warnf("%s may already handle checking needs restart functionality; will not handle in agent", + SubsysName) + } return nil case <-ctx.Done(): return ctx.Err() @@ -203,7 +223,6 @@ func (s *viamServer) Stop(ctx context.Context) error { if err := utils.SignalForTermination(s.cmd.Process.Pid); err != nil { s.logger.Warn(errw.Wrap(err, "signaling viam-server process")) } - if s.waitForExit(ctx, stopTermTimeout) { s.logger.Infof("%s successfully stopped", SubsysName) return nil @@ -247,92 +266,6 @@ func (s *viamServer) HealthCheck(ctx context.Context) error { return nil } -// errUnsafeToRestart is reported to a result channel in isRestartAllowed below when any -// one of the viam-server test URLs explicitly reports that it is unsafe to restart -// the viam-server instance. -var errUnsafeToRestart = errors.New("viam-server reports it is unsafe to restart") - -// Must be called with `s.mu` held, as `s.checkURL` and `s.checkURLAlt` are -// both accessed. -func (s *viamServer) isRestartAllowed(ctx context.Context) (bool, error) { - urls, err := s.makeTestURLs() - if err != nil { - return false, err - } - - resultChan := make(chan error, len(urls)) - - timeoutCtx, cancelFunc := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc() - - for _, url := range urls { - go func() { - s.logger.Debugf("starting restart allowed check for %s using %s", SubsysName, url) - - restartURL := url + "/restart_status" - - req, err := http.NewRequestWithContext(timeoutCtx, http.MethodGet, restartURL, nil) - if err != nil { - resultChan <- errw.Wrapf(err, "checking whether %s allows restart via %s", SubsysName, restartURL) - return - } - - // disabling the cert verification because it doesn't work in offline mode (when connecting to localhost) - //nolint:gosec - client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} - - resp, err := client.Do(req) - if err != nil { - resultChan <- errw.Wrapf(err, "checking whether %s allows restart via %s", SubsysName, restartURL) - return - } - - defer func() { - goutils.UncheckedError(resp.Body.Close()) - }() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - // Interacting with older viam-server instances will result in a - // non-successful HTTP response status code, as the `restart_status` - // endpoint will not be available. Continue to next URL in this - // case. - resultChan <- errw.Wrapf(err, "checking %s status via %s, got code: %d", SubsysName, restartURL, resp.StatusCode) - return - } - - var restartStatusResponse RestartStatusResponse - if err = json.NewDecoder(resp.Body).Decode(&restartStatusResponse); err != nil { - resultChan <- errw.Wrapf(err, "checking whether %s allows restart via %s", SubsysName, restartURL) - return - } - if restartStatusResponse.RestartAllowed { - resultChan <- nil - return - } - resultChan <- errUnsafeToRestart - }() - } - - var combinedErr error - for i := 1; i <= len(urls); i++ { - result := <-resultChan - - // If any test URL reports it is explicitly _safe_ to restart (nil value sent to - // resultChan), we can assume we can restart viam-server. - if result == nil { - return true, nil - } - // If any test URL reports it is explicitly _unsafe_ to restart (errUnsafeToRestart), - // we can assume we should not restart viam-server. - if errors.Is(result, errUnsafeToRestart) { - return false, errUnsafeToRestart - } - - combinedErr = errors.Join(combinedErr, result) - } - return false, combinedErr -} - func (s *viamServer) Update(ctx context.Context, cfg utils.AgentConfig) (needRestart bool) { s.mu.Lock() defer s.mu.Unlock() @@ -348,28 +281,34 @@ func (s *viamServer) Update(ctx context.Context, cfg utils.AgentConfig) (needRes return false } -func (s *viamServer) SafeToRestart(ctx context.Context) bool { +// Property returns a single property of the currently running viamserver. +func (s *viamServer) Property(ctx context.Context, property string) bool { s.mu.Lock() defer s.mu.Unlock() - if !s.running || runtime.GOOS == "windows" { - return true - } + switch property { + case RestartPropertyRestartAllowed: + if !s.running || runtime.GOOS == "windows" { + // Assume agent can restart viamserver if the subsystem is not running or we are on + // Windows. + // + // TODO(RSDK-12271): Allow checks of restart_allowed on Windows. + return true + } - // viam-server can be safely restarted even while running if the process - // has reported it is safe to do so through its `restart_status` HTTP - // endpoint. - restartAllowed, err := s.isRestartAllowed(ctx) - if err != nil { - s.logger.Warn(err) + restartAllowed, err := s.checkRestartProperty(ctx, RestartPropertyRestartAllowed) + if err != nil { + s.logger.Warn(err) + } return restartAllowed + case RestartPropertyDoesNotHandleNeedsRestart: + // We can use the cached value (calculated in Start) for handle needs restart + // property. + return s.doesNotHandleNeedsRestart + default: + s.logger.Errorw("Unknown property requested from viamserver", "property", property) + return false } - if restartAllowed { - s.logger.Infof("will restart %s to run new version, as it has reported allowance of a restart", SubsysName) - } else { - s.logger.Infof("will not restart %s version to run new version, as it has not reported allowance of a restart", SubsysName) - } - return restartAllowed } func NewSubsystem(ctx context.Context, logger logging.Logger, cfg utils.AgentConfig) subsystems.Subsystem { @@ -379,65 +318,3 @@ func NewSubsystem(ctx context.Context, logger logging.Logger, cfg utils.AgentCon extraEnvVars: cfg.AdvancedSettings.ViamServerExtraEnvVars, } } - -type RestartCheck interface { - SafeToRestart(ctx context.Context) bool -} - -// must be called with s.mu locked. -func (s *viamServer) makeTestURLs() ([]string, error) { - port := "8080" - mainURL, err := url.Parse(s.checkURL) - if err != nil { - s.logger.Warnf("cannot determine port for restart allowed check, using default of 8080") - } else { - port = mainURL.Port() - s.logger.Debugf("using port %s for restart allowed check", port) - } - - ips, err := GetAllLocalIPv4s() - if err != nil { - return []string{}, err - } - - urls := []string{s.checkURL, s.checkURLAlt} - for _, ip := range ips { - urls = append(urls, fmt.Sprintf("https://%s:%s", ip, port)) - } - - return urls, nil -} - -// GetAllLocalIPv4s is copied from goutils, but removed the loopback checks, as we DO want loopback adapters. -func GetAllLocalIPv4s() ([]string, error) { - allInterfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - all := []string{} - - for _, i := range allInterfaces { - addrs, err := i.Addrs() - if err != nil { - return nil, err - } - - for _, addr := range addrs { - switch v := addr.(type) { - case *net.IPNet: - _, bits := v.Mask.Size() - if bits != 32 { - // this is what limits to ipv4 - continue - } - - all = append(all, v.IP.String()) - default: - return nil, fmt.Errorf("unknown address type: %T", v) - } - } - } - - return all, nil -} diff --git a/version_control.go b/version_control.go index b52e2923..a6b8ed15 100644 --- a/version_control.go +++ b/version_control.go @@ -221,8 +221,8 @@ func (c *VersionCache) UpdateBinary(ctx context.Context, binary string) (bool, e shasum, err := utils.GetFileSum(verData.UnpackedPath) if err == nil { goodBytes = bytes.Equal(shasum, verData.UnpackedSHA) - } else { - c.logger.Warn(err) + } else if verData.UnpackedPath != "" { // custom file:// URLs with have an empty unpacked path; no need to warn + c.logger.Warnw("Could not calculate shasum", "path", verData.UnpackedPath, "error", err) } if data.TargetVersion == data.CurrentVersion {