diff --git a/flask_resty/api.py b/flask_resty/api.py index 454fc1d..6db75df 100644 --- a/flask_resty/api.py +++ b/flask_resty/api.py @@ -1,10 +1,13 @@ import functools import posixpath +from typing import Optional import flask +from flask import Flask from werkzeug.exceptions import HTTPException from .exceptions import ApiError +from .view import ApiView # ----------------------------------------------------------------------------- @@ -14,11 +17,11 @@ # ----------------------------------------------------------------------------- -def handle_api_error(error): +def handle_api_error(error: ApiError): return error.response -def handle_http_exception(error): +def handle_http_exception(error: HTTPException): return ApiError.from_http_exception(error).response @@ -47,7 +50,9 @@ class Api: :param str prefix: The API path prefix. """ - def __init__(self, app=None, prefix=""): + _app: Optional[Flask] + + def __init__(self, app: Optional[Flask] = None, prefix: str = "") -> None: if app: self._app = app self.init_app(app) @@ -56,7 +61,7 @@ def __init__(self, app=None, prefix=""): self.prefix = prefix - def init_app(self, app): + def init_app(self, app: Flask) -> None: """Initialize an application for use with Flask-RESTy. :param app: The Flask application object. @@ -67,21 +72,21 @@ def init_app(self, app): app.register_error_handler(ApiError, handle_api_error) app.register_error_handler(HTTPException, handle_http_exception) - def _get_app(self, app): + def _get_app(self, app: Optional[Flask]) -> Flask: app = app or self._app assert app, "no application specified" return app def add_resource( self, - base_rule, - base_view, - alternate_view=None, + base_rule: str, + base_view: ApiView, + alternate_view: Optional[ApiView] = None, *, - alternate_rule=None, - id_rule=None, - app=None, - ): + alternate_rule: Optional[str] = None, + id_rule: Optional[str] = None, + app: Optional[Flask] = None, + ) -> None: """Add a REST resource. :param str base_rule: The URL rule for the resource. This will be @@ -150,7 +155,9 @@ def view_func(*args, **kwargs): methods=alternate_view.methods, ) - def _get_endpoint(self, base_view, alternate_view): + def _get_endpoint( + self, base_view: ApiView, alternate_view: Optional[ApiView] + ) -> str: base_view_name = base_view.__name__ if not alternate_view: return base_view_name @@ -161,7 +168,9 @@ def _get_endpoint(self, base_view, alternate_view): else: return base_view_name - def add_ping(self, rule, *, status_code=200, app=None): + def add_ping( + self, rule: str, *, status_code: int = 200, app: Optional[Flask] = None + ) -> None: """Add a ping route. :param str rule: The URL rule. This will not use the API prefix, as the @@ -185,5 +194,5 @@ def ping(): class FlaskRestyState: - def __init__(self, api): + def __init__(self, api: Api) -> None: self.api = api diff --git a/flask_resty/routing.py b/flask_resty/routing.py index 167da7c..a0b7774 100644 --- a/flask_resty/routing.py +++ b/flask_resty/routing.py @@ -4,7 +4,7 @@ from werkzeug.routing import RequestPath except ImportError: # pragma: no cover # werkzeug<1.0 - from werkzeug.routing import RequestSlash as RequestPath + from werkzeug.routing import RequestSlash as RequestPath # type: ignore # ----------------------------------------------------------------------------- diff --git a/flask_resty/view.py b/flask_resty/view.py index 1ceb4f9..a8a6b39 100644 --- a/flask_resty/view.py +++ b/flask_resty/view.py @@ -1,25 +1,48 @@ import itertools +from typing import ( + Generic, + List, + Literal, + Optional, + Sequence, + Type, + TypeVar, + Union, +) import flask +from flask import Response from flask.views import MethodView -from marshmallow import ValidationError, fields +from marshmallow import Schema, ValidationError, fields from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Load +from sqlalchemy.orm import Load, Query, Session from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.strategy_options import loader_option from werkzeug.exceptions import NotFound from . import meta -from .authentication import NoOpAuthentication -from .authorization import NoOpAuthorization +from .authentication import AuthenticationBase, NoOpAuthentication +from .authorization import AuthorizationBase, NoOpAuthorization from .decorators import request_cached_property from .exceptions import ApiError from .fields import DelimitedList +from .filtering import Filtering +from .pagination import PaginationBase +from .related import Related +from .sorting import SortingBase from .utils import settable_property # ----------------------------------------------------------------------------- +TItem = TypeVar("TItem") +_TDataIn = TypeVar("_TDataIn", bound=dict) +_TArgs = TypeVar("_TArgs", bound=dict) +_TId = TypeVar("_TId") -class ApiView(MethodView): +# ----------------------------------------------------------------------------- + + +class ApiView(Generic[TItem], MethodView): """Base class for views that expose API endpoints. `ApiView` extends :py:class:`flask.views.MethodView` exposes functionality @@ -29,19 +52,19 @@ class ApiView(MethodView): #: The :py:class:`marshmallow.Schema` for serialization and #: deserialization. - schema = None + schema: Optional[Schema] = None #: The identifying fields for the model. - id_fields = ("id",) + id_fields: Sequence[str] = ("id",) #: The :py:class:`marshmallow.Schema` for deserializing the query params in #: the :py:data:`flask.Request.args`. - args_schema = None + args_schema: Optional[Schema] = None #: The authentication component. See :py:class:`AuthenticationBase`. - authentication = NoOpAuthentication() + authentication: AuthenticationBase = NoOpAuthentication() #: The authorization component. See :py:class:`AuthorizationBase`. - authorization = NoOpAuthorization() + authorization: AuthorizationBase = NoOpAuthorization() - def dispatch_request(self, *args, **kwargs): + def dispatch_request(self, *args, **kwargs) -> Response: """Handle an incoming request. By default, this checks request-level authentication and authorization @@ -52,7 +75,7 @@ def dispatch_request(self, *args, **kwargs): return super().dispatch_request(*args, **kwargs) - def serialize(self, item, **kwargs): + def serialize(self, item: Union[TItem, Sequence[TItem]], **kwargs) -> dict: """Dump an item using the :py:attr:`serializer`. This doesn't technically serialize the item; it instead uses @@ -66,10 +89,11 @@ def serialize(self, item, **kwargs): :return: The serialized object :rtype: dict """ + assert self.serializer, "cannot serialize without serializer" return self.serializer.dump(item, **kwargs) @settable_property - def serializer(self): + def serializer(self) -> Optional[Schema]: """The :py:class:`marshmallow.Schema` for serialization. By default, this is :py:attr:`ApiView.schema`. This can be overridden @@ -77,7 +101,7 @@ def serializer(self): """ return self.schema - def make_items_response(self, items, *args): + def make_items_response(self, items: Sequence[TItem], *args) -> Response: """Build a response for a sequence of multiple items. This serializes the items, then builds an response with the list of @@ -94,7 +118,7 @@ def make_items_response(self, items, *args): data_out = self.serialize(items, many=True) return self.make_response(data_out, *args, items=items) - def make_item_response(self, item, *args): + def make_item_response(self, item: TItem, *args) -> Response: """Build a response for a single item. This serializes the item, then builds an response with the serialized @@ -119,7 +143,7 @@ def make_item_response(self, item, *args): return response - def set_item_response_meta(self, item): + def set_item_response_meta(self, item: TItem) -> None: """Hook for setting additional metadata for an item. This should call `meta.update_response_meta` to set any metadata values @@ -130,7 +154,7 @@ def set_item_response_meta(self, item): """ pass - def make_response(self, data, *args, **kwargs): + def make_response(self, data, *args, **kwargs) -> Response: """Build a response for arbitrary dumped data. This builds the response body given the data and any metadata from the @@ -142,7 +166,7 @@ def make_response(self, data, *args, **kwargs): body = self.render_response_body(data, meta.get_response_meta()) return self.make_raw_response(body, *args, **kwargs) - def render_response_body(self, data, response_meta): + def render_response_body(self, data, response_meta: dict) -> Response: """Render the response data and metadata into a body. This is the final step of building the response payload before @@ -157,7 +181,7 @@ def render_response_body(self, data, response_meta): return flask.jsonify(body) - def make_raw_response(self, *args, **kwargs): + def make_raw_response(self, *args, **kwargs) -> Response: """Convenience method for creating a :py:class:`flask.Response`. Any supplied keyword arguments are defined as attributes on the @@ -171,7 +195,7 @@ def make_raw_response(self, *args, **kwargs): setattr(response, key, value) return response - def make_empty_response(self, **kwargs): + def make_empty_response(self, **kwargs) -> Response: """Build an empty response. This response has a status code of 204 and an empty body. @@ -181,7 +205,7 @@ def make_empty_response(self, **kwargs): """ return self.make_raw_response("", 204, **kwargs) - def make_created_response(self, item): + def make_created_response(self, item: TItem) -> Response: """Build a response for a newly created item. This response will be for the item data and will have a status code of @@ -194,7 +218,7 @@ def make_created_response(self, item): """ return self.make_item_response(item, 201) - def make_deleted_response(self, item): + def make_deleted_response(self, item: TItem) -> Response: """Build a response for a deleted item. By default, this will be an empty response. The empty response will @@ -206,7 +230,7 @@ def make_deleted_response(self, item): """ return self.make_empty_response(item=item) - def get_location(self, item): + def get_location(self, item: TItem) -> str: """Get the canonical URL for an item. Override this to return ``None`` if no such URL is available. @@ -220,7 +244,7 @@ def get_location(self, item): } return flask.url_for(flask.request.endpoint, _method="GET", **id_dict) - def get_request_data(self, **kwargs): + def get_request_data(self, **kwargs) -> _TDataIn: """Deserialize and load data from the body of the current request. By default, this will look for the value under the ``data`` key in a @@ -232,13 +256,14 @@ def get_request_data(self, **kwargs): data_raw = self.parse_request_data() return self.deserialize(data_raw, **kwargs) - def parse_request_data(self): + def parse_request_data(self) -> dict: """Deserialize the data for the current request. This will deserialize the request data from the request body into a native Python object that can be loaded by marshmallow. - :return: The deserialized request data. + :return: The deserialized request data + :rtype: dict """ try: data_raw = flask.request.get_json()["data"] @@ -249,7 +274,13 @@ def parse_request_data(self): return data_raw - def deserialize(self, data_raw, *, expected_id=None, **kwargs): + def deserialize( + self, + data_raw: dict, + *, + expected_id: Union[_TId, Literal[False], None] = None, + **kwargs, + ) -> _TDataIn: """Load data using the :py:attr:`deserializer`. This doesn't technically deserialize the data; it instead uses @@ -259,12 +290,14 @@ def deserialize(self, data_raw, *, expected_id=None, **kwargs): Any provided `**kwargs` will be passed to :py:meth:`marshmallow.Schema.load`. - :param data_raw: The request data to load. + :param dict data_raw: The request data to load. :param expected_id: The expected ID in the request data. See `validate_request_id`. :return: The deserialized data :rtype: dict """ + assert self.deserializer, "cannot deserialize without deserializer" + try: data = self.deserializer.load(data_raw, **kwargs) except ValidationError as e: @@ -276,7 +309,7 @@ def deserialize(self, data_raw, *, expected_id=None, **kwargs): return data @settable_property - def deserializer(self): + def deserializer(self) -> Optional[Schema]: """The :py:class:`marshmallow.Schema` for serialization. By default, this is :py:attr:`ApiView.schema`. This can be overridden @@ -284,7 +317,9 @@ def deserializer(self): """ return self.schema - def format_validation_error(self, message, path): + def format_validation_error( + self, message: str, path: Sequence[str] + ) -> dict: """Convert marshmallow validation error data to a serializable form. This converts marshmallow validation error data to a standard @@ -314,7 +349,7 @@ def format_validation_error(self, message, path): "source": {"pointer": pointer}, } - def validate_request_id(self, data, expected_id): + def validate_request_id(self, data: dict, expected_id: _TId) -> None: """Check that the request data has the expected ID. This is generally used to assert that update operations include the @@ -346,7 +381,7 @@ def validate_request_id(self, data, expected_id): if id != expected_id: raise ApiError(409, {"code": "invalid_id.mismatch"}) - def get_data_id(self, data): + def get_data_id(self, data: dict) -> _TId: """Get the ID as a scalar or tuple from request data. The ID will be a scalar if :py:attr:`id_fields` contains a single @@ -360,7 +395,7 @@ def get_data_id(self, data): return tuple(data[id_field] for id_field in self.id_fields) @request_cached_property - def request_args(self): + def request_args(self) -> _TArgs: """The query arguments for the current request. This uses :py:attr:`args_schema` to load the current query args. This @@ -394,7 +429,7 @@ def request_args(self): return self.deserialize_args(data_raw) - def deserialize_args(self, data_raw, **kwargs): + def deserialize_args(self, data_raw: dict, **kwargs) -> _TArgs: """Load parsed query arg data using :py:attr:`args_schema`. As with `deserialize`, contra the name, this handles loading with a @@ -419,7 +454,9 @@ def deserialize_args(self, data_raw, **kwargs): return data - def format_parameter_validation_error(self, message, parameter): + def format_parameter_validation_error( + self, message: str, parameter: str + ) -> dict: """Convert a parameter validation error to a serializable form. This closely follows `format_validation_error`, but produces error @@ -444,7 +481,7 @@ def format_parameter_validation_error(self, message, parameter): "source": {"parameter": parameter}, } - def get_id_dict(self, id): + def get_id_dict(self, id: _TId) -> dict: """Convert an ID from `get_data_id` to dictionary form. This converts an ID from `get_data_id` into a dictionary where each ID @@ -461,7 +498,7 @@ def get_id_dict(self, id): return dict(zip(self.id_fields, id)) -class ModelView(ApiView): +class ModelView(ApiView[TItem]): """Base class for API views tied to SQLAlchemy models. `ModelView` implements additional methods on top of those provided by @@ -478,25 +515,25 @@ class ModelView(ApiView): """ #: A declarative SQLAlchemy model. - model = None + model: Type[TItem] = None #: An instance of :py:class:`filtering.Filtering`. - filtering = None + filtering: Optional[Filtering] = None #: An instance of :py:class:`sorting.SortingBase`. - sorting = None + sorting: Optional[SortingBase] = None #: An instance of :py:class:`pagination.PaginationBase`. - pagination = None + pagination: Optional[PaginationBase] = None #: An instance of :py:class:`related.Related`. - related = None + related: Optional[Related] = None @settable_property - def session(self): + def session(self) -> Session: """Convenience property for the current SQLAlchemy session.""" return flask.current_app.extensions["sqlalchemy"].db.session @settable_property - def query_raw(self): + def query_raw(self) -> Query: """The raw SQLAlchemy query for the view. This is the base query, without authorization filters or query options. @@ -506,7 +543,7 @@ def query_raw(self): return self.model.query @settable_property - def query(self): + def query(self) -> Query: """The SQLAlchemy query for the view. Override this to customize the query to fetch items in this view. @@ -531,10 +568,10 @@ def query(self): #: For example, set this to ``(raiseload('*', sql_only=True),)`` to prevent #: all implicit SQL-emitting relationship loading, and force all #: relationship loading to be explicitly defined via `query_options`. - base_query_options = () + base_query_options: Sequence[loader_option] = () @settable_property - def query_options(self): + def query_options(self) -> Sequence[loader_option]: """Options to apply to the query for the view. Set this to configure relationship and column loading. @@ -551,7 +588,7 @@ def query_options(self): return self.serializer.get_query_options(Load(self.model)) - def get_list(self): + def get_list(self) -> List[TItem]: """Retrieve a list of items. This takes the output of `get_list_query` and applies pagination. @@ -561,7 +598,7 @@ def get_list(self): """ return self.paginate_list_query(self.get_list_query()) - def get_list_query(self): + def get_list_query(self) -> Query: """Build the query to retrieve a filtered and sorted list of items. :return: The list query. @@ -572,7 +609,7 @@ def get_list_query(self): query = self.sort_list_query(query) return query - def filter_list_query(self, query): + def filter_list_query(self, query: Query) -> Query: """Apply filtering as specified to the provided `query`. :param: A SQL query @@ -585,7 +622,7 @@ def filter_list_query(self, query): return self.filtering.filter_query(query, self) - def sort_list_query(self, query): + def sort_list_query(self, query: Query) -> Query: """Apply sorting as specified to the provided `query`. :param: A SQL query @@ -598,7 +635,7 @@ def sort_list_query(self, query): return self.sorting.sort_query(query, self) - def paginate_list_query(self, query): + def paginate_list_query(self, query: Query) -> List[TItem]: """Retrieve the requested page from `query`. If :py:attr:`pagination` is configured, this will retrieve the page as @@ -615,7 +652,7 @@ def paginate_list_query(self, query): return self.pagination.get_page(query, self) - def get_item_or_404(self, id, **kwargs): + def get_item_or_404(self, id: _TId, **kwargs) -> TItem: """Get an item by ID; raise a 404 if it not found. This will get an item by ID per `get_item` below. If no item is found, @@ -633,8 +670,12 @@ def get_item_or_404(self, id, **kwargs): return item def get_item( - self, id, *, with_for_update=False, create_transient_stub=False, - ): + self, + id: _TId, + *, + with_for_update: bool = False, + create_transient_stub: bool = False, + ) -> TItem: """Get an item by ID. The ID should be the scalar ID value if `id_fields` specifies a single @@ -675,7 +716,7 @@ def get_item( return item - def deserialize(self, data_raw, **kwargs): + def deserialize(self, data_raw: dict, **kwargs) -> _TDataIn: """Load data using the :py:attr:`deserializer`. In addition to the functionality of :py:meth:`ApiView.deserialize`, @@ -684,7 +725,7 @@ def deserialize(self, data_raw, **kwargs): data = super().deserialize(data_raw, **kwargs) return self.resolve_related(data) - def resolve_related(self, data): + def resolve_related(self, data: dict) -> _TDataIn: """Resolve all related fields per :py:attr:`related`. :param object data: A deserialized object @@ -696,7 +737,7 @@ def resolve_related(self, data): return self.related.resolve_related(data) - def resolve_related_item(self, data, **kwargs): + def resolve_related_item(self, data: dict, **kwargs) -> TItem: """Retrieve the related item corresponding to the provided data stub. This is used by `Related` when this view is set for a field. @@ -712,7 +753,7 @@ def resolve_related_item(self, data, **kwargs): return self.resolve_related_id(id, **kwargs) - def resolve_related_id(self, id, **kwargs): + def resolve_related_id(self, id: _TId, **kwargs) -> TItem: """Retrieve the related item corresponding to the provided ID. This is used by `Related` when a field is specified as a `RelatedId`. @@ -728,7 +769,7 @@ def resolve_related_id(self, id, **kwargs): return item - def create_stub_item(self, id): + def create_stub_item(self, id: _TId) -> TItem: """Create a stub item that corresponds to the provided ID. This is used by `get_item` when `create_transient_stub` is set. @@ -741,7 +782,7 @@ def create_stub_item(self, id): """ return self.create_item(self.get_id_dict(id)) - def create_item(self, data): + def create_item(self, data: _TDataIn) -> TItem: """Create an item using the provided data. This will invoke `authorize_create_item` on the created item. @@ -757,7 +798,7 @@ def create_item(self, data): self.authorization.authorize_create_item(item) return item - def create_item_raw(self, data): + def create_item_raw(self, data: _TDataIn) -> TItem: """As with `create_item`, but without the authorization check. This is used by `create_item`, which then applies the authorization @@ -773,7 +814,7 @@ def create_item_raw(self, data): """ return self.model(**data) - def add_item(self, item): + def add_item(self, item: TItem) -> None: """Add an item to the current session. This will invoke `authorize_save_item` on the item to add. @@ -783,7 +824,7 @@ def add_item(self, item): self.add_item_raw(item) self.authorization.authorize_save_item(item) - def add_item_raw(self, item): + def add_item_raw(self, item: TItem) -> None: """As with `add_item`, but without the authorization check. This is used by `add_item`, which then applies the authorization check. @@ -792,7 +833,7 @@ def add_item_raw(self, item): """ self.session.add(item) - def create_and_add_item(self, data): + def create_and_add_item(self, data: _TDataIn) -> TItem: """Create an item using the provided data, then add it to the session. This uses `create_item` and `add_item`. Correspondingly, it will invoke @@ -806,7 +847,7 @@ def create_and_add_item(self, data): self.add_item(item) return item - def update_item(self, item, data): + def update_item(self, item: TItem, data: _TDataIn) -> TItem: """Update an existing item with the provided data. This will invoke `authorize_update_item` using the provided item and @@ -826,7 +867,7 @@ def update_item(self, item, data): self.authorization.authorize_save_item(item) return item - def update_item_raw(self, item, data): + def update_item_raw(self, item: TItem, data: _TDataIn) -> TItem: """As with `update_item`, but without the authorization checks. Override this instead of `update_item` when applying other @@ -840,8 +881,9 @@ def update_item_raw(self, item, data): """ for key, value in data.items(): setattr(item, key, value) + return item - def delete_item(self, item): + def delete_item(self, item: TItem) -> TItem: """Delete an existing item. This will run `authorize_delete_item` on the item before deleting it. @@ -854,7 +896,7 @@ def delete_item(self, item): item = self.delete_item_raw(item) or item return item - def delete_item_raw(self, item): + def delete_item_raw(self, item: TItem) -> TItem: """As with `delete_item`, but without the authorization check. Override this to customize the delete behavior, e.g. by replacing the @@ -863,8 +905,9 @@ def delete_item_raw(self, item): :param object item: The item to delete. """ self.session.delete(item) + return item - def flush(self, *, objects=None): + def flush(self, *, objects=None) -> None: """Flush pending changes to the database. This will check database level invariants, and will throw exceptions as @@ -906,7 +949,7 @@ def commit(self): except IntegrityError as e: raise self.resolve_integrity_error(e) from e - def resolve_integrity_error(self, error): + def resolve_integrity_error(self, error: IntegrityError) -> Exception: """Convert integrity errors to HTTP error responses as appropriate. Certain kinds of database integrity errors cannot easily be caught by @@ -938,7 +981,7 @@ def resolve_integrity_error(self, error): ) return ApiError(409, {"code": "invalid_data.conflict"}) - def set_item_response_meta(self, item): + def set_item_response_meta(self, item: TItem) -> None: """Set the appropriate response metadata for the response item. By default, this adds the item metadata from the pagination component. @@ -948,7 +991,7 @@ def set_item_response_meta(self, item): super().set_item_response_meta(item) self.set_item_response_meta_pagination(item) - def set_item_response_meta_pagination(self, item): + def set_item_response_meta_pagination(self, item: TItem) -> None: """Set pagination metadata for the response item. This uses the configured pagination component to set pagination @@ -962,7 +1005,7 @@ def set_item_response_meta_pagination(self, item): meta.update_response_meta(self.pagination.get_item_meta(item, self)) -class GenericModelView(ModelView): +class GenericModelView(ModelView[TItem]): """Base class for API views implementing CRUD methods. `GenericModelView` provides basic implementations of the standard CRUD @@ -999,7 +1042,7 @@ def delete(self, id): the methods in `MethodView`. """ - def list(self): + def list(self) -> Response: """Return a list of items. This is the standard GET handler on a list view. @@ -1010,7 +1053,9 @@ def list(self): items = self.get_list() return self.make_items_response(items) - def retrieve(self, id, *, create_transient_stub=False): + def retrieve( + self, id: _TId, *, create_transient_stub: bool = False + ) -> Response: """Retrieve an item by ID. This is the standard ``GET`` handler on a detail view. @@ -1027,7 +1072,7 @@ def retrieve(self, id, *, create_transient_stub=False): ) return self.make_item_response(item) - def create(self, *, allow_client_id=False): + def create(self, *, allow_client_id: bool = False) -> Response: """Create a new item using the request data. This is the standard ``POST`` handler on a list view. @@ -1046,8 +1091,12 @@ def create(self, *, allow_client_id=False): return self.make_created_response(item) def update( - self, id, *, with_for_update=False, partial=False, - ): + self, + id: _TId, + *, + with_for_update: bool = False, + partial: bool = False, + ) -> Response: """Update the item for the specified ID with the request data. This is the standard ``PUT`` handler on a detail view if `partial` is @@ -1069,7 +1118,7 @@ def update( return self.make_item_response(item) - def upsert(self, id, *, with_for_update=False): + def upsert(self, id: _TId, *, with_for_update: bool = False) -> Response: """Upsert the item for the specified ID with the request data. This will update the item for the given ID, if that item exists. @@ -1096,7 +1145,7 @@ def upsert(self, id, *, with_for_update=False): return self.make_item_response(item) - def destroy(self, id): + def destroy(self, id: _TId) -> Response: """Delete the item for the specified ID. :param id: The item ID.