Skip to content
92 changes: 85 additions & 7 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,15 @@ def __getitem__(self, key):
if is_list_like(key):
make_dataframe = True
else:
key = [key]
if self._as_index:
make_dataframe = False
else:
make_dataframe = True
key = [key]
internal_by = frozenset(self._internal_by)
cols_to_grab = internal_by.union(key)
key = [col for col in self._df.columns if col in cols_to_grab]
if make_dataframe:
internal_by = frozenset(self._internal_by)
if len(internal_by.intersection(key)) != 0:
ErrorMessage.missmatch_with_pandas(
operation="GroupBy.__getitem__",
Expand All @@ -443,8 +445,6 @@ def __getitem__(self, key):
+ "df.groupby(df['by_column'].copy())['by_column']"
),
)
cols_to_grab = internal_by.union(key)
key = [col for col in self._df.columns if col in cols_to_grab]
return DataFrameGroupBy(
self._df[key],
drop=self._drop,
Expand All @@ -461,7 +461,7 @@ def __getitem__(self, key):
)
return SeriesGroupBy(
self._df[key],
drop=False,
drop=True,
**kwargs,
)

Expand Down Expand Up @@ -687,8 +687,6 @@ def size(self):
if MODIN_UNNAMED_SERIES_LABEL in result.columns
else result
)
elif isinstance(self._df, Series):
result.name = self._df.name
else:
result.name = None
return result.fillna(0)
Expand Down Expand Up @@ -1194,6 +1192,65 @@ def groupby_on_multiple_columns(df, *args, **kwargs):

@_inherit_docstrings(pandas.core.groupby.SeriesGroupBy)
class SeriesGroupBy(SeriesGroupByCompat, DataFrameGroupBy):
def __init__(
self,
df,
by,
axis,
level,
as_index,
sort,
group_keys,
squeeze,
idx_name,
drop,
**kwargs,
):
super(SeriesGroupBy, self).__init__(
df,
by,
axis,
level,
as_index,
sort,
group_keys,
squeeze,
idx_name,
drop,
**kwargs,
)
self._squeeze = True

def _default_to_pandas(self, f, *args, **kwargs):
"""
Execute function `f` in default-to-pandas way.

Parameters
----------
f : callable
The function to apply to each group.
*args : list
Extra positional arguments to pass to `f`.
**kwargs : dict
Extra keyword arguments to pass to `f`.

Returns
-------
modin.pandas.Series or modin.pandas.DataFrame
A new Modin Series or DataFrame with the result of the pandas function.
"""
old_df = self._df
self._df = self._df[
next(
col_name
for col_name in self._df.columns
if col_name not in self._internal_by
)
]
intermediate = super(SeriesGroupBy, self)._default_to_pandas(f, *args, **kwargs)
self._df = old_df
return intermediate

@property
def ndim(self):
"""
Expand All @@ -1210,6 +1267,18 @@ def ndim(self):
"""
return 1 # ndim is always 1 for Series

def size(self):
intermediate = super(SeriesGroupBy, self).size()
if isinstance(self._df, Series):
intermediate.name = self._df.name
else:
intermediate.name = next(
col_name
for col_name in self._df.columns
if col_name not in self._internal_by
)
return intermediate

@property
def _iter(self):
"""
Expand Down Expand Up @@ -1247,6 +1316,15 @@ def _iter(self):
for k in (sorted(group_ids) if self._sort else group_ids)
)

def aggregate(self, func=None, *args, **kwargs):
if isinstance(func, (list, dict)):
self._squeeze = False
result = super(SeriesGroupBy, self).aggregate(func, *args, **kwargs)
self._squeeze = True
return result

agg = aggregate


if IsExperimental.get():
from modin.experimental.cloud.meta_magic import make_wrapped_class
Expand Down