diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 4edbcdac00c..9296bb43614 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -350,9 +350,10 @@ def __init__( raise MetaflowNamespaceMismatch(self._current_namespace) def _get_object(self, *path_components): - result = self._metaflow.metadata.get_object( + result_iter = self._metaflow.metadata.get_object( self._NAME, "self", None, self._attempt, *path_components ) + result = next(result_iter, None) if not result: raise MetaflowNotFound("%s does not exist" % self) return result diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index b7256eb2924..6264920070b 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -394,7 +394,7 @@ def get_object(cls, obj_type, sub_type, filters, attempt, *args): Return ------ - object or list : + object or iterator : Depending on the call, the type of object return varies """ type_order = ObjectOrder.type_to_order(obj_type) diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 1bcc2a1ef2a..c832e5a4635 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -55,6 +55,7 @@ class ServiceMetadataProvider(MetadataProvider): _supports_attempt_gets = None _supports_tag_mutation = None + _supports_pagination = None def __init__(self, environment, flow, event_logger, monitor): super(ServiceMetadataProvider, self).__init__( @@ -252,6 +253,12 @@ def _mutate_user_tags_for_run( def _get_object_internal( cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args ): + if cls._supports_pagination is None: + version = cls._version(None) + cls._supports_pagination = version is not None and version_parse( + version + ) >= version_parse("2.5.0") + if attempt is not None: if cls._supports_attempt_gets is None: version = cls._version(None) @@ -275,10 +282,12 @@ def _get_object_internal( url = ServiceMetadataProvider._obj_path(*args[:obj_order]) try: v, _ = cls._request(None, url, "GET") - return MetadataProvider._apply_filter([v], filters)[0] + yield MetadataProvider._apply_filter([v], filters)[0] + return except ServiceException as ex: if ex.http_code == 404: - return None + yield None + return raise # For the other types, we locate all the objects we need to find and return them @@ -292,12 +301,36 @@ def _get_object_internal( url += "/attempt/%s/artifacts" % attempt else: url += "/%ss" % sub_type + + # make the request paginated if we are querying for child objects + paginated_results = cls._supports_pagination and ( + obj_type != sub_type and sub_type != "self" + ) try: - v, _ = cls._request(None, url, "GET") - return MetadataProvider._apply_filter(v, filters) + if paginated_results: + limit = 100 + url += "?_limit=%s" % limit + page = 1 + while True: + # print("paginated request: page %s - limit %s" % (page, limit)) + _url = url + "&_page=%s" % page + v, _ = cls._request(None, _url, "GET") + for obj in v: + yield obj + if len(v) < limit: + # print("REACHED THE END OF PAGINATION") + # no more results expected, we are on the last page. + break + page += 1 + else: + # print("REGULAR REQUEST") + v, _ = cls._request(None, url, "GET") + yield MetadataProvider._apply_filter(v, filters) + return except ServiceException as ex: if ex.http_code == 404: - return None + yield None + return raise def _new_run(self, run_id=None, tags=None, sys_tags=None):