Skip to content

Commit b197a8b

Browse files
authored
refactor(arithmetic): ➕ simplify addition (pandas-dev#1382)
* refactor(arithmetic): simplify addition * chore: reduce unused * fix: use SupportsAdd * fix(comment): pandas-dev#1382 (comment)
1 parent bc48e74 commit b197a8b

File tree

10 files changed

+327
-439
lines changed

10 files changed

+327
-439
lines changed

pandas-stubs/_libs/tslibs/period.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ class Period(PeriodMixin):
171171
@overload
172172
def __radd__(self, other: Index) -> Index: ...
173173
@overload
174-
def __radd__(self, other: Series[Timedelta]) -> PeriodSeries: ...
175-
@overload
176174
def __radd__(self, other: NaTType) -> NaTType: ...
177175
@property
178176
def day(self) -> int: ...

pandas-stubs/_libs/tslibs/timedeltas.pyi

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,16 @@ class Timedelta(timedelta):
137137
def resolution_string(self) -> str: ...
138138
# Override due to more types supported than dt.timedelta
139139
@overload # type: ignore[override]
140-
def __add__(self, other: timedelta | Timedelta | np.timedelta64) -> Timedelta: ...
140+
def __add__(self, other: dt.datetime | np.datetime64) -> Timestamp: ...
141141
@overload
142-
def __add__(self, other: dt.datetime | np.datetime64 | Timestamp) -> Timestamp: ...
142+
def __add__(self, other: timedelta | np.timedelta64) -> Self: ...
143143
@overload
144144
def __add__(self, other: NaTType) -> NaTType: ...
145145
@overload
146146
def __add__(self, other: Period) -> Period: ...
147147
@overload
148148
def __add__(self, other: dt.date) -> dt.date: ...
149149
@overload
150-
def __add__(self, other: PeriodIndex) -> PeriodIndex: ...
151-
@overload
152-
def __add__(self, other: DatetimeIndex) -> DatetimeIndex: ...
153-
@overload
154150
def __add__(
155151
self, other: np_ndarray[ShapeT, np.timedelta64]
156152
) -> np_ndarray[ShapeT, np.timedelta64]: ...
@@ -159,29 +155,21 @@ class Timedelta(timedelta):
159155
self, other: np_ndarray[ShapeT, np.datetime64]
160156
) -> np_ndarray[ShapeT, np.datetime64]: ...
161157
@overload
162-
def __add__(self, other: pd.TimedeltaIndex) -> pd.TimedeltaIndex: ...
163-
@overload
164-
def __add__(self, other: Series[Timedelta]) -> Series[Timedelta]: ...
165-
@overload
166-
def __add__(self, other: Series[Timestamp]) -> Series[Timestamp]: ...
158+
def __radd__(self, other: dt.datetime | np.datetime64) -> Timestamp: ... # type: ignore[misc]
167159
@overload
168-
def __radd__(self, other: np.datetime64) -> Timestamp: ...
169-
@overload
170-
def __radd__(self, other: timedelta | Timedelta | np.timedelta64) -> Timedelta: ...
160+
def __radd__(self, other: timedelta | np.timedelta64) -> Self: ...
171161
@overload
172162
def __radd__(self, other: NaTType) -> NaTType: ...
173163
@overload
164+
def __radd__(self, other: dt.date) -> dt.date: ...
165+
@overload
174166
def __radd__(
175167
self, other: np_ndarray[ShapeT, np.timedelta64]
176168
) -> np_ndarray[ShapeT, np.timedelta64]: ...
177169
@overload
178170
def __radd__(
179171
self, other: np_ndarray[ShapeT, np.datetime64]
180172
) -> np_ndarray[ShapeT, np.datetime64]: ...
181-
@overload
182-
def __radd__(self, other: pd.TimedeltaIndex) -> pd.TimedeltaIndex: ...
183-
@overload
184-
def __radd__(self, other: pd.PeriodIndex) -> pd.PeriodIndex: ...
185173
# Override due to more types supported than dt.timedelta
186174
@overload # type: ignore[override]
187175
def __sub__(self, other: timedelta | Timedelta | np.timedelta64) -> Timedelta: ...

pandas-stubs/_typing.pyi

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -847,25 +847,33 @@ MaskType: TypeAlias = Series[bool] | np_ndarray_bool | list[bool]
847847

848848
T_INT = TypeVar("T_INT", bound=int)
849849
T_COMPLEX = TypeVar("T_COMPLEX", bound=complex)
850-
SeriesDType: TypeAlias = (
850+
SeriesDTypeNoDateTime: TypeAlias = (
851851
str
852852
| bytes
853-
| datetime.date
854-
| datetime.time
855853
| bool
856854
| int
857855
| float
858856
| complex
859857
| Dtype
860-
| datetime.datetime # includes pd.Timestamp
861-
| datetime.timedelta # includes pd.Timedelta
862858
| Period
863859
| Interval
864860
| CategoricalDtype
865861
| BaseOffset
866862
| list[str]
867863
)
864+
SeriesDType: TypeAlias = (
865+
SeriesDTypeNoDateTime
866+
| datetime.date
867+
| datetime.time
868+
| datetime.datetime # includes pd.Timestamp
869+
| datetime.timedelta # includes pd.Timedelta
870+
)
868871
S1 = TypeVar("S1", bound=SeriesDType, default=Any)
872+
S1_CT_NDT = TypeVar(
873+
"S1_CT_NDT", bound=SeriesDTypeNoDateTime, default=Any, contravariant=True
874+
)
875+
S1_CO = TypeVar("S1_CO", bound=SeriesDType, default=Any, covariant=True)
876+
S1_CT = TypeVar("S1_CT", bound=SeriesDType, default=Any, contravariant=True)
869877
# Like S1, but without `default=Any`.
870878
S2 = TypeVar("S2", bound=SeriesDType)
871879
S3 = TypeVar("S3", bound=SeriesDType)

pandas-stubs/core/indexes/base.pyi

Lines changed: 45 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ from collections.abc import (
33
Callable,
44
Hashable,
55
Iterable,
6-
Iterator,
76
Sequence,
87
)
98
from datetime import (
@@ -20,6 +19,10 @@ from typing import (
2019
type_check_only,
2120
)
2221

22+
from _typeshed import (
23+
SupportsAdd,
24+
SupportsRAdd,
25+
)
2326
import numpy as np
2427
from pandas import (
2528
DataFrame,
@@ -49,8 +52,9 @@ from pandas._libs.interval import _OrderableT
4952
from pandas._typing import (
5053
C2,
5154
S1,
55+
S1_CO,
56+
S1_CT,
5257
T_COMPLEX,
53-
T_INT,
5458
AnyAll,
5559
ArrayLike,
5660
AxesData,
@@ -471,7 +475,6 @@ class Index(IndexOpsMixin[S1]):
471475
def shape(self) -> tuple[int, ...]: ...
472476
# Extra methods from old stubs
473477
def __eq__(self, other: object) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
474-
def __iter__(self) -> Iterator[S1]: ...
475478
def __ne__(self, other: object) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
476479
def __le__(self, other: Self | S1) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
477480
def __ge__(self, other: Self | S1) -> np_1darray[np.bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@@ -484,65 +487,43 @@ class Index(IndexOpsMixin[S1]):
484487
@overload
485488
def __add__(self, other: Index[Never]) -> Index: ...
486489
@overload
487-
def __add__(
488-
self: Index[bool],
489-
other: T_COMPLEX | Sequence[T_COMPLEX] | Index[T_COMPLEX],
490-
) -> Index[T_COMPLEX]: ...
491-
@overload
492-
def __add__(self: Index[bool], other: np_ndarray_bool) -> Index[bool]: ...
490+
def __add__(self: Index[bool], other: bool | Sequence[bool]) -> Index[bool]: ...
493491
@overload
494-
def __add__(self: Index[bool], other: np_ndarray_anyint) -> Index[int]: ...
492+
def __add__(self: Index[int], other: bool | Sequence[bool]) -> Index[int]: ...
495493
@overload
496-
def __add__(self: Index[bool], other: np_ndarray_float) -> Index[float]: ...
494+
def __add__(self: Index[float], other: int | Sequence[int]) -> Index[float]: ...
497495
@overload
498-
def __add__(self: Index[bool], other: np_ndarray_complex) -> Index[complex]: ...
496+
def __add__(
497+
self: Index[complex], other: float | Sequence[float]
498+
) -> Index[complex]: ...
499499
@overload
500500
def __add__(
501-
self: Index[int],
502-
other: (
503-
bool | Sequence[bool] | np_ndarray_bool | np_ndarray_anyint | Index[bool]
504-
),
505-
) -> Index[int]: ...
501+
self: Index[S1_CT],
502+
other: SupportsRAdd[S1_CT, S1_CO] | Sequence[SupportsRAdd[S1_CT, S1_CO]],
503+
) -> Index[S1_CO]: ...
506504
@overload
507505
def __add__(
508-
self: Index[int],
509-
other: T_COMPLEX | Sequence[T_COMPLEX] | Index[T_COMPLEX],
506+
self: Index[T_COMPLEX], other: np_ndarray_bool | Index[bool]
510507
) -> Index[T_COMPLEX]: ...
511508
@overload
512-
def __add__(self: Index[int], other: np_ndarray_float) -> Index[float]: ...
509+
def __add__(
510+
self: Index[bool], other: np_ndarray_anyint | Index[int]
511+
) -> Index[int]: ...
513512
@overload
514-
def __add__(self: Index[int], other: np_ndarray_complex) -> Index[complex]: ...
513+
def __add__(
514+
self: Index[T_COMPLEX], other: np_ndarray_anyint | Index[int]
515+
) -> Index[T_COMPLEX]: ...
515516
@overload
516517
def __add__(
517-
self: Index[float],
518-
other: (
519-
int
520-
| Sequence[int]
521-
| np_ndarray_bool
522-
| np_ndarray_anyint
523-
| np_ndarray_float
524-
| Index[T_INT]
525-
),
518+
self: Index[bool] | Index[int], other: np_ndarray_float | Index[float]
526519
) -> Index[float]: ...
527520
@overload
528521
def __add__(
529-
self: Index[float],
530-
other: T_COMPLEX | Sequence[T_COMPLEX] | Index[T_COMPLEX],
522+
self: Index[T_COMPLEX], other: np_ndarray_float | Index[float]
531523
) -> Index[T_COMPLEX]: ...
532524
@overload
533-
def __add__(self: Index[float], other: np_ndarray_complex) -> Index[complex]: ...
534-
@overload
535525
def __add__(
536-
self: Index[complex],
537-
other: (
538-
T_COMPLEX
539-
| Sequence[T_COMPLEX]
540-
| np_ndarray_bool
541-
| np_ndarray_anyint
542-
| np_ndarray_float
543-
| np_ndarray_complex
544-
| Index[T_COMPLEX]
545-
),
526+
self: Index[T_COMPLEX], other: np_ndarray_complex | Index[complex]
546527
) -> Index[complex]: ...
547528
@overload
548529
def __add__(
@@ -560,60 +541,43 @@ class Index(IndexOpsMixin[S1]):
560541
@overload
561542
def __radd__(self: Index[Never], other: complex | _ListLike | Index) -> Index: ...
562543
@overload
563-
def __radd__(
564-
self: Index[bool],
565-
other: T_COMPLEX | Sequence[T_COMPLEX] | Index[T_COMPLEX],
566-
) -> Index[T_COMPLEX]: ...
544+
def __radd__(self: Index[bool], other: bool | Sequence[bool]) -> Index[bool]: ...
567545
@overload
568-
def __radd__(self: Index[bool], other: np_ndarray_bool) -> Index[bool]: ...
546+
def __radd__(self: Index[int], other: bool | Sequence[bool]) -> Index[int]: ...
569547
@overload
570-
def __radd__(self: Index[bool], other: np_ndarray_anyint) -> Index[int]: ...
548+
def __radd__(self: Index[float], other: int | Sequence[int]) -> Index[float]: ...
571549
@overload
572-
def __radd__(self: Index[bool], other: np_ndarray_float) -> Index[float]: ...
550+
def __radd__(
551+
self: Index[complex], other: float | Sequence[float]
552+
) -> Index[complex]: ...
573553
@overload
574554
def __radd__(
575-
self: Index[int],
576-
other: (
577-
bool | Sequence[bool] | np_ndarray_bool | np_ndarray_anyint | Index[bool]
578-
),
579-
) -> Index[int]: ...
555+
self: Index[S1_CT],
556+
other: SupportsAdd[S1_CT, S1_CO] | Sequence[SupportsAdd[S1_CT, S1_CO]],
557+
) -> Index[S1_CO]: ...
580558
@overload
581559
def __radd__(
582-
self: Index[int], other: T_COMPLEX | Sequence[T_COMPLEX] | Index[T_COMPLEX]
560+
self: Index[T_COMPLEX], other: np_ndarray_bool | Index[bool]
583561
) -> Index[T_COMPLEX]: ...
584562
@overload
585-
def __radd__(self: Index[int], other: np_ndarray_float) -> Index[float]: ...
586-
@overload
587563
def __radd__(
588-
self: Index[float],
589-
other: (
590-
int
591-
| Sequence[int]
592-
| np_ndarray_bool
593-
| np_ndarray_anyint
594-
| np_ndarray_float
595-
| Index[T_INT]
596-
),
597-
) -> Index[float]: ...
564+
self: Index[bool], other: np_ndarray_anyint | Index[int]
565+
) -> Index[int]: ...
598566
@overload
599567
def __radd__(
600-
self: Index[float], other: T_COMPLEX | Sequence[T_COMPLEX] | Index[T_COMPLEX]
568+
self: Index[T_COMPLEX], other: np_ndarray_anyint | Index[int]
601569
) -> Index[T_COMPLEX]: ...
602570
@overload
603571
def __radd__(
604-
self: Index[complex],
605-
other: (
606-
T_COMPLEX
607-
| Sequence[T_COMPLEX]
608-
| np_ndarray_bool
609-
| np_ndarray_anyint
610-
| np_ndarray_float
611-
| Index[T_COMPLEX]
612-
),
613-
) -> Index[complex]: ...
572+
self: Index[bool] | Index[int], other: np_ndarray_float | Index[float]
573+
) -> Index[float]: ...
574+
@overload
575+
def __radd__(
576+
self: Index[T_COMPLEX], other: np_ndarray_float | Index[float]
577+
) -> Index[T_COMPLEX]: ...
614578
@overload
615579
def __radd__(
616-
self: Index[T_COMPLEX], other: np_ndarray_complex
580+
self: Index[T_COMPLEX], other: np_ndarray_complex | Index[complex]
617581
) -> Index[complex]: ...
618582
@overload
619583
def __radd__(

pandas-stubs/core/indexes/datetimes.pyi

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ class DatetimeIndex(
6666

6767
# various ignores needed for mypy, as we do want to restrict what can be used in
6868
# arithmetic for these types
69-
def __add__( # pyright: ignore[reportIncompatibleMethodOverride]
70-
self, other: timedelta | Timedelta | TimedeltaIndex | BaseOffset # type: ignore[override]
69+
def __add__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
70+
self, other: timedelta | TimedeltaIndex | BaseOffset
71+
) -> DatetimeIndex: ...
72+
def __radd__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
73+
self, other: timedelta | TimedeltaIndex | BaseOffset
7174
) -> DatetimeIndex: ...
7275
@overload # type: ignore[override]
7376
def __sub__(

pandas-stubs/core/indexes/period.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ class PeriodIndex(DatetimeIndexOpsMixin[pd.Period, np.object_], PeriodIndexField
3636
) -> Self: ...
3737
@property
3838
def values(self) -> np_1darray[np.object_]: ...
39+
def __add__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
40+
self, other: datetime.timedelta
41+
) -> Self: ...
42+
def __radd__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
43+
self, other: datetime.timedelta
44+
) -> Self: ...
3945
@overload # type: ignore[override]
4046
def __sub__(self, other: Period) -> Index: ...
4147
@overload

pandas-stubs/core/indexes/timedeltas.pyi

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ from pandas.core.indexes.period import PeriodIndex
2222
from pandas.core.series import Series
2323
from typing_extensions import Self
2424

25-
from pandas._libs import (
26-
Timedelta,
27-
Timestamp,
28-
)
25+
from pandas._libs import Timedelta
2926
from pandas._libs.tslibs import BaseOffset
3027
from pandas._typing import (
3128
AxesData,
@@ -53,12 +50,19 @@ class TimedeltaIndex(
5350
@overload # type: ignore[override]
5451
def __add__(self, other: Period) -> PeriodIndex: ...
5552
@overload
56-
def __add__(self, other: DatetimeIndex) -> DatetimeIndex: ...
53+
def __add__(self, other: dt.datetime | DatetimeIndex) -> DatetimeIndex: ...
5754
@overload
5855
def __add__( # pyright: ignore[reportIncompatibleMethodOverride]
59-
self, other: dt.timedelta | Timedelta | Self
56+
self, other: dt.timedelta | Self
57+
) -> Self: ...
58+
@overload # type: ignore[override]
59+
def __radd__(self, other: Period) -> PeriodIndex: ...
60+
@overload
61+
def __radd__(self, other: dt.datetime | DatetimeIndex) -> DatetimeIndex: ...
62+
@overload
63+
def __radd__( # pyright: ignore[reportIncompatibleMethodOverride]
64+
self, other: dt.timedelta | Self
6065
) -> Self: ...
61-
def __radd__(self, other: dt.datetime | Timestamp | DatetimeIndex) -> DatetimeIndex: ... # type: ignore[override]
6266
def __sub__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
6367
self, other: dt.timedelta | np.timedelta64 | np_ndarray_td | Self
6468
) -> Self: ...

0 commit comments

Comments
 (0)