diff --git a/go.mod b/go.mod index 7507397498..a648f3f173 100644 --- a/go.mod +++ b/go.mod @@ -176,6 +176,7 @@ require ( github.com/jgautheron/goconst v1.7.1 // indirect github.com/jingyugao/rowserrcheck v1.1.1 // indirect github.com/jjti/go-spancheck v0.6.4 // indirect + github.com/joho/godotenv v1.5.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/julz/importas v0.2.0 // indirect github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect diff --git a/go.sum b/go.sum index 25938fc8d1..77d2429bb3 100644 --- a/go.sum +++ b/go.sum @@ -338,6 +338,8 @@ github.com/jingyugao/rowserrcheck v1.1.1 h1:zibz55j/MJtLsjP1OF4bSdgXxwL1b+Vn7Tjz github.com/jingyugao/rowserrcheck v1.1.1/go.mod h1:4yvlZSDb3IyDTUZJUmpZfm2Hwok+Dtp+nu2qOq+er9c= github.com/jjti/go-spancheck v0.6.4 h1:Tl7gQpYf4/TMU7AT84MN83/6PutY21Nb9fuQjFTpRRc= github.com/jjti/go-spancheck v0.6.4/go.mod h1:yAEYdKJ2lRkDA8g7X+oKUHXOWVAXSBJRv04OhF+QUjk= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/julz/importas v0.2.0 h1:y+MJN/UdL63QbFJHws9BVC5RpA2iq0kpjrFajTGivjQ= diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index f2396b5852..8353faf118 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -17,6 +17,7 @@ import ( "time" "github.com/getkin/kin-openapi/openapi3" + "github.com/joho/godotenv" "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" "golang.org/x/sys/unix" @@ -69,10 +70,10 @@ the prediction on that.`, addFastFlag(cmd) addLocalImage(cmd) addConfigFlag(cmd) + addEnvFlag(cmd) cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg") cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path") - cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value") cmd.Flags().BoolVar(&useReplicateAPIToken, "use-replicate-token", false, "Pass REPLICATE_API_TOKEN from local environment into the model context") cmd.Flags().StringVar(&inputJSON, "json", "", "Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-)") @@ -643,3 +644,21 @@ func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error func addSetupTimeoutFlag(cmd *cobra.Command) { cmd.Flags().Uint32Var(&setupTimeout, "setup-timeout", 5*60, "The timeout for a container to setup (in seconds).") } + +func addEnvFlag(cmd *cobra.Command) { + defaultEnv := []string{} + + // Attempt to fill default environment variables with a .env file + exists, _ := files.Exists(".env") + if exists { + var dotEnv map[string]string + dotEnv, err := godotenv.Read() + if err == nil { + for key, value := range dotEnv { + defaultEnv = append(defaultEnv, fmt.Sprintf("%s=%s", key, value)) + } + } + } + + cmd.Flags().StringArrayVarP(&envFlags, "env", "e", defaultEnv, "Environment variables, in the form name=value") +} diff --git a/test-integration/test_integration/fixtures/env-cli-project/.env b/test-integration/test_integration/fixtures/env-cli-project/.env new file mode 100644 index 0000000000..92e90afffc --- /dev/null +++ b/test-integration/test_integration/fixtures/env-cli-project/.env @@ -0,0 +1 @@ +TEST_VAR=test_value diff --git a/test-integration/test_integration/fixtures/env-cli-project/cog.yaml b/test-integration/test_integration/fixtures/env-cli-project/cog.yaml new file mode 100644 index 0000000000..cdf01d2b46 --- /dev/null +++ b/test-integration/test_integration/fixtures/env-cli-project/cog.yaml @@ -0,0 +1 @@ +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/env-cli-project/predict.py b/test-integration/test_integration/fixtures/env-cli-project/predict.py new file mode 100644 index 0000000000..6029e472c4 --- /dev/null +++ b/test-integration/test_integration/fixtures/env-cli-project/predict.py @@ -0,0 +1,7 @@ +from cog import BasePredictor +import os + + +class Predictor(BasePredictor): + def predict(self, name: str) -> str: + return f"ENV[{name}]={os.getenv(name)}" diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index ff1d4fa626..926e44a922 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -801,3 +801,17 @@ def test_predict_future_annotations(cog_binary): timeout=120.0, ) assert result.returncode == 0 + + +def test_predict_dotenv(cog_binary): + project_dir = Path(__file__).parent / "fixtures/env-cli-project" + + result = subprocess.run( + [cog_binary, "predict", "--debug", "-i", "name=TEST_VAR"], + cwd=project_dir, + capture_output=True, + text=True, + timeout=120.0, + ) + assert result.returncode == 0 + assert result.stdout == "ENV[TEST_VAR]=test_value\n"