From baa28f7317c4fa9928733c78a7091e453b1ec0d9 Mon Sep 17 00:00:00 2001 From: Thomas Montague Date: Thu, 10 Sep 2020 13:26:42 -0400 Subject: [PATCH 1/2] Add option to control if a cli is allowed to run as root. --- cli/app.go | 8 +++++--- cli/run.go | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/cli/app.go b/cli/app.go index 415446a0b..209e3cbd4 100644 --- a/cli/app.go +++ b/cli/app.go @@ -25,6 +25,7 @@ type App struct { OnExit OnExit ContextConfig func(Context, context.Context) context.Context ContextOptions []ContextOption + AllowRoot bool } type Option func(*App) @@ -33,9 +34,10 @@ type ContextOption func(*Context) func NewApp(opts ...Option) *App { app := &App{ - Stdout: os.Stdout, - Stderr: os.Stderr, - OnExit: newOnExit(), + Stdout: os.Stdout, + Stderr: os.Stderr, + OnExit: newOnExit(), + AllowRoot: false, } for _, opt := range opts { opt(app) diff --git a/cli/run.go b/cli/run.go index 1fbd0a7bb..0bd954a2a 100644 --- a/cli/run.go +++ b/cli/run.go @@ -37,6 +37,13 @@ func (app *App) Run(args []string) (exitStatus int) { return 0 } + if !app.AllowRoot { + if syscall.Getuid() == 0 { + ctx.Errorf("%v\n", "Root is not allowed to run this program.") + return 1 + } + } + baseContext := context.Background() if app.ContextConfig != nil { baseContext = app.ContextConfig(ctx, baseContext) From 2e930514e9127fd776fe0bcf0ba183a584025faa Mon Sep 17 00:00:00 2001 From: Thomas Montague Date: Thu, 10 Sep 2020 13:46:21 -0400 Subject: [PATCH 2/2] Allow tests to run as root. --- cli/cliviper/cliviper_test.go | 1 + cli/debugapp_test.go | 1 + cli/exitcode_test.go | 1 + cli/flagvalue_test.go | 2 ++ cli/parse_test.go | 1 + cli/run_test.go | 5 +++++ 6 files changed, 11 insertions(+) diff --git a/cli/cliviper/cliviper_test.go b/cli/cliviper/cliviper_test.go index 136ab3447..748d5c8f5 100644 --- a/cli/cliviper/cliviper_test.go +++ b/cli/cliviper/cliviper_test.go @@ -22,6 +22,7 @@ func TestCLIViperApp(t *testing.T) { // set cliviper.App() option app := cli.NewApp(cliviper.App()) + app.AllowRoot = true app.Flags = []flag.Flag{ flag.StringFlag{Name: msgFlag}, } diff --git a/cli/debugapp_test.go b/cli/debugapp_test.go index 38db56ba1..fdef50516 100644 --- a/cli/debugapp_test.go +++ b/cli/debugapp_test.go @@ -50,6 +50,7 @@ func TestNewDebugApp(t *testing.T) { {err: testExitCoder{error: fmt.Errorf("foo"), exitCode: 2}, wantExitCode: 2, errorStringer: testErrorStringer, wantDebugFalse: "^foo\n$", wantDebugTrue: "^error-stringer\n$"}, } { app := cli.NewApp(cli.DebugHandler(currCase.errorStringer)) + app.AllowRoot = true app.Action = func(ctx cli.Context) error { return currCase.err } diff --git a/cli/exitcode_test.go b/cli/exitcode_test.go index 3b7a59f98..bfb040295 100644 --- a/cli/exitcode_test.go +++ b/cli/exitcode_test.go @@ -28,6 +28,7 @@ import ( func main() { app := cli.NewApp() + app.AllowRoot = true app.Action = func(ctx cli.Context) error { %v } diff --git a/cli/flagvalue_test.go b/cli/flagvalue_test.go index eb1785348..e5300ffc5 100644 --- a/cli/flagvalue_test.go +++ b/cli/flagvalue_test.go @@ -52,6 +52,7 @@ func TestBindFlagValues(t *testing.T) { }, } { app := cli.NewApp() + app.AllowRoot = true app.Command = cli.Command{ Name: "foo", Flags: []flag.Flag{ @@ -111,6 +112,7 @@ func TestBindFlagValuesStringParam(t *testing.T) { }, } { app := cli.NewApp() + app.AllowRoot = true app.Command = cli.Command{ Name: "foo", Flags: []flag.Flag{ diff --git a/cli/parse_test.go b/cli/parse_test.go index 674167fe9..ebd5afc5d 100644 --- a/cli/parse_test.go +++ b/cli/parse_test.go @@ -294,6 +294,7 @@ func TestParseFlags(t *testing.T) { t.Run(currCase.name, func(t *testing.T) { app := cli.NewApp() app.Name = "test" + app.AllowRoot = true output := &bytes.Buffer{} app.Subcommands = []cli.Command{ diff --git a/cli/run_test.go b/cli/run_test.go index 632bf3941..c8a442308 100644 --- a/cli/run_test.go +++ b/cli/run_test.go @@ -34,6 +34,7 @@ func TestRunErrorOutput(t *testing.T) { for i, currCase := range cases { app := cli.NewApp() + app.AllowRoot = true app.Action = func(ctx cli.Context) error { return currCase.err } @@ -79,6 +80,7 @@ func TestRunErrorHandler(t *testing.T) { for i, currCase := range cases { app := cli.NewApp() + app.AllowRoot = true app.ErrorHandler = currCase.handler app.Action = func(ctx cli.Context) error { return currCase.err @@ -118,6 +120,7 @@ func TestRunContext(t *testing.T) { name: "check that context is propagated to app action", check: func(t *testing.T) { app := cli.NewApp() + app.AllowRoot = true app.Command.Flags = []flag.Flag{ flag.StringFlag{ @@ -136,6 +139,7 @@ func TestRunContext(t *testing.T) { name: "check that context is propagated to app error handler", check: func(t *testing.T) { app := cli.NewApp() + app.AllowRoot = true app.Command.Flags = []flag.Flag{ flag.StringFlag{ @@ -162,6 +166,7 @@ func TestRunContext(t *testing.T) { name: "check that context is propagated to app subcommand", check: func(t *testing.T) { app := cli.NewApp() + app.AllowRoot = true app.Command.Flags = []flag.Flag{ flag.StringFlag{