|
1 |
| -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union |
| 1 | +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
|
4 | 4 |
|
5 | 5 | from replicate.exceptions import ReplicateException
|
| 6 | +from replicate.identifier import ModelIdentifier |
6 | 7 | from replicate.pagination import Page
|
7 |
| -from replicate.prediction import Prediction |
| 8 | +from replicate.prediction import ( |
| 9 | + Prediction, |
| 10 | + _create_prediction_body, |
| 11 | + _json_to_prediction, |
| 12 | +) |
8 | 13 | from replicate.resource import Namespace, Resource
|
9 | 14 | from replicate.version import Version, Versions
|
10 | 15 |
|
|
16 | 21 |
|
17 | 22 | if TYPE_CHECKING:
|
18 | 23 | from replicate.client import Client
|
| 24 | + from replicate.prediction import Predictions |
19 | 25 |
|
20 | 26 |
|
21 | 27 | class Model(Resource):
|
@@ -140,6 +146,14 @@ class Models(Namespace):
|
140 | 146 |
|
141 | 147 | model = Model
|
142 | 148 |
|
| 149 | + @property |
| 150 | + def predictions(self) -> "ModelsPredictions": |
| 151 | + """ |
| 152 | + Get a namespace for operations related to predictions on a model. |
| 153 | + """ |
| 154 | + |
| 155 | + return ModelsPredictions(client=self._client) |
| 156 | + |
143 | 157 | def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821
|
144 | 158 | """
|
145 | 159 | List all public models.
|
@@ -275,6 +289,54 @@ async def async_create(
|
275 | 289 | return _json_to_model(self._client, resp.json())
|
276 | 290 |
|
277 | 291 |
|
| 292 | +class ModelsPredictions(Namespace): |
| 293 | + """ |
| 294 | + Namespace for operations related to predictions in a deployment. |
| 295 | + """ |
| 296 | + |
| 297 | + def create( |
| 298 | + self, |
| 299 | + model: Optional[Union[str, Tuple[str, str], "Model"]], |
| 300 | + input: Dict[str, Any], |
| 301 | + **params: Unpack["Predictions.CreatePredictionParams"], |
| 302 | + ) -> Prediction: |
| 303 | + """ |
| 304 | + Create a new prediction with the deployment. |
| 305 | + """ |
| 306 | + |
| 307 | + url = _create_prediction_url_from_model(model) |
| 308 | + body = _create_prediction_body(version=None, input=input, **params) |
| 309 | + |
| 310 | + resp = self._client._request( |
| 311 | + "POST", |
| 312 | + url, |
| 313 | + json=body, |
| 314 | + ) |
| 315 | + |
| 316 | + return _json_to_prediction(self._client, resp.json()) |
| 317 | + |
| 318 | + async def async_create( |
| 319 | + self, |
| 320 | + model: Optional[Union[str, Tuple[str, str], "Model"]], |
| 321 | + input: Dict[str, Any], |
| 322 | + **params: Unpack["Predictions.CreatePredictionParams"], |
| 323 | + ) -> Prediction: |
| 324 | + """ |
| 325 | + Create a new prediction with the deployment. |
| 326 | + """ |
| 327 | + |
| 328 | + url = _create_prediction_url_from_model(model) |
| 329 | + body = _create_prediction_body(version=None, input=input, **params) |
| 330 | + |
| 331 | + resp = await self._client._async_request( |
| 332 | + "POST", |
| 333 | + url, |
| 334 | + json=body, |
| 335 | + ) |
| 336 | + |
| 337 | + return _json_to_prediction(self._client, resp.json()) |
| 338 | + |
| 339 | + |
278 | 340 | def _create_model_body( # pylint: disable=too-many-arguments
|
279 | 341 | owner: str,
|
280 | 342 | name: str,
|
@@ -318,3 +380,22 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
|
318 | 380 | if model.default_example is not None:
|
319 | 381 | model.default_example._client = client
|
320 | 382 | return model
|
| 383 | + |
| 384 | + |
| 385 | +def _create_prediction_url_from_model( |
| 386 | + model: Union[str, Tuple[str, str], "Model"] |
| 387 | +) -> str: |
| 388 | + owner, name = None, None |
| 389 | + if isinstance(model, Model): |
| 390 | + owner, name = model.owner, model.name |
| 391 | + elif isinstance(model, tuple): |
| 392 | + owner, name = model[0], model[1] |
| 393 | + elif isinstance(model, str): |
| 394 | + owner, name = ModelIdentifier.parse(model) |
| 395 | + |
| 396 | + if owner is None or name is None: |
| 397 | + raise ValueError( |
| 398 | + "model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'" |
| 399 | + ) |
| 400 | + |
| 401 | + return f"/v1/models/{owner}/{name}/predictions" |
0 commit comments