diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index f2396b5852..e58ffbdcee 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -24,6 +24,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/http" "github.com/replicate/cog/pkg/image" r8_path "github.com/replicate/cog/pkg/path" "github.com/replicate/cog/pkg/predict" @@ -31,6 +32,7 @@ import ( "github.com/replicate/cog/pkg/util/console" "github.com/replicate/cog/pkg/util/files" "github.com/replicate/cog/pkg/util/mime" + "github.com/replicate/cog/pkg/web" ) const StdinPath = "-" @@ -42,6 +44,7 @@ var ( setupTimeout uint32 useReplicateAPIToken bool inputJSON string + replicateUsername string ) func newPredictCommand() *cobra.Command { @@ -69,6 +72,7 @@ the prediction on that.`, addFastFlag(cmd) addLocalImage(cmd) addConfigFlag(cmd) + addReplicateUsernameFlag(cmd) cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg") cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path") @@ -253,6 +257,19 @@ func cmdPredict(cmd *cobra.Command, args []string) error { } } + if replicateUsername != "" { + client, err := http.ProvideHTTPClient(ctx, dockerClient) + if err != nil { + return err + } + webClient := web.NewClient(dockerClient, client) + token, err := webClient.FetchAPIToken(ctx, replicateUsername) + if err != nil { + return err + } + envFlags = append(envFlags, fmt.Sprintf("REPLICATE_API_TOKEN=%s", token)) + } + console.Info("") console.Infof("Starting Docker image %s and running setup()...", imageName) @@ -643,3 +660,7 @@ func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error func addSetupTimeoutFlag(cmd *cobra.Command) { cmd.Flags().Uint32Var(&setupTimeout, "setup-timeout", 5*60, "The timeout for a container to setup (in seconds).") } + +func addReplicateUsernameFlag(cmd *cobra.Command) { + cmd.Flags().StringVarP(&replicateUsername, "replicate-username", "u", "", "The principal to use if the prediction requires a token.") +}