diff --git a/pygeoapi/provider/base.py b/pygeoapi/provider/base.py index 538c076a9..0374adfc9 100644 --- a/pygeoapi/provider/base.py +++ b/pygeoapi/provider/base.py @@ -31,6 +31,14 @@ import logging from enum import Enum from http import HTTPStatus +from typing import Literal, Optional, TypedDict +import sys + +if sys.version_info >= (3, 11): + from typing import NotRequired +else: + from typing_extensions import NotRequired + from pygeoapi.error import GenericError @@ -44,10 +52,32 @@ class SchemaType(Enum): replace = 'replace' +# All the potential properties on a field +# as defined by the OGC API features spec +FieldProperties = TypedDict( + "FieldProperties", + { + "type": Literal["string", "number", "integer", + "boolean", "object", "array"], + "title": str, + "description": str, + "format": NotRequired[str], + "x-ogc-unit": NotRequired[str], + "x-ogc-role": NotRequired[str], + "enum": NotRequired[list[str]], + }, +) + + +# Dict type representing a mapping of the field +# to its associated data type +FieldMapping = dict[str, FieldProperties] + + class BaseProvider: """generic Provider ABC""" - def __init__(self, provider_def): + def __init__(self, provider_def: dict): """ Initialize object @@ -73,7 +103,7 @@ def __init__(self, provider_def): self.title_field = provider_def.get('title_field') self.properties = provider_def.get('properties', []) self.file_types = provider_def.get('file_types', []) - self._fields = {} + self._fields: FieldMapping = {} self.filename = None # for coverage providers @@ -81,7 +111,7 @@ def __init__(self, provider_def): self.crs = None self.num_bands = None - def get_fields(self): + def get_fields(self) -> FieldMapping: """ Get provider field information (names, types) @@ -94,7 +124,7 @@ def get_fields(self): raise NotImplementedError() @property - def fields(self) -> dict: + def fields(self) -> FieldMapping: """ Store provider field information (names, types) @@ -110,7 +140,7 @@ def fields(self) -> dict: else: return self.get_fields() - def get_schema(self, schema_type: SchemaType = SchemaType.item): + def get_schema(self, schema_type: SchemaType = SchemaType.item) -> tuple: """ Get provider schema model @@ -122,7 +152,7 @@ def get_schema(self, schema_type: SchemaType = SchemaType.item): raise NotImplementedError() - def get_data_path(self, baseurl, urlpath, dirpath): + def get_data_path(self, baseurl: str, urlpath: str, dirpath: str) -> dict: """ Gets directory listing or file description or raw file dump @@ -145,7 +175,7 @@ def get_metadata(self): raise NotImplementedError() - def get_domains(self, properties=[], current=False): + def get_domains(self, properties: list[str] = [], current=False): """ Get domains from dataset @@ -168,7 +198,7 @@ def query(self): raise NotImplementedError() - def get(self, identifier, **kwargs): + def get(self, identifier: str, **kwargs): """ query the provider by id @@ -179,7 +209,7 @@ def get(self, identifier, **kwargs): raise NotImplementedError() - def create(self, item): + def create(self, item: dict) -> str: """ Create a new item @@ -190,7 +220,7 @@ def create(self, item): raise NotImplementedError() - def update(self, identifier, item): + def update(self, identifier: str, item: dict) -> bool: """ Updates an existing item @@ -202,7 +232,7 @@ def update(self, identifier, item): raise NotImplementedError() - def delete(self, identifier): + def delete(self, identifier: str) -> bool: """ Deletes an existing item @@ -213,7 +243,8 @@ def delete(self, identifier): raise NotImplementedError() - def _load_and_prepare_item(self, item, identifier=None, + def _load_and_prepare_item(self, item: str, + identifier: Optional[str] = None, accept_missing_identifier=False, raise_if_exists=True): """