Skip to content

Commit 5f7ae72

Browse files
authored
Replace Mypy with Pyright (#206)
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 7926e2b commit 5f7ae72

File tree

12 files changed

+45
-39
lines changed

12 files changed

+45
-39
lines changed

.vscode/extensions.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
"charliermarsh.ruff",
44
"ms-python.python",
55
"ms-python.vscode-pylance",
6-
"ms-python.mypy-type-checker"
76
]
87
}

.vscode/settings.json

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
}
1616
},
1717
"python.languageServer": "Pylance",
18+
"python.analysis.typeCheckingMode": "basic",
1819
"python.testing.pytestArgs": [
1920
"-vvv",
2021
"python"
2122
],
2223
"python.testing.unittestEnabled": false,
2324
"python.testing.pytestEnabled": true,
24-
"mypy-type-checker.args": [
25-
"--show-column-numbers",
26-
"--no-pretty"
27-
],
2825
"ruff.lint.args": [
2926
"--config=pyproject.toml"
3027
],

pyproject.toml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ dependencies = [
1717
"typing_extensions>=4.5.0",
1818
]
1919
optional-dependencies = { dev = [
20-
"mypy",
2120
"pylint",
21+
"pyright",
2222
"pytest",
2323
"pytest-asyncio",
2424
"pytest-recording",
@@ -39,11 +39,6 @@ packages = ["replicate"]
3939
[tool.setuptools.package-data]
4040
"replicate" = ["py.typed"]
4141

42-
[tool.mypy]
43-
plugins = "pydantic.mypy"
44-
exclude = ["tests/"]
45-
enable_incomplete_feature = ["Unpack"]
46-
4742
[tool.pylint.main]
4843
disable = [
4944
"C0301", # Line too long

replicate/collection.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, Iterator, List, Optional, Union, overload
22

33
from typing_extensions import deprecated
44

@@ -32,13 +32,24 @@ def id(self) -> str:
3232
"""
3333
return self.slug
3434

35-
def __iter__(self): # noqa: ANN204
36-
return iter(self.models)
35+
def __iter__(self) -> Iterator[Model]:
36+
if self.models is not None:
37+
return iter(self.models)
38+
return iter([])
39+
40+
@overload
41+
def __getitem__(self, index: int) -> Optional[Model]:
42+
...
3743

38-
def __getitem__(self, index) -> Optional[Model]:
44+
@overload
45+
def __getitem__(self, index: slice) -> Optional[List[Model]]:
46+
...
47+
48+
def __getitem__(
49+
self, index: Union[int, slice]
50+
) -> Union[Optional[Model], Optional[List[Model]]]:
3951
if self.models is not None:
4052
return self.models[index]
41-
4253
return None
4354

4455
def __len__(self) -> int:

replicate/json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def encode_json(
3131
if isinstance(obj, io.IOBase):
3232
return upload_file(obj)
3333
if HAS_NUMPY:
34-
if isinstance(obj, np.integer):
34+
if isinstance(obj, np.integer): # type: ignore
3535
return int(obj)
36-
if isinstance(obj, np.floating):
36+
if isinstance(obj, np.floating): # type: ignore
3737
return float(obj)
38-
if isinstance(obj, np.ndarray):
38+
if isinstance(obj, np.ndarray): # type: ignore
3939
return obj.tolist()
4040
return obj

replicate/pagination.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
pass
2525

2626

27-
class Page(pydantic.BaseModel, Generic[T]):
27+
class Page(pydantic.BaseModel, Generic[T]): # type: ignore
2828
"""
2929
A page of results from the API.
3030
"""

replicate/resource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from replicate.client import Client
1111

1212

13-
class Resource(pydantic.BaseModel):
13+
class Resource(pydantic.BaseModel): # type: ignore
1414
"""
1515
A base class for representing a single object on the server.
1616
"""

replicate/stream.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from replicate.prediction import Predictions
2828

2929

30-
class ServerSentEvent(pydantic.BaseModel):
30+
class ServerSentEvent(pydantic.BaseModel): # type: ignore
3131
"""
3232
A server-sent event.
3333
"""
@@ -136,10 +136,10 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
136136
if sse is not None:
137137
if sse.event == "done":
138138
return
139-
elif sse.event == "error":
139+
if sse.event == "error":
140140
raise RuntimeError(sse.data)
141-
else:
142-
yield sse
141+
142+
yield sse
143143

144144
async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
145145
decoder = EventSource.Decoder()
@@ -149,10 +149,10 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
149149
if sse is not None:
150150
if sse.event == "done":
151151
return
152-
elif sse.event == "error":
152+
if sse.event == "error":
153153
raise RuntimeError(sse.data)
154-
else:
155-
yield sse
154+
155+
yield sse
156156

157157

158158
def stream(

replicate/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ def create( # type: ignore
231231
Create a new training using the specified model version as a base.
232232
"""
233233

234+
url = None
235+
234236
# Support positional arguments for backwards compatibility
235237
if args:
236238
if shorthand := args[0] if len(args) > 0 else None:
@@ -245,12 +247,12 @@ def create( # type: ignore
245247
params["webhook_completed"] = args[4]
246248
if len(args) > 5:
247249
params["webhook_events_filter"] = args[5]
248-
249250
elif model and version:
250251
url = _create_training_url_from_model_and_version(model, version)
251252
elif model is None and isinstance(version, str):
252253
url = _create_training_url_from_shorthand(version)
253-
else:
254+
255+
if not url:
254256
raise ValueError("model and version or shorthand version must be specified")
255257

256258
body = _create_training_body(input, **params)
@@ -376,6 +378,8 @@ def _create_training_url_from_model_and_version(
376378
owner, name = model.owner, model.name
377379
elif isinstance(model, tuple):
378380
owner, name = model[0], model[1]
381+
else:
382+
raise ValueError("model must be a Model or a tuple of (owner, name)")
379383

380384
if isinstance(version, Version):
381385
version_id = version.id

requirements-dev.txt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@ mccabe==0.7.0
3737
# via pylint
3838
multidict==6.0.4
3939
# via yarl
40-
mypy==1.4.1
41-
# via replicate (pyproject.toml)
42-
mypy-extensions==1.0.0
43-
# via mypy
40+
nodeenv==1.8.0
41+
# via pyright
4442
packaging==23.1
4543
# via
4644
# pytest
@@ -55,6 +53,8 @@ pydantic-core==2.3.0
5553
# via pydantic
5654
pylint==3.0.2
5755
# via replicate (pyproject.toml)
56+
pyright==1.1.337
57+
# via replicate (pyproject.toml)
5858
pytest==7.4.0
5959
# via
6060
# pytest-asyncio
@@ -79,7 +79,6 @@ tomlkit==0.12.1
7979
# via pylint
8080
typing-extensions==4.7.1
8181
# via
82-
# mypy
8382
# pydantic
8483
# pydantic-core
8584
# replicate (pyproject.toml)
@@ -89,3 +88,6 @@ wrapt==1.15.0
8988
# via vcrpy
9089
yarl==1.9.2
9190
# via vcrpy
91+
92+
# The following packages are considered to be unsafe in a requirements file:
93+
# setuptools

0 commit comments

Comments
 (0)