Skip to content

Commit 65c0283

Browse files
committed
Fix type hints.
1 parent 3188667 commit 65c0283

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

src/sqlalchemydiff/comparer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,14 @@ def compare(
133133
db_one_info = self._get_db_info(ignore_specs, inspector, self.db_one_engine)
134134
db_two_info = self._get_db_info(ignore_specs, inspector, self.db_two_engine)
135135

136-
if None not in [db_one_info, db_two_info]:
136+
if db_one_info is not None and db_two_info is not None:
137137
result[key] = inspector.diff(db_one_info, db_two_info)
138138

139139
return self.compare_result_class(result, one_alias=one_alias, two_alias=two_alias)
140140

141-
def _filter_inspectors(self, ignore_inspectors: Optional[set[str]]) -> list[str]:
141+
def _filter_inspectors(
142+
self, ignore_inspectors: Optional[set[str]]
143+
) -> list[tuple[str, type[BaseInspector]]]:
142144
if not ignore_inspectors:
143145
ignore_inspectors = set()
144146

src/sqlalchemydiff/inspection/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def diff(self, one: dict, two: dict) -> dict: ... # pragma: no cover
6868
@abc.abstractmethod
6969
def _is_supported(self, inspector: Inspector) -> bool: ... # pragma: no cover
7070

71-
def _filter_ignorers(self, specs: list[IgnoreSpecType]) -> IgnoreClauses:
71+
def _filter_ignorers(self, specs: Optional[list[IgnoreSpecType]]) -> IgnoreClauses:
7272
tables, enums, clauses = [], [], []
7373

7474
for spec in specs or []:

src/sqlalchemydiff/inspection/ignore.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from typing import NamedTuple, Optional, Union
33

44

@@ -60,9 +60,10 @@ def create_specs(
6060

6161
@dataclass
6262
class IgnoreClauses:
63-
tables: Optional[list[str]] = None
64-
enums: Optional[list[str]] = None
65-
clauses: Optional[list[IgnoreSpecType]] = None
63+
tables: list[str] = field(default_factory=list)
64+
enums: list[str] = field(default_factory=list)
65+
clauses: list[IgnoreSpecType] = field(default_factory=list)
6666

67-
def is_clause(self, table_name: str, inspector_key: str, object_name: str) -> bool:
68-
return (table_name, inspector_key, object_name) in self.clauses
67+
def is_clause(self, table_name: str, inspector_key: str, object_name: Optional[str]) -> bool:
68+
clause = TableIgnoreSpec(table_name, inspector_key, object_name)
69+
return clause in self.clauses

src/sqlalchemydiff/inspection/inspectors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, cast
22

33
from sqlalchemy.engine import Engine
44

@@ -209,7 +209,7 @@ def _format_unique_constraint(self, inspector: Inspector, table_name: str) -> li
209209
if not name:
210210
name = f"unique_{table_name}_{'_'.join(constraint.get('column_names'))}"
211211
constraint["name"] = name
212-
return result
212+
return cast(list[dict], result)
213213

214214
def _is_supported(self, inspector: Inspector) -> bool:
215215
return hasattr(inspector, "get_unique_constraints")
@@ -255,7 +255,7 @@ def inspect(
255255
inspector = self._get_inspector(engine)
256256

257257
ignore_clauses = self._filter_ignorers(ignore_specs)
258-
enums = inspector.get_enums() or []
258+
enums = getattr(inspector, "get_enums", lambda: [])() or []
259259
return [enum for enum in enums if enum["name"] not in ignore_clauses.enums]
260260

261261
def diff(self, one: dict, two: dict) -> dict:

src/sqlalchemydiff/inspection/mixins.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ class DiffMixin:
77
Provides methods used by the inspectors to diff the results.
88
"""
99

10+
one_alias: str
11+
two_alias: str
12+
one_only_alias: str
13+
two_only_alias: str
14+
1015
def _get_empty_result(self) -> dict:
1116
return {
1217
self.one_only_alias: [],
@@ -79,7 +84,7 @@ def _listdiff(self, one: Mapping[str, Iterable], two: Mapping[str, Iterable]) ->
7984

8085
return result
8186

82-
def _itemsdiff(self, items_in_one: Iterable[Mapping], items_in_two: Iterable[Mapping]) -> list:
87+
def _itemsdiff(self, items_in_one: Iterable[Mapping], items_in_two: Iterable[Mapping]) -> dict:
8388
"""Diff iterables of items in mapping format.
8489
8590
`items_in_one` and `items_in_two` are iterables of items in mapping format.

0 commit comments

Comments
 (0)