1+ from __future__ import annotations
2+
13from collections .abc import Mapping
2- from typing import Union
34
45import numpy
56
6- from openfisca_core import types as t
7+ from . import types as t
78
89
910def apply_thresholds (
10- input : t .Array [numpy .float64 ],
11+ input : t .Array [numpy .float32 ],
1112 thresholds : t .ArrayLike [float ],
1213 choices : t .ArrayLike [float ],
13- ) -> t .Array [numpy .float64 ]:
14+ ) -> t .Array [numpy .float32 ]:
1415 """Makes a choice based on an input and thresholds.
1516
1617 From a list of ``choices``, this function selects one of these values
@@ -38,26 +39,29 @@ def apply_thresholds(
3839 array([10, 10, 15, 15, 20])
3940
4041 """
41- condlist : list [Union [t .Array [numpy .bool_ ], bool ]]
42+
43+ condlist : list [t .Array [numpy .bool_ ] | bool ]
4244 condlist = [input <= threshold for threshold in thresholds ]
4345
4446 if len (condlist ) == len (choices ) - 1 :
4547 # If a choice is provided for input > highest threshold, last condition
4648 # must be true to return it.
4749 condlist += [True ]
4850
49- assert len (condlist ) == len (
50- choices
51- ), "'apply_thresholds' must be called with the same number of thresholds than choices, or one more choice."
51+ msg = (
52+ "'apply_thresholds' must be called with the same number of thresholds "
53+ "than choices, or one more choice."
54+ )
55+ assert len (condlist ) == len (choices ), msg
5256
5357 return numpy .select (condlist , choices )
5458
5559
5660def concat (
57- this : Union [ t .Array [numpy .str_ ], t .ArrayLike [str ] ],
58- that : Union [ t .Array [numpy .str_ ], t .ArrayLike [str ] ],
61+ this : t .Array [numpy .str_ ] | t .ArrayLike [str ],
62+ that : t .Array [numpy .str_ ] | t .ArrayLike [str ],
5963) -> t .Array [numpy .str_ ]:
60- """Concatenates the values of two arrays.
64+ """Concatenate the values of two arrays.
6165
6266 Args:
6367 this: An array to concatenate.
@@ -84,10 +88,10 @@ def concat(
8488
8589
8690def switch (
87- conditions : t .Array [numpy .float64 ],
91+ conditions : t .Array [numpy .float32 ],
8892 value_by_condition : Mapping [float , float ],
89- ) -> t .Array [numpy .float64 ]:
90- """Mimicks a switch statement.
93+ ) -> t .Array [numpy .float32 ]:
94+ """Mimick a switch statement.
9195
9296 Given an array of conditions, returns an array of the same size,
9397 replacing each condition item with the matching given value.
@@ -117,3 +121,6 @@ def switch(
117121 condlist = [conditions == condition for condition in value_by_condition ]
118122
119123 return numpy .select (condlist , tuple (value_by_condition .values ()))
124+
125+
126+ __all__ = ["apply_thresholds" , "concat" , "switch" ]
0 commit comments