Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ repos:
hooks:
- id: pyupgrade
args: [--py310-plus]
- repo: https://github.com/4Catalyzer/fourmat
rev: v1.0.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.3
hooks:
- id: fourmat
additional_dependencies: [setuptools>74]
- id: ruff
args: [ --fix ]
- id: ruff-format
- repo: https://github.com/asottile/blacken-docs
rev: 1.18.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion example/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
app = Flask(__name__)
app.config.from_object(settings)

from . import routes # noqa: F401 isort:skip
from . import routes # noqa: E402 F401
8 changes: 2 additions & 6 deletions example/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ class Author(db.Model):

id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Text, nullable=False)
created_at = db.Column(
db.DateTime, default=dt.datetime.utcnow, nullable=False
)
created_at = db.Column(db.DateTime, default=dt.datetime.utcnow, nullable=False)


class Book(db.Model):
Expand All @@ -27,6 +25,4 @@ class Book(db.Model):
author_id = db.Column(db.Integer, db.ForeignKey(Author.id), nullable=False)
author = db.relationship(Author, backref=db.backref("books"))
published_at = db.Column(db.DateTime, nullable=False)
created_at = db.Column(
db.DateTime, default=dt.datetime.utcnow, nullable=False
)
created_at = db.Column(db.DateTime, default=dt.datetime.utcnow, nullable=False)
4 changes: 1 addition & 3 deletions flask_resty/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def get_error_from_http_exception(cls, exc):
}

@classmethod
def from_validation_error(
cls, status_code, error, format_validation_error
):
def from_validation_error(cls, status_code, error, format_validation_error):
return cls(
status_code,
*(
Expand Down
8 changes: 2 additions & 6 deletions flask_resty/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ class RelatedItem(fields.Nested):

def _deserialize(self, value, *args, **kwargs):
if self.many and not marshmallow.utils.is_collection(value):
raise self.make_error(
"type", input=value, type=value.__class__.__name__
)
raise self.make_error("type", input=value, type=value.__class__.__name__)

# Do partial load of related item, as we only need the id.
return self.schema.load(value, partial=True)
Expand All @@ -46,9 +44,7 @@ class DelimitedList(fields.List):

delimiter = ","

def __init__(
self, cls_or_instance, delimiter=None, as_string=False, **kwargs
):
def __init__(self, cls_or_instance, delimiter=None, as_string=False, **kwargs):
super().__init__(cls_or_instance, **kwargs)

self.delimiter = delimiter or self.delimiter
Expand Down
8 changes: 2 additions & 6 deletions flask_resty/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ class FieldFilterBase(ArgFilterBase):
throwing an API error.
"""

def __init__(
self, *, separator=",", allow_empty=False, skip_invalid=False
):
def __init__(self, *, separator=",", allow_empty=False, skip_invalid=False):
self._separator = separator
self._allow_empty = allow_empty
self._skip_invalid = skip_invalid
Expand Down Expand Up @@ -101,9 +99,7 @@ def get_default_filter(self, view):
if field.required:
raise ApiError(400, {"code": "invalid_filter.missing"})

load_default = (
field.load_default if _USE_LOAD_DEFAULT else field.missing
)
load_default = field.load_default if _USE_LOAD_DEFAULT else field.missing
value = load_default() if callable(load_default) else load_default
if value is marshmallow.missing:
return None
Expand Down
17 changes: 4 additions & 13 deletions flask_resty/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,15 @@ def decode(self, jwt, **kwargs):

# jwt.decode will also check this, but this is more defensive.
if alg not in kwargs["algorithms"]:
raise InvalidAlgorithmError(
"The specified alg value is not allowed"
)
raise InvalidAlgorithmError("The specified alg value is not allowed")

return super().decode(
jwt, key=self.get_key_from_jwk(jwk, alg), **kwargs
)
return super().decode(jwt, key=self.get_key_from_jwk(jwk, alg), **kwargs)

def get_jwk_for_jwt(self, unverified_header):
try:
token_kid = unverified_header["kid"]
except KeyError as e:
raise InvalidTokenError(
"Key ID header parameter is missing"
) from e
raise InvalidTokenError("Key ID header parameter is missing") from e

for jwk in self.jwk_set["keys"]:
if jwk["kid"] == token_kid:
Expand Down Expand Up @@ -128,7 +122,4 @@ def _pyjwt(self):

@property
def jwk_set(self):
return (
self._jwk_set
or flask.current_app.config[self.get_config_key("jwk_set")]
)
return self._jwk_set or flask.current_app.config[self.get_config_key("jwk_set")]
26 changes: 7 additions & 19 deletions flask_resty/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,7 @@ def get_limit(self):
def reversed(self):
return self.get_cursor_info().reversed

def adjust_sort_ordering(
self, view: ModelView, field_orderings
) -> FieldOrderings:
def adjust_sort_ordering(self, view: ModelView, field_orderings) -> FieldOrderings:
"""Ensure the query is sorted correctly and get the field orderings.

The implementation of cursor-based pagination in Flask-RESTy requires
Expand Down Expand Up @@ -456,8 +454,7 @@ def parse_cursor(

deserializer = view.deserializer
column_fields = (
deserializer.fields[field_name]
for field_name, _ in field_orderings
deserializer.fields[field_name] for field_name, _ in field_orderings
)

try:
Expand Down Expand Up @@ -503,9 +500,7 @@ def deserialize_value(self, field, value):
def format_validation_error(self, message, path):
return {"code": "invalid_cursor", "detail": message}

def get_filter(
self, view, field_orderings: FieldOrderings, cursor: Cursor
):
def get_filter(self, view, field_orderings: FieldOrderings, cursor: Cursor):
"""Build the filter clause corresponding to a cursor.

Given the field orderings and the cursor as above, this will construct
Expand Down Expand Up @@ -539,8 +534,7 @@ def get_previous_clause(column_cursors):
if not column_cursors:
return None
clauses = [
column.isnot_distinct_from(value)
for column, _, value in column_cursors
column.isnot_distinct_from(value) for column, _, value in column_cursors
]

return sa.and_(*clauses)
Expand All @@ -565,9 +559,7 @@ def _prepare_current_clause(self, column, asc, value):
if value is None:
return None
elif value is not None:
current_clause = self._handle_nullable(
column, value, is_nullable
)
current_clause = self._handle_nullable(column, value, is_nullable)
else:
current_clause = column > value
else:
Expand Down Expand Up @@ -626,9 +618,7 @@ def make_cursor(self, item, view, field_orderings):

def get_column_fields(self, view, field_orderings):
serializer = view.serializer
return tuple(
serializer.fields[field_name] for field_name, _ in field_orderings
)
return tuple(serializer.fields[field_name] for field_name, _ in field_orderings)

def render_cursor(self, item, column_fields):
cursor = tuple(
Expand Down Expand Up @@ -728,9 +718,7 @@ def get_page(self, query, view):
# Relay expects a cursor for each item.
cursors_out = self.make_cursors(items, view, field_orderings)

page_info = self.get_page_info(
query, view, field_orderings, cursor_in, items
)
page_info = self.get_page_info(query, view, field_orderings, cursor_in, items)

meta.update_response_meta({"cursors": cursors_out, **page_info})

Expand Down
16 changes: 4 additions & 12 deletions flask_resty/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
| _| | | (_| \__ \ <_____| _ <| |___ ___) || || |_| |
|_| |_|\__,_|___/_|\_\ |_| \_\_____|____/ |_| \__, |
|___/
""".strip(
"\n"
)
""".strip("\n")

DEFAULTS = dict(
RESTY_SHELL_CONTEXT={},
Expand Down Expand Up @@ -116,9 +114,7 @@ def context_formatter(
- schema_context.keys()
- model_context.keys()
)
additional_context = {
key: full_context[key] for key in additional_context_keys
}
additional_context = {key: full_context[key] for key in additional_context_keys}
if additional_context:
sections.append(("Additional", additional_context))
return "\n".join([format_section(*section) for section in sections])
Expand All @@ -136,9 +132,7 @@ def cli(shell: str, sqlalchemy_echo: bool):
"""An improved Flask shell command."""
from flask.globals import current_app

options = {
key: current_app.config.get(key, DEFAULTS[key]) for key in DEFAULTS
}
options = {key: current_app.config.get(key, DEFAULTS[key]) for key in DEFAULTS}
current_app.config["SQLALCHEMY_ECHO"] = sqlalchemy_echo
base_context = {"app": current_app}
flask_context = current_app.make_shell_context()
Expand Down Expand Up @@ -180,9 +174,7 @@ def cli(shell: str, sqlalchemy_echo: bool):
ipy_extensions=options["RESTY_SHELL_IPY_EXTENSIONS"],
ipy_autoreload=options["RESTY_SHELL_IPY_AUTORELOAD"],
ipy_colors=options["RESTY_SHELL_IPY_COLORS"],
ipy_highlighting_style=options[
"RESTY_SHELL_IPY_HIGHLIGHTING_STYLE"
],
ipy_highlighting_style=options["RESTY_SHELL_IPY_HIGHLIGHTING_STYLE"],
)
)
if Path(".konchrc.local").exists(): # pragma: no cover
Expand Down
4 changes: 1 addition & 3 deletions flask_resty/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def get_field_orderings(self, fields):
:return: A sequence of field orderings. See :py:meth:`get_criterion`.
:rtype: tuple
"""
return tuple(
self.get_field_ordering(field) for field in fields.split(",")
)
return tuple(self.get_field_ordering(field) for field in fields.split(","))

def get_field_ordering(self, field: str) -> FieldOrdering:
if field and field[0] == "-":
Expand Down
20 changes: 5 additions & 15 deletions flask_resty/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ class ApiClient(FlaskClient):
"""

def open(self, path, *args, **kwargs):
full_path = "{}{}".format(
self.application.extensions["resty"].api.prefix, path
)
full_path = "{}{}".format(self.application.extensions["resty"].api.prefix, path)

if "data" in kwargs:
kwargs.setdefault("content_type", "application/json")
Expand Down Expand Up @@ -61,9 +59,7 @@ def assert_shape(actual, expected, key=None):

if key is not None:
suffix = (
" for parent "
+ ("index" if isinstance(key, int) else "key")
+ f" {key!r}"
" for parent " + ("index" if isinstance(key, int) else "key") + f" {key!r}"
)

if isinstance(expected, Mapping):
Expand All @@ -84,9 +80,7 @@ def assert_shape(actual, expected, key=None):
elif isinstance(expected, (str, bytes)):
assert expected == actual
elif isinstance(expected, Sequence):
assert isinstance(actual, Sequence), (
f"{actual!r} is not a Sequence" + suffix
)
assert isinstance(actual, Sequence), f"{actual!r} is not a Sequence" + suffix

actual_len = len(actual)
expected_len = len(expected)
Expand All @@ -100,18 +94,14 @@ def assert_shape(actual, expected, key=None):
)
+ suffix
)
for idx, (actual_item, expected_item) in enumerate(
zip(actual, expected)
):
for idx, (actual_item, expected_item) in enumerate(zip(actual, expected)):
assert_shape(actual_item, expected_item, key=idx)
elif isinstance(expected, float):
assert (
abs(actual - expected) < 1e-6
), "float not within the allowed tolerance of 1e-6"
else:
assert expected == actual, (
f"{actual!r} is not equal to {expected!r}" + suffix
)
assert expected == actual, f"{actual!r} is not equal to {expected!r}" + suffix


def Shape(expected):
Expand Down
20 changes: 5 additions & 15 deletions flask_resty/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,7 @@ def get_location(self, item):
:return: The canonical URL for `item`.
:rtype: str
"""
id_dict = {
id_field: getattr(item, id_field) for id_field in self.id_fields
}
id_dict = {id_field: getattr(item, id_field) for id_field in self.id_fields}
return flask.url_for(flask.request.endpoint, _method="GET", **id_dict)

def get_request_data(self, **kwargs):
Expand Down Expand Up @@ -304,9 +302,7 @@ def format_validation_error(self, message, path):
:return: The formatted validation error.
:rtype: dict
"""
pointer = "/data/{}".format(
"/".join(str(field_key) for field_key in path)
)
pointer = "/data/{}".format("/".join(str(field_key) for field_key in path))

return {
"code": "invalid_data",
Expand Down Expand Up @@ -383,9 +379,7 @@ def request_args(self):
# KeyError for args that aren't present.
continue

if isinstance(field, fields.List) and not isinstance(
field, DelimitedList
):
if isinstance(field, fields.List) and not isinstance(field, DelimitedList):
value = args.getlist(field_name)
else:
value = args.get(field_name)
Expand Down Expand Up @@ -960,9 +954,7 @@ def resolve_integrity_error(self, error):
# a schema bug, so we emit an interal server error instead.
return error

flask.current_app.logger.warning(
"handled integrity error", exc_info=error
)
flask.current_app.logger.warning("handled integrity error", exc_info=error)
return ApiError(409, {"code": "invalid_data.conflict"})

def set_item_response_meta(self, item):
Expand Down Expand Up @@ -1049,9 +1041,7 @@ def retrieve(self, id, *, create_transient_stub=False):
:return: An HTTP 200 response.
:rtype: :py:class:`flask.Response`
"""
item = self.get_item_or_404(
id, create_transient_stub=create_transient_stub
)
item = self.get_item_or_404(id, create_transient_stub=create_transient_stub)
return self.make_item_response(item)

def create(self, *, allow_client_id=False):
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
"jwt": ("PyJWT>=2.0.0", "cryptography>=2.0.0"),
"tests": ("coverage", "psycopg2-binary", "pytest"),
}
EXTRAS_REQUIRE["dev"] = (
EXTRAS_REQUIRE["docs"] + EXTRAS_REQUIRE["tests"] + ("tox",)
)
EXTRAS_REQUIRE["dev"] = EXTRAS_REQUIRE["docs"] + EXTRAS_REQUIRE["tests"] + ("tox",)


def find_version(fname):
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def app():

@pytest.fixture
def db(app):
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get(
"DATABASE_URL", "sqlite://"
)
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get("DATABASE_URL", "sqlite://")

# TODO: Remove once this is the default.
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
Expand Down
Loading
Loading