Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions metaflow/plugins/pypi/pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metaflow.exception import MetaflowException

from .micromamba import Micromamba
from .utils import pip_tags, wheel_tags
from .utils import pip_tags, wheel_tags, conda_platform, markers_from_platform


class PipException(MetaflowException):
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(self, error):

METADATA_FILE = "{prefix}/.pip/metadata"
INSTALLATION_MARKER = "{prefix}/.pip/id"
MAX_SOLVE_ITERATIONS = 1

# TODO:
# 1. Support local dirs, non-wheel like packages
Expand All @@ -69,7 +70,7 @@ def _get_resolved_python_version(self, prefix):
except Exception:
return None

def solve(self, id_, packages, python, platform):
def solve(self, id_, packages, python, platform, iterations=0):
prefix = self.micromamba.path_to_environment(id_)
if prefix is None:
msg = "Unable to locate a Micromamba managed virtual environment\n"
Expand Down Expand Up @@ -150,6 +151,85 @@ def _format(dl_info):
res["hash"] = vcs_info["commit_id"]
return res

from metaflow._vendor.packaging.requirements import Requirement

def _extract_platform_specific_deps(
pkg_metadata, target_system, target_machine
):
# We are interested in the outputs 'metadata.requires_dist', and whether it contains platform_system or platform_machine markers
# that are different than the environment performing the resolving.
# In this case we want to do a second pass, which will try to add these packages _without_ the environment markers
# in order to assure that all relevant packages are present in the target environment
# e.g.
# "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
deps = pkg_metadata.get("requires_dist")
deps_with_markers = {}
if deps is None:
return deps_with_markers
for dep in deps:
req = Requirement(dep)
if not req.marker:
continue
match_system = re.match(
r"^.*platform_system == (.*?)\s", str(req.marker)
)
plat_system = (
match_system.groups()[0].strip('"') if match_system else None
)

match_machine = re.match(
r"^.*platform_machine == (.*?)\s", str(req.marker)
)
plat_machine = (
match_machine.groups()[0].strip('"') if match_machine else None
)

if plat_system is None and plat_machine is None:
continue

if plat_system == target_system or plat_machine == target_machine:
# we must make sure that this dependency gets added to the list,
# as it will not be carried by the default resolve due to platform/machine mismatch.
version = str(req.specifier)
if version.startswith("=="):
version = version[2:]
deps_with_markers[req.name] = version

return deps_with_markers

# NOTE: Make sure to run this only if current platform and target platform are a mismatch.
# i.e. we are doing a cross-platform resolve!
if (conda_platform() != platform) and iterations < MAX_SOLVE_ITERATIONS:
debug.conda_exec(
"Current platform differs from target platform. Performing a check for environment markers in package dependencies that might end up being not included otherwise."
)
requested_sys, requested_machine = markers_from_platform(platform)
debug.conda_exec(
f"Checking for environment markers 'platform_system == {requested_sys}' and 'platform_machine == {requested_machine}'"
)
with open(report, mode="r", encoding="utf-8") as f:
deps_to_add = {
k: v
for item in json.load(f)["install"]
for k, v in _extract_platform_specific_deps(
item.get("metadata", {}), requested_sys, requested_machine
).items()
}

added_deps = False
for dep, ver in deps_to_add.items():
if dep not in packages:
added_deps = True
packages[dep] = ver

if added_deps:
debug.conda_exec(
"Added dependencies due to environment markers, have to re-solve the environment with new ones."
)
return self.solve(
id_, packages, python, platform, iterations=iterations + 1
)

with open(report, mode="r", encoding="utf-8") as f:
return [
_format(item["download_info"]) for item in json.load(f)["install"]
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/pypi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def conda_platform():
return "osx-64"


def markers_from_platform(platform):
plat, mach = platform.split("-")

platform_system = {"osx": "Darwin", "linux": "Linux"}[plat]
platform_machine = {"32": "x86", "64": "x86_64", "arm64": "aarch64"}[mach]

return platform_system, platform_machine


def wheel_tags(wheel):
_, _, _, tags = parse_wheel_filename(wheel)
return list(tags)
Expand Down
Loading