Skip to content

[PoC] Reimplement Specifiers and SpecifierSets using intervals#1120

Draft
notatallshaw wants to merge 4 commits intopypa:mainfrom
notatallshaw:is-unsatisfiable/specifier_intervals
Draft

[PoC] Reimplement Specifiers and SpecifierSets using intervals#1120
notatallshaw wants to merge 4 commits intopypa:mainfrom
notatallshaw:is-unsatisfiable/specifier_intervals

Conversation

@notatallshaw
Copy link
Member

This is a proof of concept of switching the internal mechanics of Specifier and SpecifierSet to using intervals instead of iterative version filters. In general this speeds up any mildly complex specifier and can make very complex specifiers significantly faster (more than 10x).

However this is a large overhaul, first it is intended we land #1119 which implements just enough of this to introduce a new is_unsatisfiable method.

This implementation is designed to make it easy to buid a public API on top of this, I imagine having some intervals() method which returns a list[SpecifierIntervals] where SpecifierIntervals are tuple[SpecifierBound, SpecifierBound], where a SpecifierBound encapsulates a Version object but applies PEP 440 specifier ordering logic.

However, to make this more performant, especially for very simple one off cases where it is currently slower, the internal machinery probably needs to move away from bounds logic, implementing hot paths and side channels. Will leave this in draft until after #1119 lands.

@notatallshaw
Copy link
Member Author

I'll remove this from pip: pypa/pip#13850

But I think it's best to either wrap a deprecated around it and/or make a public version of "operators".

@henryiii
Copy link
Contributor

I played around asking copilot in vscode to optimize this branch. It got about 0.7 or so on both complex and simple filtering. This is still slower for simple specifiers, but it's about 1.2x slower instead of 1.6x or so. Here's the AI generated summary of the things it thinks it optimized:

Yes. Here is a tighter PR-comment version:

Compared against notatallshaw/is-unsatisfiable/specifier_intervals, the optimizations are mostly in five areas:

  1. Exclusion-bound comparisons got cheaper.
  • _ExclusionBound now avoids re-trimming release tuples on every comparison, rejects mismatches earlier, and adds a direct __gt__ fast path.
  • This reduces work in the hot comparison path used by interval membership checks.
  1. Prerelease handling moved into the interval filter.
  • _filter_by_intervals now handles prereleases=None itself, including the PEP 440 “buffer prereleases unless no finals match” behavior.
  • That removes an extra wrapper layer and keeps filtering decisions in one place.
  1. Single-interval cases got a dedicated fast path.
  • Both filtering and membership checks now special-case the common len(intervals) == 1 case.
  • This avoids the generic nested interval loop for simple specifiers.
  1. contains now uses direct interval membership.
  • A new _version_in_intervals helper lets Specifier.contains and SpecifierSet.contains do one parsed-version membership check instead of routing through more generic logic.
  1. SpecifierSet uses cheaper fast paths when possible.
  • Single-specifier sets delegate directly to that one specifier.
  • Multi-specifier sets use cached intersected intervals when possible.
  • The slower per-specifier fallback is now mostly reserved for === cases.

And the diff:

Details
diff --git a/src/packaging/specifiers.py b/src/packaging/specifiers.py
index c341232..62cc181 100644
--- a/src/packaging/specifiers.py
+++ b/src/packaging/specifiers.py
@@ -84,12 +84,22 @@ class _ExclusionBound:
     def _is_family(self, other: Version) -> bool:
         """Is ``other`` a version that this sentinel sorts above?"""
         v = self.version
-        if not (
-            other.epoch == v.epoch
-            and _trim_release(other.release) == self._trimmed_release
-            and other.pre == v.pre
-        ):
+        if other.epoch != v.epoch or other.pre != v.pre:
+            return False
+
+        # Compare trimmed release equality without allocating a new tuple
+        # for ``other.release`` on each call.
+        other_release = other.release
+        trimmed = self._trimmed_release
+        if len(other_release) < len(trimmed):
             return False
+        for idx, value in enumerate(trimmed):
+            if other_release[idx] != value:
+                return False
+        for value in other_release[len(trimmed) :]:
+            if value != 0:
+                return False
+
         if self._kind == _AFTER_LOCALS:
             # Local family: exact same public version (any local label).
             return other.post == v.post and other.dev == v.dev
@@ -107,8 +117,25 @@ class _ExclusionBound:
                 return self.version < other.version
             return self._kind < other._kind
         assert isinstance(other, Version)
+        # Cheap reject first: if ``other`` is not above ``V``,
+        # ``self < other`` can never hold.
+        if not (self.version < other):
+            return False
         # self < other iff other is NOT in the family and other > V
-        return not self._is_family(other) and self.version < other
+        return not self._is_family(other)
+
+    def __gt__(self, other: object) -> bool:
+        if isinstance(other, _ExclusionBound):
+            if self.version != other.version:
+                return self.version > other.version
+            return self._kind > other._kind
+        assert isinstance(other, Version)
+        # Fast path: base version already dominates — no family check needed.
+        if self.version >= other:
+            return True
+        # Slower path: other > V, but might still be in the family
+        # (e.g. a post-release counted as V.postN with AFTER_POSTS semantics).
+        return self._is_family(other)
 
     def __hash__(self) -> int:
         return hash((self.version, self._kind))
@@ -203,7 +230,7 @@ def _filter_by_intervals(
     intervals: list[_SpecifierInterval],
     iterable: Iterable[Any],
     key: Callable[[Any], UnparsedVersion] | None,
-    prereleases: bool,
+    prereleases: bool | None,
 ) -> Iterator[Any]:
     """Filter versions against precomputed intervals.
 
@@ -211,12 +238,116 @@ def _filter_by_intervals(
     use :class:`_ExclusionBound` to handle local-version semantics.
 
     Used by both :class:`Specifier` and :class:`SpecifierSet`.
-    Prerelease buffering (PEP 440 default) is NOT handled here;
-    callers wrap the result with :func:`_pep440_filter_prereleases`
-    when needed.
+    When ``prereleases`` is ``None``, PEP 440 default semantics apply:
+    prereleases are excluded unless no final releases match.
     """
+    if not intervals:
+        return
+
+    # PEP 440 default behavior: exclude prereleases unless no finals match.
+    if prereleases is None:
+        prereleases_buffer: list[Any] = []
+        found_final = False
+
+        if len(intervals) == 1:
+            (
+                (lower_version, lower_inclusive),
+                (
+                    upper_version,
+                    upper_inclusive,
+                ),
+            ) = intervals[0]
+
+            for item in iterable:
+                parsed = _coerce_version(item if key is None else key(item))
+                if parsed is None:
+                    continue
+                if lower_version is not None:
+                    if lower_inclusive:
+                        if parsed < lower_version:
+                            continue
+                    elif not (parsed > lower_version):
+                        continue
+                if upper_version is not None:
+                    if upper_inclusive:
+                        if parsed > upper_version:
+                            continue
+                    elif not (parsed < upper_version):
+                        continue
+                if parsed.is_prerelease:
+                    prereleases_buffer.append(item)
+                else:
+                    found_final = True
+                    yield item
+            if not found_final:
+                yield from prereleases_buffer
+            return
+
+        for item in iterable:
+            parsed = _coerce_version(item if key is None else key(item))
+            if parsed is None:
+                continue
+            # Check if version falls within any interval. Intervals are sorted
+            # and non-overlapping, so at most one can match.
+            for (lower_version, lower_inclusive), (
+                upper_version,
+                upper_inclusive,
+            ) in intervals:
+                if lower_version is not None:
+                    if lower_inclusive:
+                        if parsed < lower_version:
+                            break
+                    elif not (parsed > lower_version):
+                        break
+                if upper_version is None:
+                    matched = True
+                elif upper_inclusive:
+                    matched = not (parsed > upper_version)
+                else:
+                    matched = parsed < upper_version
+                if matched:
+                    if parsed.is_prerelease:
+                        prereleases_buffer.append(item)
+                    else:
+                        found_final = True
+                        yield item
+                    break
+        if not found_final:
+            yield from prereleases_buffer
+        return
+
     exclude_prereleases = prereleases is False
 
+    if len(intervals) == 1:
+        (
+            (lower_version, lower_inclusive),
+            (
+                upper_version,
+                upper_inclusive,
+            ),
+        ) = intervals[0]
+
+        for item in iterable:
+            parsed = _coerce_version(item if key is None else key(item))
+            if parsed is None:
+                continue
+            if exclude_prereleases and parsed.is_prerelease:
+                continue
+            if lower_version is not None:
+                if lower_inclusive:
+                    if parsed < lower_version:
+                        continue
+                elif not (parsed > lower_version):
+                    continue
+            if upper_version is not None:
+                if upper_inclusive:
+                    if parsed > upper_version:
+                        continue
+                elif not (parsed < upper_version):
+                    continue
+            yield item
+        return
+
     for item in iterable:
         parsed = _coerce_version(item if key is None else key(item))
         if parsed is None:
@@ -229,20 +360,67 @@ def _filter_by_intervals(
             upper_version,
             upper_inclusive,
         ) in intervals:
-            if lower_version is not None and (
-                parsed < lower_version
-                or (parsed == lower_version and not lower_inclusive)
-            ):
-                break
-            if (
-                upper_version is None
-                or parsed < upper_version
-                or (parsed == upper_version and upper_inclusive)
-            ):
+            if lower_version is not None:
+                if lower_inclusive:
+                    if parsed < lower_version:
+                        break
+                elif not (parsed > lower_version):
+                    break
+            if upper_version is None:
+                matched = True
+            elif upper_inclusive:
+                matched = not (parsed > upper_version)
+            else:
+                matched = parsed < upper_version
+            if matched:
                 yield item
                 break
 
 
+def _version_in_intervals(
+    version: Version, intervals: list[_SpecifierInterval]
+) -> bool:
+    """Return whether ``version`` falls within any of ``intervals``."""
+    if not intervals:
+        return False
+
+    if len(intervals) == 1:
+        (
+            (lower_version, lower_inclusive),
+            (
+                upper_version,
+                upper_inclusive,
+            ),
+        ) = intervals[0]
+        if lower_version is not None:
+            if lower_inclusive:
+                if version < lower_version:
+                    return False
+            elif not (version > lower_version):
+                return False
+        if upper_version is None:
+            return True
+        if upper_inclusive:
+            return not (version > upper_version)
+        return version < upper_version
+
+    for (lower_version, lower_inclusive), (upper_version, upper_inclusive) in intervals:
+        if lower_version is not None:
+            if lower_inclusive:
+                if version < lower_version:
+                    break
+            elif not (version > lower_version):
+                break
+        if upper_version is None:
+            return True
+        if upper_inclusive:
+            if not (version > upper_version):
+                return True
+        elif version < upper_version:
+            return True
+    return False
+
+
 def _pep440_filter_prereleases(
     iterable: Iterable[Any], key: Callable[[Any], UnparsedVersion] | None
 ) -> Iterator[Any]:
@@ -820,7 +998,24 @@ class Specifier(BaseSpecifier):
         True
         """
 
-        return bool(list(self.filter([item], prereleases=prereleases)))
+        if self.operator == "===":
+            return str(item).lower() == self.version.lower()
+
+        version = _coerce_version(item)
+        if version is None:
+            return False
+
+        if prereleases is None:
+            if self._prereleases is not None:
+                prereleases = self._prereleases
+            elif self.prereleases:
+                prereleases = True
+
+        resolve_pre = True if prereleases is None else prereleases
+        if not resolve_pre and version.is_prerelease:
+            return False
+
+        return _version_in_intervals(version, self._to_intervals())
 
     @typing.overload
     def filter(
@@ -890,22 +1085,15 @@ class Specifier(BaseSpecifier):
             elif self.prereleases:
                 prereleases = True
 
-        # When prereleases is still None, pass True to include all versions
-        # and let _pep440_filter_prereleases handle the buffering.
-        resolve_pre = True if prereleases is None else prereleases
-
-        filtered = _filter_by_intervals(
+        # _filter_by_intervals handles prereleases=None (PEP 440 semantics)
+        # directly, so no wrapper needed.
+        yield from _filter_by_intervals(
             self._to_intervals(),
             iterable,
             key,
-            prereleases=resolve_pre,
+            prereleases=prereleases,
         )
 
-        if prereleases is not None:
-            yield from filtered
-        else:
-            yield from _pep440_filter_prereleases(filtered, key)
-
 
 class SpecifierSet(BaseSpecifier):
     """This class abstracts handling of a set of version specifiers.
@@ -1244,8 +1432,30 @@ class SpecifierSet(BaseSpecifier):
         if version is not None and installed and version.is_prerelease:
             prereleases = True
 
-        check_item = item if version is None else version
-        return bool(list(self.filter([check_item], prereleases=prereleases)))
+        if prereleases is None:
+            default_prereleases = self.prereleases
+            if default_prereleases is not None:
+                prereleases = default_prereleases
+
+        allow_prereleases = True if prereleases is None else prereleases
+
+        if self._specs:
+            intervals = self._get_intervals()
+            if intervals is not None:
+                if version is None:
+                    return False
+                if not allow_prereleases and version.is_prerelease:
+                    return False
+                return _version_in_intervals(version, intervals)
+
+            candidate = item if version is None else version
+            return all(
+                s.contains(candidate, prereleases=allow_prereleases)
+                for s in self._specs
+            )
+
+        # Empty SpecifierSet matches everything unless prereleases are disabled.
+        return allow_prereleases or version is None or not version.is_prerelease
 
     @typing.overload
     def filter(
@@ -1313,37 +1523,48 @@ class SpecifierSet(BaseSpecifier):
         # Determine if we're forcing a prerelease or not, if we're not forcing
         # one for this particular filter call, then we'll use whatever the
         # SpecifierSet thinks for whether or not we should support prereleases.
-        if prereleases is None and self.prereleases is not None:
-            prereleases = self.prereleases
+        if prereleases is None:
+            default_prereleases = self.prereleases
+            if default_prereleases is not None:
+                prereleases = default_prereleases
 
         # Filter versions that match all specifiers.
         if self._specs:
-            resolve_pre = True if prereleases is None else prereleases
+            # Fast path: a single specifier can delegate directly.
+            # This avoids an extra PEP 440 pass in the common one-spec case.
+            if len(self._specs) == 1:
+                return self._specs[0].filter(
+                    iterable,
+                    prereleases=prereleases,
+                    key=key,
+                )
 
-            filtered: Iterator[Any]
             intervals = self._get_intervals()
             if intervals is not None:
-                filtered = _filter_by_intervals(
+                # _filter_by_intervals handles prereleases=None (PEP 440
+                # semantics) directly.
+                return _filter_by_intervals(
                     intervals,
                     iterable,
                     key,
-                    prereleases=resolve_pre,
+                    prereleases=prereleases,
                 )
-            else:
-                # _get_intervals returns None when specs include ===
-                # (arbitrary string matching, not version comparison).
-                specs = self._specs
-                filtered = (
-                    item
-                    for item in iterable
-                    if all(
-                        s.contains(
-                            item if key is None else key(item),
-                            prereleases=resolve_pre,
-                        )
-                        for s in specs
+
+            # _get_intervals returns None when specs include ===
+            # (arbitrary string matching, not version comparison).
+            allow_prereleases = True if prereleases is None else prereleases
+            specs = self._specs
+            filtered: Iterator[Any] = (
+                item
+                for item in iterable
+                if all(
+                    s.contains(
+                        item if key is None else key(item),
+                        prereleases=allow_prereleases,
                     )
+                    for s in specs
                 )
+            )
 
             if prereleases is not None:
                 return filtered

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants