Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions openfisca_core/commons/formulas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Any, Dict, Sequence, TypeVar
from openfisca_core.typing import ArrayLike, ArrayType

import numpy

from openfisca_core.types import ArrayLike, ArrayType

T = TypeVar("T")


Expand Down
5 changes: 3 additions & 2 deletions openfisca_core/commons/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypeVar
from __future__ import annotations

from openfisca_core.types import ArrayType
from typing import TypeVar
from openfisca_core.typing import ArrayType

T = TypeVar("T")

Expand Down
5 changes: 3 additions & 2 deletions openfisca_core/commons/rates.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Optional
from openfisca_core.typing import ArrayLike, ArrayType

import numpy

from openfisca_core.types import ArrayLike, ArrayType


def average_rate(
target: ArrayType[float],
Expand Down
30 changes: 21 additions & 9 deletions openfisca_core/data_storage/on_disk_storage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
from __future__ import annotations

from typing import Any, AbstractSet, MutableMapping
from openfisca_core.typing import ArrayType, PeriodProtocol

import os
import shutil

import numpy

from openfisca_core import periods
from openfisca_core.indexed_enums import EnumArray
from openfisca_core.indexed_enums import Enum, EnumArray


class OnDiskStorage:
"""
Low-level class responsible for storing and retrieving calculated vectors on disk
"""

def __init__(self, storage_dir, is_eternal = False, preserve_storage_dir = False):
self._files = {}
self._enums = {}
def __init__(
self,
storage_dir: str,
is_eternal: bool = False,
preserve_storage_dir: bool = False,
) -> None:

self._files: MutableMapping[PeriodProtocol, ArrayType[Any]] = {}
self._enums: MutableMapping[str, Enum] = {}
self.is_eternal = is_eternal
self.preserve_storage_dir = preserve_storage_dir
self.storage_dir = storage_dir
Expand All @@ -26,7 +37,7 @@ def _decode_file(self, file):
else:
return numpy.load(file)

def get(self, period):
def get(self, period: PeriodProtocol) -> ArrayType[Any]:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
Expand All @@ -36,7 +47,7 @@ def get(self, period):
return None
return self._decode_file(values)

def put(self, value, period):
def put(self, value: ArrayType[Any], period: PeriodProtocol) -> None:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
Expand Down Expand Up @@ -65,10 +76,11 @@ def delete(self, period = None):
if not period.contains(period_item)
}

def get_known_periods(self):
def get_known_periods(self) -> AbstractSet[PeriodProtocol]:
return self._files.keys()

def restore(self):
def restore(self) -> None:
files: MutableMapping[PeriodProtocol, ArrayType[Any]]
self._files = files = {}
# Restore self._files from content of storage_dir.
for filename in os.listdir(self.storage_dir):
Expand All @@ -79,7 +91,7 @@ def restore(self):
period = periods.period(filename_core)
files[period] = path

def __del__(self):
def __del__(self) -> None:
if self.preserve_storage_dir:
return
shutil.rmtree(self.storage_dir) # Remove the holder temporary files
Expand Down
10 changes: 9 additions & 1 deletion openfisca_core/errors/variable_not_found_error.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from openfisca_core.typing import TaxBenefitSystemProtocol

import os


Expand All @@ -6,7 +10,11 @@ class VariableNotFoundError(Exception):
Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem.
"""

def __init__(self, variable_name, tax_benefit_system):
def __init__(
self,
variable_name: str,
tax_benefit_system: TaxBenefitSystemProtocol,
) -> None:
"""
:param variable_name: Name of the variable that was queried.
:param tax_benefit_system: Tax benefits system that does not contain `variable_name`
Expand Down
24 changes: 20 additions & 4 deletions openfisca_core/holders/holder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from typing import Any, Optional, Sequence
from openfisca_core.typing import ArrayType

import os
import warnings

Expand Down Expand Up @@ -48,7 +53,12 @@ def clone(self, population):

return new

def create_disk_storage(self, directory = None, preserve = False):
def create_disk_storage(
self,
directory: Optional[str] = None,
preserve: bool = False,
) -> OnDiskStorage:

if directory is None:
directory = self.simulation.data_storage_dir
storage_dir = os.path.join(directory, self.variable.name)
Expand All @@ -71,12 +81,13 @@ def delete_arrays(self, period = None):
if self._disk_storage:
self._disk_storage.delete(period)

def get_array(self, period):
def get_array(self, period: periods.Period) -> Any:
"""
Get the value of the variable for the given period.

If the value is not known, return ``None``.
"""

if self.variable.is_neutralized:
return self.default_array()
value = self._memory_storage.get(period)
Expand Down Expand Up @@ -122,7 +133,7 @@ def get_memory_usage(self):

return usage

def get_known_periods(self):
def get_known_periods(self) -> Sequence[periods.Period]:
"""
Get the list of periods the variable value is known for.
"""
Expand Down Expand Up @@ -227,7 +238,12 @@ def _set(self, period, value):
else:
self._memory_storage.put(value, period)

def put_in_cache(self, value, period):
def put_in_cache(
self,
value: ArrayType[Any],
period: periods.Period,
) -> None:

if self._do_not_store:
return

Expand Down
11 changes: 8 additions & 3 deletions openfisca_core/populations/population.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from typing import Optional
from openfisca_core.typing import ArrayLike

import traceback

import numpy
Expand All @@ -13,8 +18,8 @@ def __init__(self, entity):
self.simulation = None
self.entity = entity
self._holders = {}
self.count = 0
self.ids = []
self.count: Optional[int] = 0
self.ids: ArrayLike[str] = []

def clone(self, simulation):
result = Population(self.entity)
Expand All @@ -36,7 +41,7 @@ def __getattr__(self, attribute):
raise AttributeError("You tried to use the '{}' of '{}' but that is not a known attribute.".format(attribute, self.entity.key))
return projector

def get_index(self, id):
def get_index(self, id: str) -> int:
return self.ids.index(id)

# Calculations
Expand Down
43 changes: 34 additions & 9 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from __future__ import annotations

from typing import Any, Mapping, Optional, Set, Tuple
from openfisca_core.typing import (
ArrayType,
HolderProtocol,
PeriodProtocol,
PopulationProtocol,
TaxBenefitSystemProtocol,
)

import tempfile
import warnings

Expand All @@ -18,14 +29,15 @@ class Simulation:

def __init__(
self,
tax_benefit_system,
populations
):
tax_benefit_system: TaxBenefitSystemProtocol,
populations: Mapping[str, PopulationProtocol]
) -> None:
"""
This constructor is reserved for internal use; see :any:`SimulationBuilder`,
which is the preferred way to obtain a Simulation initialized with a consistent
set of Entities.
"""

self.tax_benefit_system = tax_benefit_system
assert tax_benefit_system is not None

Expand All @@ -34,7 +46,7 @@ def __init__(
self.link_to_entities_instances()
self.create_shortcuts()

self.invalidated_caches = set()
self.invalidated_caches: Set[Tuple[str, PeriodProtocol]] = set()

self.debug = False
self.trace = False
Expand Down Expand Up @@ -83,7 +95,11 @@ def data_storage_dir(self):

# ----- Calculation methods ----- #

def calculate(self, variable_name, period):
def calculate(
self,
variable_name: str,
period: Optional[Any],
) -> ArrayType[Any]:
"""Calculate ``variable_name`` for ``period``."""

if period is not None and not isinstance(period, Period):
Expand Down Expand Up @@ -291,10 +307,15 @@ def _check_for_cycle(self, variable: str, period):
message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable, period, self.tracer.stack)
raise SpiralError(message, variable)

def invalidate_cache_entry(self, variable: str, period):
def invalidate_cache_entry(
self,
variable: str,
period: PeriodProtocol,
) -> None:

self.invalidated_caches.add((variable, period))

def invalidate_spiral_variables(self, variable: str):
def invalidate_spiral_variables(self, variable: str) -> None:
# Visit the stack, from the bottom (most recent) up; we know that we'll find
# the variable implicated in the spiral (max_spiral_loops+1) times; we keep the
# intermediate values computed (to avoid impacting performance) but we mark them
Expand All @@ -319,7 +340,7 @@ def get_array(self, variable_name, period):
period = periods.period(period)
return self.get_holder(variable_name).get_array(period)

def get_holder(self, variable_name):
def get_holder(self, variable_name: str) -> HolderProtocol:
"""
Get the :obj:`.Holder` associated with the variable ``variable_name`` for the simulation
"""
Expand Down Expand Up @@ -414,7 +435,11 @@ def get_variable_population(self, variable_name):
variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)
return self.populations[variable.entity.key]

def get_population(self, plural = None):
def get_population(
self,
plural: Optional[str] = None,
) -> Optional[PopulationProtocol]:

return next((population for population in self.populations.values() if population.entity.plural == plural), None)

def get_entity(self, plural = None):
Expand Down
27 changes: 21 additions & 6 deletions openfisca_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from __future__ import annotations

import typing
from typing import Any, Mapping, Optional, Sequence
from openfisca_core.typing import (
ArrayType,
AxisSchema,
TaxBenefitSystemProtocol,
)

import copy
import dpath
import typing

import numpy

Expand All @@ -14,12 +23,14 @@

class SimulationBuilder:

def __init__(self):
default_period: Optional[str]

def __init__(self) -> None:
self.default_period = None # Simulation period used for variables when no period is defined
self.persons_plural = None # Plural name for person entity in current tax and benefits system

# JSON input - Memory of known input values. Indexed by variable or axis name.
self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {}
self.input_buffer: typing.Dict[Variable.name, typing.Dict[str, ArrayType]] = {}
self.populations: typing.Dict[Entity.key, Population] = {}
# JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes.
self.entity_counts: typing.Dict[Entity.plural, int] = {}
Expand All @@ -32,13 +43,17 @@ def __init__(self):

self.variable_entities: typing.Dict[Variable.name, Entity] = {}

self.axes = [[]]
self.axes: Sequence[Sequence[AxisSchema]] = [[]]
self.axes_entity_counts: typing.Dict[Entity.plural, int] = {}
self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {}
self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {}
self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {}

def build_from_dict(self, tax_benefit_system, input_dict):
def build_from_dict(
self,
tax_benefit_system: TaxBenefitSystemProtocol,
input_dict: Mapping[str, Any],
) -> Simulation:
"""
Build a simulation from ``input_dict``

Expand Down Expand Up @@ -322,7 +337,7 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json):
self.roles[entity.plural] = self.roles[entity.plural].tolist()
self.memberships[entity.plural] = self.memberships[entity.plural].tolist()

def set_default_period(self, period_str):
def set_default_period(self, period_str: Optional[str]) -> None:
if period_str:
self.default_period = str(periods.period(period_str))

Expand Down
Loading