Skip to content

Commit bb6d0f9

Browse files
authored
refactor: simplify pd.concat (#1465)
simplify concat
1 parent 6e9a209 commit bb6d0f9

File tree

2 files changed

+58
-146
lines changed

2 files changed

+58
-146
lines changed

pandas-stubs/core/reshape/concat.pyi

Lines changed: 43 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ from typing import (
88
overload,
99
)
1010

11-
from pandas import (
12-
DataFrame,
13-
Series,
14-
)
11+
from pandas.core.frame import DataFrame
12+
from pandas.core.generic import NDFrame
13+
from pandas.core.series import Series
1514
from typing_extensions import Never
1615

1716
from pandas._typing import (
@@ -24,156 +23,59 @@ from pandas._typing import (
2423
HashableT4,
2524
)
2625

27-
@overload
28-
def concat( # type: ignore[overload-overlap]
29-
objs: Iterable[DataFrame] | Mapping[HashableT1, DataFrame],
30-
*,
31-
axis: Axis = ...,
32-
join: Literal["inner", "outer"] = ...,
33-
ignore_index: bool = ...,
34-
keys: Iterable[HashableT2] | None = ...,
35-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
36-
names: list[HashableT4] | None = ...,
37-
verify_integrity: bool = ...,
38-
sort: bool = ...,
39-
copy: bool = ...,
40-
) -> DataFrame: ...
41-
@overload
42-
def concat( # pyright: ignore[reportOverlappingOverload]
43-
objs: Iterable[Series[S2]],
44-
*,
45-
axis: AxisIndex = ...,
46-
join: Literal["inner", "outer"] = ...,
47-
ignore_index: bool = ...,
48-
keys: Iterable[HashableT2] | None = ...,
49-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
50-
names: list[HashableT4] | None = ...,
51-
verify_integrity: bool = ...,
52-
sort: bool = ...,
53-
copy: bool = ...,
54-
) -> Series[S2]: ...
55-
@overload
56-
def concat( # type: ignore[overload-overlap]
57-
objs: Iterable[Series] | Mapping[HashableT1, Series],
58-
*,
59-
axis: AxisIndex = ...,
60-
join: Literal["inner", "outer"] = ...,
61-
ignore_index: bool = ...,
62-
keys: Iterable[HashableT2] | None = ...,
63-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
64-
names: list[HashableT4] | None = ...,
65-
verify_integrity: bool = ...,
66-
sort: bool = ...,
67-
copy: bool = ...,
68-
) -> Series: ...
69-
@overload
70-
def concat( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
71-
objs: Iterable[Series | DataFrame] | Mapping[HashableT1, Series | DataFrame],
72-
*,
73-
axis: Axis = ...,
74-
join: Literal["inner", "outer"] = ...,
75-
ignore_index: bool = ...,
76-
keys: Iterable[HashableT2] | None = ...,
77-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
78-
names: list[HashableT4] | None = ...,
79-
verify_integrity: bool = ...,
80-
sort: bool = ...,
81-
copy: bool = ...,
82-
) -> DataFrame: ...
8326
@overload
8427
def concat(
8528
objs: Iterable[None] | Mapping[HashableT1, None],
8629
*,
87-
axis: Axis = ...,
88-
join: Literal["inner", "outer"] = ...,
89-
ignore_index: bool = ...,
90-
keys: Iterable[HashableT2] | None = ...,
91-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
92-
names: list[HashableT4] | None = ...,
93-
verify_integrity: bool = ...,
94-
sort: bool = ...,
95-
copy: bool = ...,
30+
axis: Axis = 0,
31+
join: Literal["inner", "outer"] = "outer",
32+
ignore_index: bool = False,
33+
keys: Iterable[HashableT2] | None = None,
34+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = None,
35+
names: list[HashableT4] | None = None,
36+
verify_integrity: bool = False,
37+
sort: bool = False,
38+
copy: bool = True,
9639
) -> Never: ...
9740
@overload
98-
def concat( # type: ignore[overload-overlap]
99-
objs: Iterable[DataFrame | None] | Mapping[HashableT1, DataFrame | None],
41+
def concat( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
42+
objs: Iterable[Series[S2] | None] | Mapping[HashableT1, Series[S2] | None],
10043
*,
101-
axis: Axis = ...,
102-
join: Literal["inner", "outer"] = ...,
103-
ignore_index: bool = ...,
104-
keys: Iterable[HashableT2] | None = ...,
105-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
106-
names: list[HashableT4] | None = ...,
107-
verify_integrity: bool = ...,
108-
sort: bool = ...,
109-
copy: bool = ...,
110-
) -> DataFrame: ...
44+
axis: AxisIndex = 0,
45+
join: Literal["inner", "outer"] = "outer",
46+
ignore_index: bool = False,
47+
keys: Iterable[HashableT2] | None = None,
48+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = None,
49+
names: list[HashableT4] | None = None,
50+
verify_integrity: bool = False,
51+
sort: bool = False,
52+
copy: bool = True,
53+
) -> Series[S2]: ...
11154
@overload
11255
def concat( # type: ignore[overload-overlap]
11356
objs: Iterable[Series | None] | Mapping[HashableT1, Series | None],
11457
*,
115-
axis: AxisIndex = ...,
116-
join: Literal["inner", "outer"] = ...,
117-
ignore_index: bool = ...,
118-
keys: Iterable[HashableT2] | None = ...,
119-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
120-
names: list[HashableT4] | None = ...,
121-
verify_integrity: bool = ...,
122-
sort: bool = ...,
123-
copy: bool = ...,
58+
axis: AxisIndex = 0,
59+
join: Literal["inner", "outer"] = "outer",
60+
ignore_index: bool = False,
61+
keys: Iterable[HashableT2] | None = None,
62+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = None,
63+
names: list[HashableT4] | None = None,
64+
verify_integrity: bool = False,
65+
sort: bool = False,
66+
copy: bool = True,
12467
) -> Series: ...
12568
@overload
12669
def concat(
127-
objs: (
128-
Iterable[Series | DataFrame | None]
129-
| Mapping[HashableT1, Series | DataFrame | None]
130-
),
70+
objs: Iterable[NDFrame | None] | Mapping[HashableT1, NDFrame | None],
13171
*,
132-
axis: Axis = ...,
133-
join: Literal["inner", "outer"] = ...,
134-
ignore_index: bool = ...,
135-
keys: Iterable[HashableT2] | None = ...,
136-
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = ...,
137-
names: list[HashableT4] | None = ...,
138-
verify_integrity: bool = ...,
139-
sort: bool = ...,
140-
copy: bool = ...,
72+
axis: Axis = 0,
73+
join: Literal["inner", "outer"] = "outer",
74+
ignore_index: bool = False,
75+
keys: Iterable[HashableT2] | None = None,
76+
levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] | None = None,
77+
names: list[HashableT4] | None = None,
78+
verify_integrity: bool = False,
79+
sort: bool = False,
80+
copy: bool = True,
14181
) -> DataFrame: ...
142-
143-
# Including either of the next 2 overloads causes mypy to complain about
144-
# test_pandas.py:test_types_concat() in assert_type(pd.concat([s, s2]), pd.Series)
145-
# It thinks that pd.concat([s, s2]) is Any . May be due to Series being
146-
# Generic, or the axis argument being unspecified, and then there is partial
147-
# overlap with the first 2 overloads.
148-
#
149-
# @overload
150-
# def concat(
151-
# objs: Union[
152-
# Iterable[Union[Series, DataFrame]], Mapping[HashableT, Union[Series, DataFrame]]
153-
# ],
154-
# axis: Literal[0, "index"] = ...,
155-
# join: str = ...,
156-
# ignore_index: bool = ...,
157-
# keys=...,
158-
# levels=...,
159-
# names=...,
160-
# verify_integrity: bool = ...,
161-
# sort: bool = ...,
162-
# copy: bool = ...,
163-
# ) -> DataFrame | Series: ...
164-
165-
# @overload
166-
# def concat(
167-
# objs: Union[
168-
# Iterable[Union[Series, DataFrame]], Mapping[HashableT, Union[Series, DataFrame]]
169-
# ],
170-
# axis: Axis = ...,
171-
# join: str = ...,
172-
# ignore_index: bool = ...,
173-
# keys=...,
174-
# levels=...,
175-
# names=...,
176-
# verify_integrity: bool = ...,
177-
# sort: bool = ...,
178-
# copy: bool = ...,
179-
# ) -> DataFrame | Series: ...

tests/test_pandas.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,20 @@ def test_types_concat_none() -> None:
107107
series = pd.Series([7, -5, 10])
108108
df = pd.DataFrame({"a": [7, -5, 10]})
109109

110-
check(assert_type(pd.concat([None, series]), pd.Series), pd.Series)
110+
check(
111+
assert_type(pd.concat([None, series]), "pd.Series[int]"), pd.Series, np.integer
112+
)
111113
check(assert_type(pd.concat([None, df]), pd.DataFrame), pd.DataFrame)
112114
check(
113115
assert_type(pd.concat([None, series, df], axis=1), pd.DataFrame), pd.DataFrame
114116
)
115117
check(assert_type(pd.concat([None, series, df]), pd.DataFrame), pd.DataFrame)
116118

117-
check(assert_type(pd.concat({"a": None, "b": series}), pd.Series), pd.Series)
119+
check(
120+
assert_type(pd.concat({"a": None, "b": series}), "pd.Series[int]"),
121+
pd.Series,
122+
np.integer,
123+
)
118124
check(assert_type(pd.concat({"a": None, "b": df}), pd.DataFrame), pd.DataFrame)
119125
check(
120126
assert_type(pd.concat({"a": None, "b": series, "c": df}, axis=1), pd.DataFrame),
@@ -163,18 +169,22 @@ def test_types_concat() -> None:
163169

164170
# Depends on the axis
165171
check(
166-
assert_type(pd.concat({"a": s, "b": s2}), pd.Series),
172+
assert_type(pd.concat({"a": s, "b": s2}), "pd.Series[int]"),
167173
pd.Series,
174+
np.integer,
168175
)
169176
check(
170177
assert_type(pd.concat({"a": s, "b": s2}, axis=1), pd.DataFrame),
171178
pd.DataFrame,
172179
)
173-
check(assert_type(pd.concat({1: s, 2: s2}), pd.Series), pd.Series)
180+
check(
181+
assert_type(pd.concat({1: s, 2: s2}), "pd.Series[int]"), pd.Series, np.integer
182+
)
174183
check(assert_type(pd.concat({1: s, 2: s2}, axis=1), pd.DataFrame), pd.DataFrame)
175184
check(
176-
assert_type(pd.concat({1: s, None: s2}), pd.Series),
185+
assert_type(pd.concat({1: s, None: s2}), "pd.Series[int]"),
177186
pd.Series,
187+
np.integer,
178188
)
179189

180190
# https://github.com/microsoft/python-type-stubs/issues/69

0 commit comments

Comments
 (0)