Skip to content

Commit 639f234

Browse files
committed
Add support for typing use() function
1 parent 35e66dc commit 639f234

File tree

3 files changed

+136
-21
lines changed

3 files changed

+136
-21
lines changed

README.md

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -592,16 +592,49 @@ for output in outputs:
592592
print(get_url_path(output)) # "https://replicate.delivery/xyz"
593593
```
594594

595+
### Typing
596+
597+
By default `use()` knows nothing about the interface of the model. To provide a better developer experience we provide two methods to add type annotations to the function returned by the `use()` helper.
598+
599+
**1. Provide a function signature**
600+
601+
The use method accepts a function signature as an additional `hint` keyword argument. When provided it will use this signature for the `model()` and `model.create()` functions.
602+
603+
```py
604+
# Flux takes a required prompt string and optional image and seed.
605+
def hint(*, prompt: str, image: Path | None = None, seed: int | None = None) -> str: ...
606+
607+
flux_dev = use("black-forest-labs/flux-dev", hint=hint)
608+
output1 = flux_dev() # will warn that `prompt` is missing
609+
output2 = flux_dev(prompt="str") # output2 will be typed as `str`
610+
```
611+
612+
**2. Provide a class**
613+
614+
The second method requires creating a callable class with a `name` field. The name will be used as the function reference when passed to `use()`.
615+
616+
```py
617+
class FluxDev:
618+
name = "black-forest-labs/flux-dev"
619+
620+
def __call__( self, *, prompt: str, image: Path | None = None, seed: int | None = None ) -> str: ...
621+
622+
flux_dev = use(FluxDev)
623+
output1 = flux_dev() # will warn that `prompt` is missing
624+
output2 = flux_dev(prompt="str") # output2 will be typed as `str`
625+
```
626+
627+
In future we hope to provide tooling to generate and provide these models as packages to make working with them easier. For now you may wish to create your own.
628+
595629
### TODO
596630

597631
There are several key things still outstanding:
598632

599633
1. Support for asyncio.
600-
2. Support for typing the return value.
601-
3. Support for streaming text when available (rather than polling)
602-
4. Support for streaming files when available (rather than polling)
603-
5. Support for cleaning up downloaded files.
604-
6. Support for streaming logs using `OutputIterator`.
634+
2. Support for streaming text when available (rather than polling)
635+
3. Support for streaming files when available (rather than polling)
636+
4. Support for cleaning up downloaded files.
637+
5. Support for streaming logs using `OutputIterator`.
605638

606639
## Development
607640

replicate/use.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,24 @@
44
# - [ ] Support asyncio variant
55
import inspect
66
import os
7+
import sys
78
import tempfile
89
from dataclasses import dataclass
910
from functools import cached_property
1011
from pathlib import Path
11-
from typing import Any, Iterator, Optional, Tuple, Union
12+
from typing import (
13+
Any,
14+
Callable,
15+
Generic,
16+
Iterator,
17+
Optional,
18+
ParamSpec,
19+
Protocol,
20+
Tuple,
21+
TypeVar,
22+
cast,
23+
overload,
24+
)
1225
from urllib.parse import urlparse
1326

1427
import httpx
@@ -24,16 +37,35 @@
2437
__all__ = ["use", "get_path_url"]
2538

2639

40+
def _in_repl() -> bool:
41+
return bool(
42+
sys.flags.interactive # python -i
43+
or hasattr(sys, "ps1") # prompt strings exist
44+
or (
45+
sys.stdin.isatty() # tty
46+
and sys.stdout.isatty()
47+
)
48+
or ("get_ipython" in globals())
49+
)
50+
51+
2752
def _in_module_scope() -> bool:
2853
"""
2954
Returns True when called from top level module scope.
3055
"""
3156
if os.getenv("REPLICATE_ALWAYS_ALLOW_USE"):
3257
return True
3358

59+
# If we're running in a REPL.
60+
if _in_repl():
61+
return True
62+
3463
if frame := inspect.currentframe():
64+
print(frame)
3565
if caller := frame.f_back:
66+
print(caller.f_code.co_name)
3667
return caller.f_code.co_name == "<module>"
68+
3769
return False
3870

3971

@@ -248,16 +280,26 @@ def get_path_url(path: Any) -> str | None:
248280
return None
249281

250282

283+
Input = ParamSpec("Input")
284+
Output = TypeVar("Output")
285+
286+
287+
class FunctionRef(Protocol, Generic[Input, Output]):
288+
name: str
289+
290+
__call__: Callable[Input, Output]
291+
292+
251293
@dataclass
252-
class Run:
294+
class Run[O]:
253295
"""
254296
Represents a running prediction with access to its version.
255297
"""
256298

257299
prediction: Prediction
258300
schema: dict
259301

260-
def output(self) -> Union[Any, Iterator[Any]]:
302+
def output(self) -> O:
261303
"""
262304
Wait for the prediction to complete and return its output.
263305
"""
@@ -269,10 +311,13 @@ def output(self) -> Union[Any, Iterator[Any]]:
269311
# Return an OutputIterator for iterator output types (including concatenate iterators)
270312
if _has_iterator_output_type(self.schema):
271313
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
272-
return OutputIterator(
273-
lambda: self.prediction.output_iterator(),
274-
self.schema,
275-
is_concatenate=is_concatenate,
314+
return cast(
315+
O,
316+
OutputIterator(
317+
lambda: self.prediction.output_iterator(),
318+
self.schema,
319+
is_concatenate=is_concatenate,
320+
),
276321
)
277322

278323
# Process output for file downloads based on schema
@@ -288,7 +333,7 @@ def logs(self) -> Optional[str]:
288333

289334

290335
@dataclass
291-
class Function:
336+
class Function(Generic[Input, Output]):
292337
"""
293338
A wrapper for a Replicate model that can be called as a function.
294339
"""
@@ -333,11 +378,10 @@ def _version(self) -> Version | None:
333378

334379
return version
335380

336-
def __call__(self, **inputs: Any) -> Any:
337-
run = self.create(**inputs)
338-
return run.output()
381+
def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
382+
return self.create(*args, **inputs).output()
339383

340-
def create(self, **inputs: Any) -> Run:
384+
def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
341385
"""
342386
Start a prediction with the specified inputs.
343387
"""
@@ -365,14 +409,14 @@ def create(self, **inputs: Any) -> Run:
365409
return Run(prediction, self.openapi_schema)
366410

367411
@property
368-
def default_example(self) -> Optional[Prediction]:
412+
def default_example(self) -> Optional[dict[str, Any]]:
369413
"""
370414
Get the default example for this model.
371415
"""
372416
raise NotImplementedError("This property has not yet been implemented")
373417

374418
@cached_property
375-
def openapi_schema(self) -> dict[Any, Any]:
419+
def openapi_schema(self) -> dict[str, Any]:
376420
"""
377421
Get the OpenAPI schema for this model version.
378422
"""
@@ -387,7 +431,19 @@ def openapi_schema(self) -> dict[Any, Any]:
387431
return schema
388432

389433

390-
def use(function_ref: str) -> Function:
434+
@overload
435+
def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ...
436+
437+
438+
@overload
439+
def use(ref: str, *, hint: Callable[Input, Output]) -> Function[Input, Output]: ...
440+
441+
442+
def use(
443+
ref: str | FunctionRef[Input, Output],
444+
*,
445+
hint: Callable[Input, Output] | None = None,
446+
) -> Function[Input, Output]:
391447
"""
392448
Use a Replicate model as a function.
393449
@@ -402,4 +458,9 @@ def use(function_ref: str) -> Function:
402458
if not _in_module_scope():
403459
raise RuntimeError("You may only call replicate.use() at the top level.")
404460

405-
return Function(function_ref)
461+
try:
462+
ref = ref.name # type: ignore
463+
except AttributeError:
464+
pass
465+
466+
return Function(str(ref))

tests/test_use.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,27 @@ async def test_use_with_version_identifier(client_mode):
280280
assert output == "not hotdog"
281281

282282

283+
@pytest.mark.asyncio
284+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT])
285+
@respx.mock
286+
async def test_use_with_function_ref(client_mode):
287+
mock_model_endpoints()
288+
mock_prediction_endpoints()
289+
290+
class HotdogDetector:
291+
name = "acme/hotdog-detector:xyz123"
292+
293+
def __call__(self, prompt: str) -> str: ...
294+
295+
hotdog_detector = replicate.use(HotdogDetector())
296+
297+
# Call function with prompt="hello world"
298+
output = hotdog_detector(prompt="hello world")
299+
300+
# Assert that output is the completed output from the prediction request
301+
assert output == "not hotdog"
302+
303+
283304
@pytest.mark.asyncio
284305
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT])
285306
@respx.mock

0 commit comments

Comments
 (0)