Skip to content

Commit 2f84638

Browse files
authored
Add support for paginated results in list methods (#189)
This PR defines a new generic `Page` type, which represents a page of results returned from methods like `replicate.models.list`. This PR also updates those list methods to take a `cursor` argument. When `page.next` or `page.previous` is not `None`, you can pass it to a `list` method to get the next or previous page of results. --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 7a058fb commit 2f84638

File tree

12 files changed

+8761
-3155
lines changed

12 files changed

+8761
-3155
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ replicate.predictions.list()
165165
# [<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]
166166
```
167167

168+
Lists of predictions are paginated. You can get the next page of predictions by passing the `next` property as an argument to the `list` method:
169+
170+
```python
171+
page1 = replicate.predictions.list()
172+
173+
if page1.next:
174+
page2 = replicate.predictions.list(page1.next)
175+
```
176+
168177
## Load output files
169178

170179
Output files are returned as HTTPS URLs. You can load an output file as a buffer:

replicate/hardware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class Hardwares(Namespace):
2828

2929
def list(self) -> List[Hardware]:
3030
"""
31-
List all public models.
31+
List all hardware available for you to run models on Replicate.
3232
3333
Returns:
34-
A list of models.
34+
List[Hardware]: A list of hardware.
3535
"""
3636

3737
resp = self._client._request("GET", "/v1/hardware")

replicate/model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Dict, List, Optional, Union
1+
from typing import Dict, Optional, Union
22

33
from typing_extensions import deprecated
44

55
from replicate.exceptions import ReplicateException
6+
from replicate.pagination import Page
67
from replicate.prediction import Prediction
78
from replicate.resource import Namespace, Resource
89
from replicate.version import Version, Versions
@@ -123,18 +124,23 @@ class Models(Namespace):
123124

124125
model = Model
125126

126-
def list(self) -> List[Model]:
127+
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
127128
"""
128129
List all public models.
129130
131+
Parameters:
132+
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
130133
Returns:
131-
A list of models.
134+
Page[Model]: A page of of models.
135+
Raises:
136+
ValueError: If `cursor` is `None`.
132137
"""
133138

134-
resp = self._client._request("GET", "/v1/models")
135-
# TODO: paginate
136-
models = resp.json()["results"]
137-
return [self._prepare_model(obj) for obj in models]
139+
if cursor is None:
140+
raise ValueError("cursor cannot be None")
141+
142+
resp = self._client._request("GET", "/v1/models" if cursor is ... else cursor)
143+
return Page[Model](self._client, self, **resp.json())
138144

139145
def get(self, key: str) -> Model:
140146
"""

replicate/pagination.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import (
2+
TYPE_CHECKING,
3+
Dict,
4+
Generic,
5+
List,
6+
Optional,
7+
TypeVar,
8+
Union,
9+
)
10+
11+
try:
12+
from pydantic import v1 as pydantic # type: ignore
13+
except ImportError:
14+
import pydantic # type: ignore
15+
16+
from replicate.resource import Namespace, Resource
17+
18+
T = TypeVar("T", bound=Resource)
19+
20+
if TYPE_CHECKING:
21+
from .client import Client
22+
23+
24+
class Page(pydantic.BaseModel, Generic[T]):
25+
"""
26+
A page of results from the API.
27+
"""
28+
29+
_client: "Client" = pydantic.PrivateAttr()
30+
_namespace: Namespace = pydantic.PrivateAttr()
31+
32+
previous: Optional[str] = None
33+
"""A pointer to the previous page of results"""
34+
35+
next: Optional[str] = None
36+
"""A pointer to the next page of results"""
37+
38+
results: List[T]
39+
"""The results on this page"""
40+
41+
def __init__(
42+
self,
43+
client: "Client",
44+
namespace: Namespace[T],
45+
*,
46+
results: Optional[List[Union[T, Dict]]] = None,
47+
**kwargs,
48+
) -> None:
49+
self._client = client
50+
self._namespace = namespace
51+
52+
super().__init__(
53+
results=[self._namespace._prepare_model(r) for r in results]
54+
if results
55+
else None,
56+
**kwargs,
57+
)
58+
59+
def __iter__(self): # noqa: ANN204
60+
return iter(self.results)
61+
62+
def __getitem__(self, index: int) -> T:
63+
return self.results[index]
64+
65+
def __len__(self) -> int:
66+
return len(self.results)

replicate/prediction.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from replicate.exceptions import ModelError
77
from replicate.files import upload_file
88
from replicate.json import encode_json
9+
from replicate.pagination import Page
910
from replicate.resource import Namespace, Resource
1011
from replicate.version import Version
1112

@@ -157,21 +158,25 @@ class Predictions(Namespace):
157158

158159
model = Prediction
159160

160-
def list(self) -> List[Prediction]:
161+
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noqa: F821
161162
"""
162163
List your predictions.
163164
165+
Parameters:
166+
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
164167
Returns:
165-
A list of prediction objects.
168+
Page[Prediction]: A page of of predictions.
169+
Raises:
170+
ValueError: If `cursor` is `None`.
166171
"""
167172

168-
resp = self._client._request("GET", "/v1/predictions")
169-
# TODO: paginate
170-
predictions = resp.json()["results"]
171-
for prediction in predictions:
172-
# HACK: resolve this? make it lazy somehow?
173-
del prediction["version"]
174-
return [self._prepare_model(obj) for obj in predictions]
173+
if cursor is None:
174+
raise ValueError("cursor cannot be None")
175+
176+
resp = self._client._request(
177+
"GET", "/v1/predictions" if cursor is ... else cursor
178+
)
179+
return Page[Prediction](self._client, self, **resp.json())
175180

176181
def get(self, id: str) -> Prediction: # pylint: disable=invalid-name
177182
"""

replicate/training.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from replicate.exceptions import ReplicateException
77
from replicate.files import upload_file
88
from replicate.json import encode_json
9+
from replicate.pagination import Page
910
from replicate.resource import Namespace, Resource
1011
from replicate.version import Version
1112

@@ -90,21 +91,25 @@ class CreateParams(TypedDict):
9091
webhook_completed: NotRequired[str]
9192
webhook_events_filter: NotRequired[List[str]]
9293

93-
def list(self) -> List[Training]:
94+
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
9495
"""
9596
List your trainings.
9697
98+
Parameters:
99+
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
97100
Returns:
98-
List[Training]: A list of training objects.
101+
Page[Training]: A page of trainings.
102+
Raises:
103+
ValueError: If `cursor` is `None`.
99104
"""
100105

101-
resp = self._client._request("GET", "/v1/trainings")
102-
# TODO: paginate
103-
trainings = resp.json()["results"]
104-
for training in trainings:
105-
# HACK: resolve this? make it lazy somehow?
106-
del training["version"]
107-
return [self._prepare_model(obj) for obj in trainings]
106+
if cursor is None:
107+
raise ValueError("cursor cannot be None")
108+
109+
resp = self._client._request(
110+
"GET", "/v1/trainings" if cursor is ... else cursor
111+
)
112+
return Page[Training](self._client, self, **resp.json())
108113

109114
def get(self, id: str) -> Training: # pylint: disable=invalid-name
110115
"""

0 commit comments

Comments
 (0)