diff --git a/metaflow/plugins/pypi/conda_environment.py b/metaflow/plugins/pypi/conda_environment.py index 292d1fb7535..6c5973b1270 100644 --- a/metaflow/plugins/pypi/conda_environment.py +++ b/metaflow/plugins/pypi/conda_environment.py @@ -55,7 +55,7 @@ def validate_environment(self, logger, datastore_type): # Initialize necessary virtual environments for all Metaflow tasks. # Use Micromamba for solving conda packages and Pip for solving pypi packages. from .micromamba import Micromamba - from .pip import Pip + from .pip_resolver import Pip print_lock = threading.Lock() diff --git a/metaflow/plugins/pypi/pip_patcher.py b/metaflow/plugins/pypi/pip_patcher.py new file mode 100644 index 00000000000..11cf2c314e9 --- /dev/null +++ b/metaflow/plugins/pypi/pip_patcher.py @@ -0,0 +1,24 @@ +from unittest.mock import patch +import os +import sys + + +# Because Pip does not offer a direct way to set target platform_system and platform_machine values for resolving packages transitive dependencies, we need to instead +# manually patch the correct target architecture values for pip to be able to resolve the whole dependency tree successfully. +# This is necessary for packages that have conditional dependencies dependent on machine/system, e.g. Torch +@patch( + "platform.machine", lambda: os.environ.get("PIP_PATCH_MACHINE", os.uname().machine) +) +@patch( + "platform.system", lambda: os.environ.get("PIP_PATCH_SYSTEM", os.uname().sysname) +) +def _main(args): + # TODO: Pip has deprecated using script wrappers for the cli. this will break in the future, and make patching the sys internals much harder + from pip import main + + exitcode = main(args) + sys.exit(exitcode) + + +if __name__ == "__main__": + _main(sys.argv[1:]) diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip_resolver.py similarity index 97% rename from metaflow/plugins/pypi/pip.py rename to metaflow/plugins/pypi/pip_resolver.py index 5aeb03ec5f4..3c24ee76e93 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip_resolver.py @@ -102,7 +102,10 @@ def solve(self, id_, packages, python, platform): else: cmd.append(f"{package}=={version}") try: - self._call(prefix, cmd) + env = {} + if platform == "linux-64": + env = {"PIP_PATCH_SYSTEM": "Linux", "PIP_PATCH_MACHINE": "x86_64"} + self._call(prefix, cmd, env) except PipPackageNotFound as ex: # pretty print package errors raise PipException( @@ -306,7 +309,8 @@ def _call(self, prefix, args, env=None, isolated=True): "run", "--prefix", prefix, - "pip3", + "python", + os.path.join(os.path.dirname(__file__), "pip_patcher.py"), "--disable-pip-version-check", "--no-color", ]