diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index a9c74ca6c3f..e15a5f2f30c 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -12,7 +12,7 @@ from metaflow.exception import MetaflowException from .micromamba import Micromamba -from .utils import pip_tags, wheel_tags +from .utils import pip_tags, wheel_tags, conda_platform, markers_from_platform class PipException(MetaflowException): @@ -44,6 +44,7 @@ def __init__(self, error): METADATA_FILE = "{prefix}/.pip/metadata" INSTALLATION_MARKER = "{prefix}/.pip/id" +MAX_SOLVE_ITERATIONS = 1 # TODO: # 1. Support local dirs, non-wheel like packages @@ -69,7 +70,7 @@ def _get_resolved_python_version(self, prefix): except Exception: return None - def solve(self, id_, packages, python, platform): + def solve(self, id_, packages, python, platform, iterations=0): prefix = self.micromamba.path_to_environment(id_) if prefix is None: msg = "Unable to locate a Micromamba managed virtual environment\n" @@ -150,6 +151,85 @@ def _format(dl_info): res["hash"] = vcs_info["commit_id"] return res + from metaflow._vendor.packaging.requirements import Requirement + + def _extract_platform_specific_deps( + pkg_metadata, target_system, target_machine + ): + # We are interested in the outputs 'metadata.requires_dist', and whether it contains platform_system or platform_machine markers + # that are different than the environment performing the resolving. + # In this case we want to do a second pass, which will try to add these packages _without_ the environment markers + # in order to assure that all relevant packages are present in the target environment + # e.g. + # "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == \"Linux\" and platform_machine == \"x86_64\"", + deps = pkg_metadata.get("requires_dist") + deps_with_markers = {} + if deps is None: + return deps_with_markers + for dep in deps: + req = Requirement(dep) + if not req.marker: + continue + match_system = re.match( + r"^.*platform_system == (.*?)\s", str(req.marker) + ) + plat_system = ( + match_system.groups()[0].strip('"') if match_system else None + ) + + match_machine = re.match( + r"^.*platform_machine == (.*?)\s", str(req.marker) + ) + plat_machine = ( + match_machine.groups()[0].strip('"') if match_machine else None + ) + + if plat_system is None and plat_machine is None: + continue + + if plat_system == target_system or plat_machine == target_machine: + # we must make sure that this dependency gets added to the list, + # as it will not be carried by the default resolve due to platform/machine mismatch. + version = str(req.specifier) + if version.startswith("=="): + version = version[2:] + deps_with_markers[req.name] = version + + return deps_with_markers + + # NOTE: Make sure to run this only if current platform and target platform are a mismatch. + # i.e. we are doing a cross-platform resolve! + if (conda_platform() != platform) and iterations < MAX_SOLVE_ITERATIONS: + debug.conda_exec( + "Current platform differs from target platform. Performing a check for environment markers in package dependencies that might end up being not included otherwise." + ) + requested_sys, requested_machine = markers_from_platform(platform) + debug.conda_exec( + f"Checking for environment markers 'platform_system == {requested_sys}' and 'platform_machine == {requested_machine}'" + ) + with open(report, mode="r", encoding="utf-8") as f: + deps_to_add = { + k: v + for item in json.load(f)["install"] + for k, v in _extract_platform_specific_deps( + item.get("metadata", {}), requested_sys, requested_machine + ).items() + } + + added_deps = False + for dep, ver in deps_to_add.items(): + if dep not in packages: + added_deps = True + packages[dep] = ver + + if added_deps: + debug.conda_exec( + "Added dependencies due to environment markers, have to re-solve the environment with new ones." + ) + return self.solve( + id_, packages, python, platform, iterations=iterations + 1 + ) + with open(report, mode="r", encoding="utf-8") as f: return [ _format(item["download_info"]) for item in json.load(f)["install"] diff --git a/metaflow/plugins/pypi/utils.py b/metaflow/plugins/pypi/utils.py index 8e3914f588d..1af46e20f68 100644 --- a/metaflow/plugins/pypi/utils.py +++ b/metaflow/plugins/pypi/utils.py @@ -41,6 +41,15 @@ def conda_platform(): return "osx-64" +def markers_from_platform(platform): + plat, mach = platform.split("-") + + platform_system = {"osx": "Darwin", "linux": "Linux"}[plat] + platform_machine = {"32": "x86", "64": "x86_64", "arm64": "aarch64"}[mach] + + return platform_system, platform_machine + + def wheel_tags(wheel): _, _, _, tags = parse_wheel_filename(wheel) return list(tags)