From 1ac959ea5056fbf1329efec6408efdec6b8f7d8f Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Sat, 22 Mar 2025 23:31:40 +0100 Subject: [PATCH 1/2] changing polars to use a lazyframe --- mesa_frames/concrete/polars/agentset.py | 219 +++++++----- mesa_frames/concrete/polars/mixin.py | 433 +++++++++++++----------- tests/polars/test_agentset_polars.py | 402 ++++++++++++++++++---- tests/polars/test_mixin_polars.py | 318 +++++++++-------- 4 files changed, 866 insertions(+), 506 deletions(-) diff --git a/mesa_frames/concrete/polars/agentset.py b/mesa_frames/concrete/polars/agentset.py index 4bec1ea5..3ddf29bf 100644 --- a/mesa_frames/concrete/polars/agentset.py +++ b/mesa_frames/concrete/polars/agentset.py @@ -80,7 +80,7 @@ def step(self): class AgentSetPolars(AgentSetDF, PolarsMixin): """Polars-based implementation of AgentSetDF.""" - _agents: pl.DataFrame + _agents: pl.LazyFrame _copy_with_method: dict[str, tuple[str, list[str]]] = { "_agents": ("clone", []), } @@ -96,19 +96,19 @@ def __init__(self, model: "ModelDF") -> None: The model that the agent set belongs to. """ self._model = model - self._agents = pl.DataFrame(schema={"unique_id": pl.Int64}) - self._mask = pl.repeat(True, len(self._agents), dtype=pl.Boolean, eager=True) + self._agents = pl.LazyFrame(schema={"unique_id": pl.Int64}) + self._mask = pl.repeat(True, 0, dtype=pl.Boolean) def add( self, - agents: pl.DataFrame | Sequence[Any] | dict[str, Any], + agents: pl.DataFrame | pl.LazyFrame | Sequence[Any] | dict[str, Any], inplace: bool = True, ) -> Self: """Add agents to the AgentSetPolars. Parameters ---------- - agents : pl.DataFrame | Sequence[Any] | dict[str, Any] + agents : pl.DataFrame | pl.LazyFrame | Sequence[Any] | dict[str, Any] The agents to add. inplace : bool, optional Whether to add the agents in place, by default True. @@ -119,35 +119,37 @@ def add( The updated AgentSetPolars. """ obj = self._get_obj(inplace) - if isinstance(agents, pl.DataFrame): + if isinstance(agents, (pl.DataFrame, pl.LazyFrame)): if "unique_id" not in agents.columns: raise KeyError("DataFrame must have a unique_id column.") - new_agents = agents + new_agents = agents.lazy() if isinstance(agents, pl.DataFrame) else agents elif isinstance(agents, dict): if "unique_id" not in agents: raise KeyError("Dictionary must have a unique_id key.") - new_agents = pl.DataFrame(agents) + new_agents = pl.LazyFrame(agents) else: if len(agents) != len(obj._agents.columns): raise ValueError( "Length of data must match the number of columns in the AgentSet if being added as a Collection." ) - new_agents = pl.DataFrame([agents], schema=obj._agents.schema) + new_agents = pl.LazyFrame([agents], schema=obj._agents.schema) - if new_agents["unique_id"].dtype != pl.Int64: + # Collect schema to check unique_id type + if new_agents.schema["unique_id"] != pl.Int64: raise TypeError("unique_id column must be of type int64.") # If self._mask is pl.Expr, then new mask is the same. # If self._mask is pl.Series[bool], then new mask has to be updated. if isinstance(obj._mask, pl.Series): + # original_active_indices = obj._agents.filter(obj._mask).collect()["unique_id"] original_active_indices = obj._agents.filter(obj._mask)["unique_id"] obj._agents = pl.concat([obj._agents, new_agents], how="diagonal_relaxed") if isinstance(obj._mask, pl.Series): - obj._update_mask(original_active_indices, new_agents["unique_id"]) - + # obj._update_mask(original_active_indices, new_agents.collect()["unique_id"]) + obj._update_mask(original_active_indices, new_agents.collect()["unique_id"]) return obj @overload @@ -160,12 +162,14 @@ def contains( self, agents: PolarsIdsLike, ) -> bool | pl.Series: + # Need to collect for containment check + agent_ids = self._agents.select("unique_id").collect()["unique_id"] if isinstance(agents, pl.Series): - return agents.is_in(self._agents["unique_id"]) + return agents.is_in(agent_ids) elif isinstance(agents, Collection): - return pl.Series(agents).is_in(self._agents["unique_id"]) + return pl.Series(agents).is_in(agent_ids) else: - return agents in self._agents["unique_id"] + return agents in agent_ids def get( self, @@ -173,10 +177,14 @@ def get( mask: AgentPolarsMask = None, ) -> pl.Series | pl.DataFrame: masked_df = self._get_masked_df(mask) - attr_names = self.agents.select(attr_names).columns.copy() + attr_names = ( + self.agents.select(attr_names).collect().columns.copy() + if attr_names + else [] + ) if not attr_names: - return masked_df - masked_df = masked_df.select(attr_names) + return masked_df.collect() + masked_df = masked_df.select(attr_names).collect() if masked_df.shape[1] == 1: return masked_df[masked_df.columns[0]] return masked_df @@ -193,14 +201,19 @@ def set( masked_df = obj._get_masked_df(mask) if not attr_names: - attr_names = masked_df.columns + attr_names = masked_df.collect().columns attr_names.remove("unique_id") def process_single_attr( - masked_df: pl.DataFrame, attr_name: str, values: Any - ) -> pl.DataFrame: - if isinstance(values, pl.DataFrame): - return masked_df.with_columns(values.to_series().alias(attr_name)) + masked_df: pl.LazyFrame, attr_name: str, values: Any + ) -> pl.LazyFrame: + if isinstance(values, (pl.DataFrame, pl.LazyFrame)): + values_series = ( + values.collect() if isinstance(values, pl.LazyFrame) else values + ) + return masked_df.with_columns( + values_series.to_series().alias(attr_name) + ) elif isinstance(values, pl.Expr): return masked_df.with_columns(values.alias(attr_name)) if isinstance(values, pl.Series): @@ -209,7 +222,7 @@ def process_single_attr( if isinstance(values, Collection): values = pl.Series(values) else: - values = pl.repeat(values, len(masked_df)) + values = pl.repeat(values, masked_df.collect().height) return masked_df.with_columns(values.alias(attr_name)) if isinstance(attr_names, str) and values is not None: @@ -247,8 +260,12 @@ def select( if filter_func: mask = mask & filter_func(obj) if n is not None: - mask = (obj._agents["unique_id"]).is_in( - obj._agents.filter(mask).sample(n)["unique_id"] + # Need to collect for sampling + sample_ids = obj._agents.filter(mask).collect().sample(n)["unique_id"] + mask = ( + (obj._agents.select("unique_id")) + .collect()["unique_id"] + .is_in(sample_ids) ) if negate: mask = mask.not_() @@ -257,10 +274,15 @@ def select( def shuffle(self, inplace: bool = True) -> Self: obj = self._get_obj(inplace) - obj._agents = obj._agents.sample( - fraction=1, - shuffle=True, - seed=obj.random.integers(np.iinfo(np.int32).max), + # Collect to perform shuffle, then convert back to LazyFrame + obj._agents = ( + obj._agents.collect() + .sample( + fraction=1, + shuffle=True, + seed=obj.random.integers(np.iinfo(np.int32).max), + ) + .lazy() ) return obj @@ -283,13 +305,14 @@ def to_pandas(self) -> "AgentSetPandas": from mesa_frames.concrete.pandas.agentset import AgentSetPandas new_obj = AgentSetPandas(self._model) - new_obj._agents = self._agents.to_pandas() + new_obj._agents = self._agents.collect().to_pandas() if isinstance(self._mask, pl.Series): new_obj._mask = self._mask.to_pandas() else: # self._mask is Expr new_obj._mask = ( - self._agents["unique_id"] - .is_in(self._agents.filter(self._mask)["unique_id"]) + self._agents.select("unique_id") + .collect()["unique_id"] + .is_in(self._agents.filter(self._mask).collect()["unique_id"]) .to_pandas() ) return new_obj @@ -302,8 +325,9 @@ def _concatenate_agentsets( original_masked_index: pl.Series | None = None, ) -> Self: if not duplicates_allowed: - indices_list = [self._agents["unique_id"]] + [ - agentset._agents["unique_id"] for agentset in agentsets + indices_list = [self._agents.select("unique_id").collect()["unique_id"]] + [ + agentset._agents.select("unique_id").collect()["unique_id"] + for agentset in agentsets ] all_indices = pl.concat(indices_list) if all_indices.is_duplicated().any(): @@ -315,26 +339,37 @@ def _concatenate_agentsets( max_length = max(len(agentset) for agentset in agentsets) for agentset in agentsets: if len(agentset) == max_length: - original_index = agentset._agents["unique_id"] + original_index = agentset._agents.select("unique_id").collect()[ + "unique_id" + ] final_dfs = [self._agents] - final_active_indices = [self._agents["unique_id"]] - final_indices = self._agents["unique_id"].clone() + final_active_indices = [ + self._agents.filter(self._mask).collect()["unique_id"] + ] + final_indices = ( + self._agents.select("unique_id").collect()["unique_id"].clone() + ) for obj in iter(agentsets): # Remove agents that are already in the final DataFrame final_dfs.append( - obj._agents.filter(pl.col("unique_id").is_in(final_indices).not_()) + obj._agents.filter(~pl.col("unique_id").is_in(final_indices)) ) # Add the indices of the active agents of current AgentSet - final_active_indices.append(obj._agents.filter(obj._mask)["unique_id"]) + final_active_indices.append( + obj._agents.filter(obj._mask).collect()["unique_id"] + ) # Update the indices of the agents in the final DataFrame final_indices = pl.concat( - [final_indices, final_dfs[-1]["unique_id"]], how="vertical" + [ + final_indices, + final_dfs[-1].select("unique_id").collect()["unique_id"], + ], + how="vertical", ) # Left-join original index with concatenated dfs to keep original ids order - final_df = original_index.to_frame().join( + final_df = pl.LazyFrame({"unique_id": original_index}).join( pl.concat(final_dfs, how="diagonal_relaxed"), on="unique_id", how="left" ) - # final_active_index = pl.concat(final_active_indices, how="vertical") else: @@ -342,15 +377,24 @@ def _concatenate_agentsets( [obj._agents for obj in agentsets], how="diagonal_relaxed" ) final_active_index = pl.concat( - [obj._agents.filter(obj._mask)["unique_id"] for obj in agentsets] + [ + obj._agents.filter(obj._mask).collect()["unique_id"] + for obj in agentsets + ] ) - final_mask = final_df["unique_id"].is_in(final_active_index) + final_mask = ( + final_df.select("unique_id") + .collect()["unique_id"] + .is_in(final_active_index) + ) self._agents = final_df self._mask = final_mask # If some ids were removed in the do-method, we need to remove them also from final_df if not isinstance(original_masked_index, type(None)): ids_to_remove = original_masked_index.filter( - original_masked_index.is_in(self._agents["unique_id"]).not_() + ~original_masked_index.is_in( + self._agents.select("unique_id").collect()["unique_id"] + ) ) if not ids_to_remove.is_empty(): self.remove(ids_to_remove, inplace=True) @@ -364,26 +408,27 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: if ( isinstance(mask, pl.Series) and mask.dtype == pl.Boolean - and len(mask) == len(self._agents) + and len(mask) == len(self._agents.collect()) ): return mask - return self._agents["unique_id"].is_in(mask) + return self._agents.select("unique_id").collect()["unique_id"].is_in(mask) if isinstance(mask, pl.Expr): return mask elif isinstance(mask, pl.Series): return bool_mask_from_series(mask) - elif isinstance(mask, pl.DataFrame): - if "unique_id" in mask.columns: - return bool_mask_from_series(mask["unique_id"]) - elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: - return bool_mask_from_series(mask[mask.columns[0]]) + elif isinstance(mask, (pl.DataFrame, pl.LazyFrame)): + mask_df = mask.collect() if isinstance(mask, pl.LazyFrame) else mask + if "unique_id" in mask_df.columns: + return bool_mask_from_series(mask_df["unique_id"]) + elif len(mask_df.columns) == 1 and mask_df.dtypes[0] == pl.Boolean: + return bool_mask_from_series(mask_df[mask_df.columns[0]]) else: raise KeyError( "DataFrame must have a 'unique_id' column or a single boolean column." ) elif mask is None or mask == "all": - return pl.repeat(True, len(self._agents)) + return pl.repeat(True, len(self._agents.collect())) elif mask == "active": return self._mask elif isinstance(mask, Collection): @@ -394,25 +439,28 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: def _get_masked_df( self, mask: AgentPolarsMask = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if (isinstance(mask, pl.Series) and mask.dtype == pl.Boolean) or isinstance( mask, pl.Expr ): return self._agents.filter(mask) - elif isinstance(mask, pl.DataFrame): - if not mask["unique_id"].is_in(self._agents["unique_id"]).all(): + elif isinstance(mask, (pl.DataFrame, pl.LazyFrame)): + mask_df = mask.collect() if isinstance(mask, pl.LazyFrame) else mask + agents_ids = self._agents.select("unique_id").collect()["unique_id"] + if not mask_df["unique_id"].is_in(agents_ids).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) - return mask.select("unique_id").join( + return pl.LazyFrame({"unique_id": mask_df["unique_id"]}).join( self._agents, on="unique_id", how="left" ) elif isinstance(mask, pl.Series): - if not mask.is_in(self._agents["unique_id"]).all(): + agents_ids = self._agents.select("unique_id").collect()["unique_id"] + if not mask.is_in(agents_ids).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) - mask_df = mask.to_frame("unique_id") + mask_df = pl.LazyFrame({"unique_id": mask}) return mask_df.join(self._agents, on="unique_id", how="left") elif mask is None or mask == "all": return self._agents @@ -423,11 +471,12 @@ def _get_masked_df( mask_series = pl.Series(mask) else: mask_series = pl.Series([mask]) - if not mask_series.is_in(self._agents["unique_id"]).all(): + agents_ids = self._agents.select("unique_id").collect()["unique_id"] + if not mask_series.is_in(agents_ids).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) - mask_df = mask_series.to_frame("unique_id") + mask_df = pl.LazyFrame({"unique_id": mask_series}) return mask_df.join(self._agents, on="unique_id", how="left") @overload @@ -436,14 +485,21 @@ def _get_obj_copy(self, obj: pl.Series) -> pl.Series: ... @overload def _get_obj_copy(self, obj: pl.DataFrame) -> pl.DataFrame: ... - def _get_obj_copy(self, obj: pl.Series | pl.DataFrame) -> pl.Series | pl.DataFrame: + @overload + def _get_obj_copy(self, obj: pl.LazyFrame) -> pl.LazyFrame: ... + + def _get_obj_copy( + self, obj: pl.Series | pl.DataFrame | pl.LazyFrame + ) -> pl.Series | pl.DataFrame | pl.LazyFrame: return obj.clone() def _discard(self, ids: PolarsIdsLike) -> Self: mask = self._get_bool_mask(ids) if isinstance(self._mask, pl.Series): - original_active_indices = self._agents.filter(self._mask)["unique_id"] + original_active_indices = self._agents.filter(self._mask).collect()[ + "unique_id" + ] self._agents = self._agents.filter(mask.not_()) @@ -455,16 +511,17 @@ def _discard(self, ids: PolarsIdsLike) -> Self: def _update_mask( self, original_active_indices: pl.Series, new_indices: pl.Series | None = None ) -> None: + agent_ids = self._agents.select("unique_id").collect()["unique_id"] if new_indices is not None: - self._mask = self._agents["unique_id"].is_in( - original_active_indices - ) | self._agents["unique_id"].is_in(new_indices) + self._mask = agent_ids.is_in(original_active_indices) | agent_ids.is_in( + new_indices + ) else: - self._mask = self._agents["unique_id"].is_in(original_active_indices) + self._mask = agent_ids.is_in(original_active_indices) def __getattr__(self, key: str) -> pl.Series: super().__getattr__(key) - return self._agents[key] + return self._agents.select(key).collect()[key] @overload def __getitem__( @@ -503,26 +560,28 @@ def __getitem__( return attr def __iter__(self) -> Iterator[dict[str, Any]]: - return iter(self._agents.iter_rows(named=True)) + return iter(self._agents.collect().iter_rows(named=True)) def __len__(self) -> int: - return len(self._agents) + return self._agents.collect().height def __reversed__(self) -> Iterator: - return reversed(iter(self._agents.iter_rows(named=True))) + return reversed(list(self._agents.collect().iter_rows(named=True))) @property - def agents(self) -> pl.DataFrame: + def agents(self) -> pl.LazyFrame: return self._agents @agents.setter - def agents(self, agents: pl.DataFrame) -> None: - if "unique_id" not in agents.columns: + def agents(self, agents: pl.DataFrame | pl.LazyFrame) -> None: + if "unique_id" not in ( + agents.columns if isinstance(agents, pl.LazyFrame) else agents.columns + ): raise KeyError("DataFrame must have a unique_id column.") - self._agents = agents + self._agents = agents.lazy() if isinstance(agents, pl.DataFrame) else agents @property - def active_agents(self) -> pl.DataFrame: + def active_agents(self) -> pl.LazyFrame: return self.agents.filter(self._mask) @active_agents.setter @@ -530,13 +589,13 @@ def active_agents(self, mask: AgentPolarsMask) -> None: self.select(mask=mask, inplace=True) @property - def inactive_agents(self) -> pl.DataFrame: + def inactive_agents(self) -> pl.LazyFrame: return self.agents.filter(~self._mask) @property def index(self) -> pl.Series: - return self._agents["unique_id"] + return self._agents.select("unique_id").collect()["unique_id"] @property - def pos(self) -> pl.DataFrame: + def pos(self) -> pl.LazyFrame: return super().pos diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index d9825dad..3c819394 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -9,11 +9,11 @@ Classes: PolarsMixin(DataFrameMixin): A Polars-based implementation of DataFrame operations. This class provides - methods for manipulating and analyzing data stored in Polars DataFrames, + methods for manipulating and analyzing data stored in Polars LazyFrames, tailored for use in mesa-frames components like AgentSetPolars and GridPolars. The PolarsMixin class is designed to be used as a mixin with other mesa-frames -classes, providing them with Polars-specific DataFrame functionality. It implements +classes, providing them with Polars-specific LazyFrame functionality. It implements the abstract methods defined in the DataFrameMixin, ensuring consistent DataFrame operations across the mesa-frames package. @@ -26,7 +26,7 @@ class AgentSetPolars(AgentSetDF, PolarsMixin): def __init__(self, model): super().__init__(model) - self.agents = pl.DataFrame() # Initialize empty DataFrame + self.agents = pl.LazyFrame() # Initialize empty LazyFrame def some_method(self): # Use Polars operations provided by the mixin @@ -34,8 +34,8 @@ def some_method(self): # ... further processing ... Features: - - High-performance DataFrame operations using Polars - - Support for both eager and lazy evaluation + - High-performance LazyFrame operations using Polars + - Support for lazy evaluation with improved query optimization - Efficient memory usage and fast computation - Integration with Polars' query optimization capabilities @@ -55,18 +55,18 @@ def some_method(self): class PolarsMixin(DataFrameMixin): - """Polars-specific implementation of DataFrame operations.""" + """Polars-specific implementation of DataFrame operations using LazyFrames.""" # TODO: complete with other dtypes _dtypes_mapping: dict[str, Any] = {"int64": pl.Int64, "bool": pl.Boolean} def _df_add( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -77,21 +77,23 @@ def _df_add( def _df_all( self, - df: pl.DataFrame, + df: pl.LazyFrame, name: str = "all", axis: Literal["index", "columns"] = "columns", - ) -> pl.Series: + ) -> pl.Expr: if axis == "index": - return pl.Series(name, df.select(pl.col("*").all()).row(0)) - return df.with_columns(pl.all_horizontal("*").alias(name))[name] + # Return an expression that will evaluate to all values across index + return pl.all(pl.col("*")).alias(name) + # Return an expression for all values across columns + return pl.all_horizontal(pl.col("*")).alias(name) def _df_and( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -100,73 +102,85 @@ def _df_and( index_cols=index_cols, ) - def _df_column_names(self, df: pl.DataFrame) -> list[str]: + def _df_column_names(self, df: pl.LazyFrame) -> list[str]: + # This operation requires schema inspection which is available on LazyFrame return df.columns def _df_combine_first( self, - original_df: pl.DataFrame, - new_df: pl.DataFrame, + original_df: pl.LazyFrame, + new_df: pl.LazyFrame, index_cols: str | list[str], - ) -> pl.DataFrame: - original_df = original_df.with_columns(_index=pl.int_range(0, len(original_df))) + ) -> pl.LazyFrame: + # Create a sequential index using with_row_count instead of int_range + original_df = original_df.with_row_count("_index") common_cols = set(original_df.columns) & set(new_df.columns) merged_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") - merged_df = ( - merged_df.with_columns( - pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) - for col in common_cols - ) - .select(pl.exclude("^.*_right$")) - .sort("_index") - .drop("_index") - ) - return merged_df + + # Use expressions to coalesce values + coalesce_exprs = [ + pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) + for col in common_cols + if col in merged_df.columns and f"{col}_right" in merged_df.columns + ] + + # Apply coalesce expressions and drop right columns + merged_df = merged_df.with_columns(coalesce_exprs) + right_cols = [col for col in merged_df.columns if col.endswith("_right")] + merged_df = merged_df.drop(right_cols) + + # Sort by index and drop index column + return merged_df.sort("_index").drop("_index") @overload def _df_concat( self, - objs: Collection[pl.DataFrame], + objs: Collection[pl.LazyFrame], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: ... + ) -> pl.LazyFrame: ... @overload def _df_concat( self, - objs: Collection[pl.Series], + objs: Collection[pl.Expr], how: Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | list[str] | None = None, - ) -> pl.Series: ... + ) -> pl.Expr: ... @overload def _df_concat( self, - objs: Collection[pl.Series], + objs: Collection[pl.Expr], how: Literal["horizontal"] = "horizontal", ignore_index: bool = False, index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: ... + ) -> pl.LazyFrame: ... def _df_concat( self, - objs: Collection[pl.DataFrame] | Collection[pl.Series], + objs: Collection[pl.LazyFrame] | Collection[pl.Expr], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | None = None, - ) -> pl.Series | pl.DataFrame: - if isinstance(objs[0], pl.DataFrame) and how == "vertical": + ) -> pl.LazyFrame | pl.Expr: + if isinstance(next(iter(objs), None), pl.LazyFrame) and how == "vertical": how = "diagonal_relaxed" - if isinstance(objs[0], pl.Series) and how == "horizontal": - obj = pl.DataFrame().hstack(objs, in_place=True) + + if isinstance(next(iter(objs), None), pl.Expr) and how == "horizontal": + # Convert expressions to LazyFrames for horizontal concat + obj = pl.LazyFrame().with_columns(list(objs)) else: + # Use concat on LazyFrames directly obj = pl.concat(objs, how=how) - if isinstance(obj, pl.DataFrame) and how == "horizontal" and ignore_index: - obj = obj.rename( - {c: str(i) for c, i in zip(obj.columns, range(len(obj.columns)))} - ) + + if isinstance(obj, pl.LazyFrame) and how == "horizontal" and ignore_index: + # Rename columns if ignore_index is True + rename_dict = {c: str(i) for i, c in enumerate(obj.columns)} + obj = obj.rename(rename_dict) + return obj def _df_constructor( @@ -176,42 +190,50 @@ def _df_constructor( index: Sequence[Hashable] | None = None, index_cols: str | list[str] | None = None, dtypes: dict[str, str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if dtypes is not None: dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} + + # Convert pandas DataFrame to Polars if isinstance(data, pd.DataFrame): data = data.reset_index() - df = pl.DataFrame( - data=data, schema=columns, schema_overrides=dtypes, orient="row" - ) + df = pl.from_pandas(data).lazy() + else: + # Create LazyFrame directly + df = pl.LazyFrame(data=data, schema=columns, schema_overrides=dtypes) + if index is not None: if index_cols is not None: if isinstance(index_cols, str): index_cols = [index_cols] - index_df = pl.DataFrame(index, index_cols) + index_df = pl.LazyFrame({col: index for col in index_cols}) else: - index_df = pl.DataFrame(index) - if len(df) != len(index_df) and len(df) == 1: - df = index_df.join(df, how="cross") + index_df = pl.LazyFrame({"index": index}) + + if len(df.schema) == 0: + # Empty LazyFrame case + df = index_df else: - df = index_df.hstack(df) + # Use cross join for single row df or regular join otherwise + df = index_df.join(df, how="cross") + return df def _df_contains( self, - df: pl.DataFrame, + df: pl.LazyFrame, column: str, values: Sequence[Any], - ) -> pl.Series: - return pl.Series("contains", values).is_in(df[column]) + ) -> pl.Expr: + return pl.col(column).is_in(values) def _df_div( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -222,48 +244,36 @@ def _df_div( def _df_drop_columns( self, - df: pl.DataFrame, + df: pl.LazyFrame, columns: str | list[str], - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return df.drop(columns) def _df_drop_duplicates( self, - df: pl.DataFrame, + df: pl.LazyFrame, subset: str | list[str] | None = None, keep: Literal["first", "last", False] = "first", - ) -> pl.DataFrame: + ) -> pl.LazyFrame: # If subset is None, use all columns if subset is None: subset = df.columns - original_col_order = df.columns + if keep == "first": - return ( - df.group_by(subset, maintain_order=True) - .first() - .select(original_col_order) - ) + return df.unique(subset=subset, keep="first") elif keep == "last": - return ( - df.group_by(subset, maintain_order=True) - .last() - .select(original_col_order) - ) + return df.unique(subset=subset, keep="last") else: - return ( - df.with_columns(pl.len().over(subset)) - .filter(pl.col("len") < 2) - .drop("len") - .select(original_col_order) - ) + # For keep=False, drop all duplicates + return df.filter(~pl.col(subset).is_duplicated()) def _df_ge( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -274,48 +284,49 @@ def _df_ge( def _df_get_bool_mask( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_cols: str | list[str] | None = None, mask: PolarsMask = None, negate: bool = False, - ) -> pl.Series | pl.Expr: - def bool_mask_from_series(mask: pl.Series) -> pl.Series: - if ( - isinstance(mask, pl.Series) - and mask.dtype == pl.Boolean - and len(mask) == len(df) - ): - return mask - assert isinstance(index_cols, str) - return df[index_cols].is_in(mask) - - def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: - assert index_cols, list[str] - mask = mask[index_cols].unique() - mask = mask.with_columns(in_it=True) - return df.join(mask, on=index_cols, how="left")["in_it"].fill_null(False) + ) -> pl.Expr: + def bool_mask_from_expr(mask: pl.Expr) -> pl.Expr: + return mask + + def bool_mask_from_lazyframe(mask: pl.LazyFrame) -> pl.Expr: + if index_cols is None: + raise ValueError( + "index_cols must be provided when using LazyFrame mask" + ) - if isinstance(mask, pl.Expr): - result = mask - elif isinstance(mask, pl.Series): - result = bool_mask_from_series(mask) - elif isinstance(mask, pl.DataFrame): - if index_cols in mask.columns: - result = bool_mask_from_series(mask[index_cols]) - elif all(col in mask.columns for col in index_cols): - result = bool_mask_from_df(mask[index_cols]) - elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: - result = bool_mask_from_series(mask[mask.columns[0]]) + if isinstance(index_cols, str): + return pl.col(index_cols).is_in(mask.select(index_cols)) else: - raise KeyError( - f"Mask must have {index_cols} column(s) or a single boolean column." + # For multiple index columns, create an expression to check if in the mask + join_cols = [pl.col(col) for col in index_cols] + return pl.struct(join_cols).is_in(mask.select(index_cols)) + + def bool_mask_from_values(values) -> pl.Expr: + if index_cols is None: + raise ValueError("index_cols must be provided when using value mask") + + if isinstance(index_cols, str): + return pl.col(index_cols).is_in(values) + else: + # This is simplified and may need adjustment for multi-column case + raise NotImplementedError( + "Multi-column masking with raw values not implemented" ) + + if isinstance(mask, pl.Expr): + result = bool_mask_from_expr(mask) + elif isinstance(mask, pl.LazyFrame): + result = bool_mask_from_lazyframe(mask) elif mask is None or mask == "all": - result = pl.Series([True] * len(df)) + result = pl.lit(True) elif isinstance(mask, Collection): - result = bool_mask_from_series(pl.Series(mask)) + result = bool_mask_from_values(mask) else: - result = bool_mask_from_series(pl.Series([mask])) + result = bool_mask_from_values([mask]) if negate: result = ~result @@ -324,32 +335,32 @@ def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: def _df_get_masked_df( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_cols: str | list[str] | None = None, mask: PolarsMask | None = None, columns: list[str] | None = None, negate: bool = False, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: b_mask = self._df_get_bool_mask(df, index_cols, mask, negate=negate) if columns: - return df.filter(b_mask)[columns] + return df.filter(b_mask).select(columns) return df.filter(b_mask) def _df_groupby_cumcount( - self, df: pl.DataFrame, by: str | list[str], name="cum_count" - ) -> pl.Series: - return df.with_columns(pl.cum_count(by).over(by).alias(name))[name] + self, df: pl.LazyFrame, by: str | list[str], name="cum_count" + ) -> pl.Expr: + return pl.cumcount().over(by).alias(name) - def _df_index(self, df: pl.DataFrame, index_col: str | list[str]) -> pl.Series: - return df[index_col] + def _df_index(self, df: pl.LazyFrame, index_col: str | list[str]) -> pl.Expr: + return pl.col(index_col) - def _df_iterator(self, df: pl.DataFrame) -> Iterator[dict[str, Any]]: - return iter(df.iter_rows(named=True)) + def _df_iterator(self, df: pl.LazyFrame) -> Iterator[dict[str, Any]]: + return iter(df.collect().iter_rows(named=True)) def _df_join( self, - left: pl.DataFrame, - right: pl.DataFrame, + left: pl.LazyFrame, + right: pl.LazyFrame, index_cols: str | list[str] | None = None, on: str | list[str] | None = None, left_on: str | list[str] | None = None, @@ -360,7 +371,7 @@ def _df_join( | Literal["outer"] | Literal["cross"] = "left", suffix="_right", - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if how == "outer": how = "full" if how == "right": @@ -373,11 +384,11 @@ def _df_join( def _df_lt( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -388,11 +399,11 @@ def _df_lt( def _df_mod( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -403,11 +414,11 @@ def _df_mod( def _df_mul( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -419,27 +430,27 @@ def _df_mul( @overload def _df_norm( self, - df: pl.DataFrame, + df: pl.LazyFrame, srs_name: str = "norm", include_cols: Literal[False] = False, - ) -> pl.Series: ... + ) -> pl.Expr: ... @overload def _df_norm( self, - df: pl.Series, + df: pl.Expr, srs_name: str = "norm", include_cols: Literal[True] = True, - ) -> pl.DataFrame: ... + ) -> pl.LazyFrame: ... def _df_norm( self, - df: pl.DataFrame, + df: pl.LazyFrame, srs_name: str = "norm", include_cols: bool = False, - ) -> pl.Series | pl.DataFrame: + ) -> pl.Expr | pl.LazyFrame: srs = ( - df.with_columns(pl.col("*").pow(2)).sum_horizontal().sqrt().rename(srs_name) + df.with_columns(pl.col("*").pow(2)).sum_horizontal().sqrt().alias(srs_name) ) if include_cols: return df.with_columns(srs) @@ -447,57 +458,65 @@ def _df_norm( def _df_operation( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], operation: Callable[[pl.Expr, pl.Expr], pl.Expr], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): + ) -> pl.LazyFrame: + if isinstance(other, pl.LazyFrame): if index_cols is not None: + # Join with the other LazyFrame op_df = df.join(other, how="left", on=index_cols, suffix="_op") else: - assert len(df) == len( - other - ), "DataFrames must have the same length if index_cols is not specified" - index_cols = [] - other = other.rename(lambda col: col + "_op") + # Without index cols, assume matching order and do a horizontal concat + other = other.rename({col: f"{col}_op" for col in other.columns}) op_df = pl.concat([df, other], how="horizontal") - return op_df.with_columns( - operation(pl.col(col), pl.col(f"{col}_op")).alias(col) - for col in df.columns - if col not in index_cols - ).select(df.columns) - elif isinstance( - other, (Sequence, pl.Series) - ): # Currently, pl.Series is not a Sequence + + # Apply the operation to matching columns + expr_list = [] + for col in df.columns: + if col not in (index_cols or []): + if f"{col}_op" in op_df.columns: + expr_list.append( + operation(pl.col(col), pl.col(f"{col}_op")).alias(col) + ) + else: + expr_list.append(pl.col(col)) + else: + expr_list.append(pl.col(col)) + + return op_df.with_columns(expr_list).select(df.columns) + elif isinstance(other, (Sequence, pl.Series)): if axis == "index": - assert len(df) == len( - other - ), "Sequence must have the same length as df if axis is 'index'" - other_series = pl.Series("operand", other) - return df.with_columns( - operation(pl.col(col), other_series).alias(col) - for col in df.columns - ) + # Apply operation row-wise + if isinstance(other, pl.Series): + # Convert Series to an expression + other_expr = pl.lit(other.to_list()) + else: + other_expr = pl.lit(list(other)) + + expr_list = [ + operation(pl.col(col), other_expr).alias(col) for col in df.columns + ] + return df.with_columns(expr_list) else: - assert ( - len(df.columns) == len(other) - ), "Sequence must have the same length as df.columns if axis is 'columns'" - return df.with_columns( + # Apply operation column-wise + expr_list = [ operation(pl.col(col), pl.lit(other[i])).alias(col) for i, col in enumerate(df.columns) - ) + ] + return df.with_columns(expr_list) else: - raise ValueError("other must be a DataFrame or a Sequence") + raise ValueError("other must be a LazyFrame or a Sequence") def _df_or( self, - df: pl.DataFrame, - other: pl.DataFrame | Sequence[float | int], + df: pl.LazyFrame, + other: pl.LazyFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -508,17 +527,17 @@ def _df_or( def _df_reindex( self, - df: pl.DataFrame, - other: Sequence[Hashable] | pl.DataFrame, + df: pl.LazyFrame, + other: Sequence[Hashable] | pl.LazyFrame, new_index_cols: str | list[str], original_index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: - # If other is a DataFrame, extract the index columns - if isinstance(other, pl.DataFrame): + ) -> pl.LazyFrame: + # If other is a LazyFrame, extract the index columns + if isinstance(other, pl.LazyFrame): other = other.select(new_index_cols) else: - # If other is a sequence, create a DataFrame with it - other = pl.Series(name=new_index_cols, values=other).to_frame() + # If other is a sequence, create a LazyFrame with it + other = pl.LazyFrame({new_index_cols: other}) # Perform a left join to reindex if original_index_cols is None: @@ -529,16 +548,16 @@ def _df_reindex( return result def _df_rename_columns( - self, df: pl.DataFrame, old_columns: list[str], new_columns: list[str] - ) -> pl.DataFrame: + self, df: pl.LazyFrame, old_columns: list[str], new_columns: list[str] + ) -> pl.LazyFrame: return df.rename(dict(zip(old_columns, new_columns))) def _df_reset_index( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_cols: str | list[str] | None = None, drop: bool = False, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if drop and index_cols is not None: return df.drop(index_cols) else: @@ -546,13 +565,13 @@ def _df_reset_index( def _df_sample( self, - df: pl.DataFrame, + df: pl.LazyFrame, n: int | None = None, frac: float | None = None, with_replacement: bool = False, shuffle: bool = False, seed: int | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return df.sample( n=n, fraction=frac, @@ -563,32 +582,32 @@ def _df_sample( def _df_set_index( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_name: str, new_index: Sequence[Hashable] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if new_index is None: return df return df.with_columns(**{index_name: new_index}) def _df_with_columns( self, - original_df: pl.DataFrame, - data: Sequence | pl.DataFrame | Sequence[Sequence] | dict[str | Any] | Any, + original_df: pl.LazyFrame, + data: Sequence | pl.LazyFrame | Sequence[Sequence] | dict[str | Any] | Any, new_columns: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if ( (isinstance(data, Sequence) and isinstance(data[0], Sequence)) or isinstance( - data, pl.DataFrame - ) # Currently, pl.DataFrame is not a Sequence + data, pl.LazyFrame + ) # Currently, pl.LazyFrame is not a Sequence or ( isinstance(data, dict) and isinstance(data[list(data.keys())[0]], Sequence) ) ): # This means that data is a Sequence of Sequences (rows) - data = pl.DataFrame(data, new_columns) + data = pl.LazyFrame(data, new_columns) original_df = original_df.select(pl.exclude(data.columns)) return original_df.hstack(data) if not isinstance(data, dict): @@ -614,10 +633,10 @@ def _srs_contains( self, srs: Collection[Any], values: Any | Sequence[Any], - ) -> pl.Series: + ) -> pl.Expr: if not isinstance(values, Collection): values = [values] - return pl.Series(values).is_in(srs) + return pl.lit(values).is_in(srs) def _srs_range( self, @@ -626,12 +645,12 @@ def _srs_range( end: int, step: int = 1, ) -> pl.Series: - return pl.arange(start=start, end=end, step=step, eager=True).rename(name) + return pl.arange(start=start, end=end, step=step, eager=True).alias(name) def _srs_to_df( self, srs: pl.Series, index: pl.Series | None = None - ) -> pl.DataFrame: - df = srs.to_frame() + ) -> pl.LazyFrame: + df = srs.to_frame().lazy() if index: return df.with_columns({index.name: index}) return df diff --git a/tests/polars/test_agentset_polars.py b/tests/polars/test_agentset_polars.py index 9c311727..6b1388f5 100644 --- a/tests/polars/test_agentset_polars.py +++ b/tests/polars/test_agentset_polars.py @@ -63,9 +63,14 @@ def test__init__(self): agents = ExampleAgentSetPolars(model) agents.add({"unique_id": [0, 1, 2, 3]}) assert agents.model == model - assert isinstance(agents.agents, pl.DataFrame) - assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3] - assert isinstance(agents._mask, pl.Series) + assert isinstance(agents.agents, pl.LazyFrame) + assert agents.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + ] + assert isinstance(agents._mask, pl.Expr) assert isinstance(agents.random, Generator) assert agents.starting_wealth.to_list() == [1, 2, 3, 4] @@ -79,19 +84,67 @@ def test_add( # Test with a DataFrame result = agents.add(agents2.agents, inplace=False) - assert result.agents["unique_id"].to_list() == [0, 1, 2, 3, 4, 5, 6, 7] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ] # Test with a list (Sequence[Any]) result = agents.add([10, 5, 10], inplace=False) - assert result.agents["unique_id"].to_list() == [0, 1, 2, 3, 10] - assert result.agents["wealth"].to_list() == [1, 2, 3, 4, 5] - assert result.agents["age"].to_list() == [10, 20, 30, 40, 10] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 10, + ] + assert result.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert result.agents.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 10, + ] # Test with a dict[str, Any] agents.add({"unique_id": [4, 5], "wealth": [5, 6], "age": [50, 60]}) - assert agents.agents["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 4, 5] - assert agents.agents["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 4, + 5, + ] + assert agents.agents.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] def test_contains(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -122,7 +175,11 @@ def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): # Test with a single value result = agents.discard(0, inplace=False) - assert result.agents["unique_id"].to_list() == [1, 2, 3] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 1, + 2, + 3, + ] assert result.pos["unique_id"].to_list() == [1, 2, 3] assert result.pos["dim_0"].to_list() == [1, None, None] assert result.pos["dim_1"].to_list() == [1, None, None] @@ -130,14 +187,20 @@ def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): # Test with a list result = agents.discard([0, 1], inplace=False) - assert result.agents["unique_id"].to_list() == [2, 3] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] assert result.pos["unique_id"].to_list() == [2, 3] assert result.pos["dim_0"].to_list() == [None, None] assert result.pos["dim_1"].to_list() == [None, None] # Test with a pl.DataFrame result = agents.discard(pl.DataFrame({"unique_id": [0, 1]}), inplace=False) - assert result.agents["unique_id"].to_list() == [2, 3] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] assert result.pos["unique_id"].to_list() == [2, 3] assert result.pos["dim_0"].to_list() == [None, None] assert result.pos["dim_1"].to_list() == [None, None] @@ -145,29 +208,52 @@ def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): # Test with active_agents agents.active_agents = [0, 1] result = agents.discard("active", inplace=False) - assert result.agents["unique_id"].to_list() == [2, 3] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] assert result.pos["unique_id"].to_list() == [2, 3] assert result.pos["dim_0"].to_list() == [None, None] assert result.pos["dim_1"].to_list() == [None, None] # Test with empty list result = agents.discard([], inplace=False) - assert result.agents["unique_id"].to_list() == [0, 1, 2, 3] + assert result.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + ] def test_do(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with no return_results, no mask agents.do("add_wealth", 1) - assert agents.agents["wealth"].to_list() == [2, 3, 4, 5] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 2, + 3, + 4, + 5, + ] # Test with return_results=True, no mask assert agents.do("add_wealth", 1, return_results=True) is None - assert agents.agents["wealth"].to_list() == [3, 4, 5, 6] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 3, + 4, + 5, + 6, + ] # Test with a mask agents.do("add_wealth", 1, mask=agents["wealth"] > 3) - assert agents.agents["wealth"].to_list() == [3, 5, 6, 7] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 3, + 5, + 6, + 7, + ] def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -179,16 +265,24 @@ def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): result = agents.get(["wealth", "age"]) assert isinstance(result, pl.DataFrame) assert result.columns == ["wealth", "age"] - assert result["wealth"].to_list() == agents.agents["wealth"].to_list() + assert ( + result["wealth"].to_list() + == agents.agents.select("wealth").collect()["wealth"].to_list() + ) # Test with a single attribute and a mask - selected = agents.select(agents.agents["wealth"] > 1, inplace=False) + selected = agents.select( + agents.agents.select("wealth").collect()["wealth"] > 1, inplace=False + ) assert selected.get("wealth", mask="active").to_list() == [2, 3, 4] def test_remove(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.remove([0, 1]) - assert agents.agents["unique_id"].to_list() == [2, 3] + assert agents.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] with pytest.raises(KeyError): agents.remove([1]) @@ -198,76 +292,127 @@ def test_select(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with default arguments. Should select all agents selected = agents.select(inplace=False) assert ( - selected.active_agents["wealth"].to_list() - == agents.agents["wealth"].to_list() + selected.active_agents.select("wealth").collect()["wealth"].to_list() + == agents.agents.select("wealth").collect()["wealth"].to_list() ) # Test with a pl.Series[bool] mask = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) selected = agents.select(mask, inplace=False) - assert selected.active_agents["unique_id"].to_list() == [0, 2, 3] + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 2, 3] # Test with a ListLike mask = [0, 2] selected = agents.select(mask, inplace=False) - assert selected.active_agents["unique_id"].to_list() == [0, 2] + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 2] # Test with a pl.DataFrame mask = pl.DataFrame({"unique_id": [0, 1]}) selected = agents.select(mask, inplace=False) - assert selected.active_agents["unique_id"].to_list() == [0, 1] + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 1] # Test with filter_func def filter_func(agentset: AgentSetPolars) -> pl.Series: - return agentset.agents["wealth"] > 1 + return agentset.agents.select("wealth").collect()["wealth"] > 1 selected = agents.select(filter_func=filter_func, inplace=False) - assert selected.active_agents["unique_id"].to_list() == [1, 2, 3] + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [1, 2, 3] # Test with n selected = agents.select(n=3, inplace=False) - assert len(selected.active_agents) == 3 + assert len(selected.active_agents.collect()) == 3 # Test with n, filter_func and mask mask = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) selected = agents.select(mask, filter_func=filter_func, n=1, inplace=False) - assert any(el in selected.active_agents["unique_id"].to_list() for el in [2, 3]) + assert any( + el + in selected.active_agents.select("unique_id") + .collect()["unique_id"] + .to_list() + for el in [2, 3] + ) def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with a single attribute result = agents.set("wealth", 0, inplace=False) - assert result.agents["wealth"].to_list() == [0, 0, 0, 0] + assert result.agents.select("wealth").collect()["wealth"].to_list() == [ + 0, + 0, + 0, + 0, + ] # Test with a list of attributes result = agents.set(["wealth", "age"], 1, inplace=False) - assert result.agents["wealth"].to_list() == [1, 1, 1, 1] - assert result.agents["age"].to_list() == [1, 1, 1, 1] + assert result.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 1, + 1, + 1, + ] + assert result.agents.select("age").collect()["age"].to_list() == [1, 1, 1, 1] # Test with a single attribute and a mask - selected = agents.select(agents.agents["wealth"] > 1, inplace=False) + selected = agents.select( + agents.agents.select("wealth").collect()["wealth"] > 1, inplace=False + ) selected.set("wealth", 0, mask="active") - assert selected.agents["wealth"].to_list() == [1, 0, 0, 0] + assert selected.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 0, + 0, + 0, + ] # Test with a dictionary agents.set({"wealth": 10, "age": 20}) - assert agents.agents["wealth"].to_list() == [10, 10, 10, 10] - assert agents.agents["age"].to_list() == [20, 20, 20, 20] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 10, + 10, + 10, + 10, + ] + assert agents.agents.select("age").collect()["age"].to_list() == [ + 20, + 20, + 20, + 20, + ] def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars for _ in range(10): - original_order = agents.agents["unique_id"].to_list() + original_order = ( + agents.agents.select("unique_id").collect()["unique_id"].to_list() + ) agents.shuffle() - if original_order != agents.agents["unique_id"].to_list(): + if ( + original_order + != agents.agents.select("unique_id").collect()["unique_id"].to_list() + ): return assert False def test_sort(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.sort("wealth", ascending=False) - assert agents.agents["wealth"].to_list() == [4, 3, 2, 1] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 4, + 3, + 2, + 1, + ] def test__add__( self, @@ -279,19 +424,54 @@ def test__add__( # Test with an AgentSetPolars and a DataFrame agents3 = agents + agents2.agents - assert agents3.agents["unique_id"].to_list() == [0, 1, 2, 3, 4, 5, 6, 7] + assert agents3.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ] # Test with an AgentSetPolars and a list (Sequence[Any]) agents3 = agents + [10, 5, 5] # unique_id, wealth, age - assert agents3.agents["unique_id"].to_list()[:-1] == [0, 1, 2, 3] - assert len(agents3.agents) == 5 - assert agents3.agents["wealth"].to_list() == [1, 2, 3, 4, 5] - assert agents3.agents["age"].to_list() == [10, 20, 30, 40, 5] + assert agents3.agents.select("unique_id").collect()["unique_id"].to_list()[ + :-1 + ] == [0, 1, 2, 3] + assert len(agents3.agents.collect()) == 5 + assert agents3.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert agents3.agents.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 5, + ] # Test with an AgentSetPolars and a dict agents3 = agents + {"unique_id": 10, "wealth": 5} - assert agents3.agents["unique_id"].to_list() == [0, 1, 2, 3, 10] - assert agents3.agents["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents3.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 10, + ] + assert agents3.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] def test__contains__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with a single value @@ -349,21 +529,56 @@ def test__iadd__( # Test with an AgentSetPolars and a DataFrame agents = deepcopy(fix1_AgentSetPolars) agents += agents2.agents - assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 4, 5, 6, 7] + assert agents.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ] # Test with an AgentSetPolars and a list agents = deepcopy(fix1_AgentSetPolars) agents += [10, 5, 5] # unique_id, wealth, age - assert agents.agents["unique_id"].to_list()[:-1] == [0, 1, 2, 3] - assert len(agents.agents) == 5 - assert agents.agents["wealth"].to_list() == [1, 2, 3, 4, 5] - assert agents.agents["age"].to_list() == [10, 20, 30, 40, 5] + assert agents.agents.select("unique_id").collect()["unique_id"].to_list()[ + :-1 + ] == [0, 1, 2, 3] + assert len(agents.agents.collect()) == 5 + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert agents.agents.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 5, + ] # Test with an AgentSetPolars and a dict agents = deepcopy(fix1_AgentSetPolars) agents += {"unique_id": 10, "wealth": 5} - assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 10] - assert agents.agents["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + 10, + ] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] def test__iter__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -375,7 +590,7 @@ def test__isub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with an AgentSetPolars and a DataFrame agents = deepcopy(fix1_AgentSetPolars) agents -= agents.agents - assert agents.agents.is_empty() + assert agents.agents.collect().is_empty() def test__len__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -399,21 +614,46 @@ def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with key=str, value=Any agents["wealth"] = 0 - assert agents.agents["wealth"].to_list() == [0, 0, 0, 0] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 0, + 0, + 0, + 0, + ] # Test with key=list[str], value=Any agents[["wealth", "age"]] = 1 - assert agents.agents["wealth"].to_list() == [1, 1, 1, 1] - assert agents.agents["age"].to_list() == [1, 1, 1, 1] + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 1, + 1, + 1, + ] + assert agents.agents.select("age").collect()["age"].to_list() == [1, 1, 1, 1] # Test with key=tuple, value=Any agents[0, "wealth"] = 5 - assert agents.agents["wealth"].to_list() == [5, 1, 1, 1] + assert ( + agents.agents.select("wealth") + .collect() + .filter(pl.col("unique_id") == 0)["wealth"][0] + == 5 + ) # Test with key=AgentMask, value=Any agents[0] = [9, 99] - assert agents.agents.item(0, "wealth") == 9 - assert agents.agents.item(0, "age") == 99 + assert ( + agents.agents.select(["unique_id", "wealth"]) + .collect() + .filter(pl.col("unique_id") == 0)["wealth"][0] + == 9 + ) + assert ( + agents.agents.select(["unique_id", "age"]) + .collect() + .filter(pl.col("unique_id") == 0)["age"][0] + == 99 + ) def test__str__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents: ExampleAgentSetPolars = fix1_AgentSetPolars @@ -422,8 +662,13 @@ def test__str__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): def test__sub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents: ExampleAgentSetPolars = fix1_AgentSetPolars agents2: ExampleAgentSetPolars = agents - agents.agents - assert agents2.agents.is_empty() - assert agents.agents["wealth"].to_list() == [1, 2, 3, 4] + assert agents2.agents.collect().is_empty() + assert agents.agents.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + ] def test_get_obj(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -437,28 +682,43 @@ def test_agents( ): agents = fix1_AgentSetPolars agents2 = fix2_AgentSetPolars - assert isinstance(agents.agents, pl.DataFrame) + assert isinstance(agents.agents, pl.LazyFrame) # Test agents.setter agents.agents = agents2.agents - assert agents.agents["unique_id"].to_list() == [4, 5, 6, 7] + assert agents.agents.select("unique_id").collect()["unique_id"].to_list() == [ + 4, + 5, + 6, + 7, + ] def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with select - agents.select(agents.agents["wealth"] > 2, inplace=True) - assert agents.active_agents["unique_id"].to_list() == [2, 3] + agents.select( + agents.agents.select("wealth").collect()["wealth"] > 2, inplace=True + ) + assert agents.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [2, 3] # Test with active_agents.setter - agents.active_agents = agents.agents["wealth"] > 2 - assert agents.active_agents["unique_id"].to_list() == [2, 3] + agents.active_agents = agents.agents.select("wealth").collect()["wealth"] > 2 + assert agents.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [2, 3] def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars - agents.select(agents.agents["wealth"] > 2, inplace=True) - assert agents.inactive_agents["unique_id"].to_list() == [0, 1] + agents.select( + agents.agents.select("wealth").collect()["wealth"] > 2, inplace=True + ) + assert agents.inactive_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 1] def test_pos(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): pos = fix1_AgentSetPolars_with_pos.pos diff --git a/tests/polars/test_mixin_polars.py b/tests/polars/test_mixin_polars.py index d1ec3e60..e122a72a 100644 --- a/tests/polars/test_mixin_polars.py +++ b/tests/polars/test_mixin_polars.py @@ -15,7 +15,7 @@ def mixin(self): @pytest.fixture def df_0(self): - return pl.DataFrame( + return pl.LazyFrame( { "unique_id": ["x", "y", "z"], "A": [1, 2, 3], @@ -27,7 +27,7 @@ def df_0(self): @pytest.fixture def df_1(self): - return pl.DataFrame( + return pl.LazyFrame( { "unique_id": ["z", "a", "b"], "A": [4, 5, 6], @@ -37,33 +37,33 @@ def df_1(self): }, ) - def test_df_add(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_add(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test adding a DataFrame and a sequence element-wise along the rows (axis='index') - result = mixin._df_add(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_add(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [5, 7, 9] assert result["D"].to_list() == [5, 7, 9] # Test adding a DataFrame and a sequence element-wise along the column (axis='columns') - result = mixin._df_add(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_add(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [2, 3, 4] assert result["D"].to_list() == [3, 4, 5] # Test adding DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_add( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 7] assert result["D"].to_list() == [None, None, 4] def test_df_all(self, mixin: PolarsMixin): - df = pl.DataFrame( + df = pl.LazyFrame( { "A": [True, False, True], "B": [True, True, True], @@ -71,28 +71,32 @@ def test_df_all(self, mixin: PolarsMixin): ) # Test with axis='columns' - result = mixin._df_all(df["A", "B"], axis="columns") + result = mixin._df_all(df["A", "B"], axis="columns").collect() assert isinstance(result, pl.Series) assert result.name == "all" assert result.to_list() == [True, False, True] # Test with axis='index' - result = mixin._df_all(df["A", "B"], axis="index") + result = mixin._df_all(df["A", "B"], axis="index").collect() assert isinstance(result, pl.Series) assert result.name == "all" assert result.to_list() == [False, True] - def test_df_and(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_and(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - df_0 = df_0.with_columns(F=pl.Series([True, True, False])) - df_1 = df_1.with_columns(F=pl.Series([False, False, True])) - result = mixin._df_and(df_0[["C", "F"]], df_1["F"], axis="index") + df_0_with_f = df_0.with_columns(F=pl.lit([True, True, False])) + df_1_with_f = df_1.with_columns(F=pl.lit([False, False, True])) + result = mixin._df_and( + df_0_with_f[["C", "F"]], df_1_with_f["F"], axis="index" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [False, False, True] assert result["F"].to_list() == [False, False, False] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_and(df_0[["C", "F"]], [True, False], axis="columns") + result = mixin._df_and( + df_0_with_f[["C", "F"]], [True, False], axis="columns" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, False, True] assert result["F"].to_list() == [False, False, False] @@ -103,22 +107,22 @@ def test_df_and(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame df_1[["unique_id", "C", "F"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [None, False, False] assert result["F"].to_list() == [None, None, False] - def test_df_column_names(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_column_names(self, mixin: PolarsMixin, df_0: pl.LazyFrame): cols = mixin._df_column_names(df_0) assert isinstance(cols, list) assert all(isinstance(c, str) for c in cols) assert set(mixin._df_column_names(df_0)) == {"unique_id", "A", "B", "C", "D"} def test_df_combine_first( - self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame ): # Test with df_0 and df_1 - result = mixin._df_combine_first(df_0, df_1, "unique_id") + result = mixin._df_combine_first(df_0, df_1, "unique_id").collect() result = result.sort("A") assert isinstance(result, pl.DataFrame) assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} @@ -130,7 +134,7 @@ def test_df_combine_first( assert result["E"].to_list() == [None, None, 1, 2, 3] # Test with df_1 and df_0 - result = mixin._df_combine_first(df_1, df_0, "unique_id") + result = mixin._df_combine_first(df_1, df_0, "unique_id").collect() result = result.sort("E", nulls_last=True) assert isinstance(result, pl.DataFrame) assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} @@ -142,14 +146,14 @@ def test_df_combine_first( assert result["E"].to_list() == [1, 2, 3, None, None] def test_df_concat( - self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame ): ### Test vertical concatenation ## With DataFrames for ignore_index in [False, True]: vertical = mixin._df_concat( [df_0, df_1], how="vertical", ignore_index=ignore_index - ) + ).collect() assert isinstance(vertical, pl.DataFrame) assert vertical.columns == ["unique_id", "A", "B", "C", "D", "E"] assert len(vertical) == 6 @@ -164,7 +168,7 @@ def test_df_concat( for ignore_index in [True, False]: vertical = mixin._df_concat( [df_0["A"], df_1["A"]], how="vertical", ignore_index=ignore_index - ) + ).collect() assert isinstance(vertical, pl.Series) assert len(vertical) == 6 assert vertical.to_list() == [1, 2, 3, 4, 5, 6] @@ -174,10 +178,10 @@ def test_df_concat( ## With DataFrames # Error With same column names with pytest.raises(pl.exceptions.DuplicateError): - mixin._df_concat([df_0, df_1], how="horizontal") + mixin._df_concat([df_0, df_1], how="horizontal").collect() # With ignore_index = False - df_1 = df_1.rename(lambda c: f"{c}_1") - horizontal = mixin._df_concat([df_0, df_1], how="horizontal") + df_1_renamed = df_1.rename(lambda c: f"{c}_1") + horizontal = mixin._df_concat([df_0, df_1_renamed], how="horizontal").collect() assert isinstance(horizontal, pl.DataFrame) assert horizontal.columns == [ "unique_id", @@ -205,10 +209,10 @@ def test_df_concat( # With ignore_index = True horizontal_ignore_index = mixin._df_concat( - [df_0, df_1], + [df_0, df_1_renamed], how="horizontal", ignore_index=True, - ) + ).collect() assert isinstance(horizontal_ignore_index, pl.DataFrame) assert horizontal_ignore_index.columns == [ "0", @@ -237,8 +241,8 @@ def test_df_concat( ## With Series # With ignore_index = False horizontal = mixin._df_concat( - [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=False - ) + [df_0["A"], df_1_renamed["B_1"]], how="horizontal", ignore_index=False + ).collect() assert isinstance(horizontal, pl.DataFrame) assert horizontal.columns == ["A", "B_1"] assert len(horizontal) == 3 @@ -247,8 +251,8 @@ def test_df_concat( # With ignore_index = True horizontal = mixin._df_concat( - [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=True - ) + [df_0["A"], df_1_renamed["B_1"]], how="horizontal", ignore_index=True + ).collect() assert isinstance(horizontal, pl.DataFrame) assert horizontal.columns == ["0", "1"] assert len(horizontal) == 3 @@ -258,7 +262,7 @@ def test_df_concat( def test_df_constructor(self, mixin: PolarsMixin): # Test with dictionary data = {"num": [1, 2, 3], "letter": ["a", "b", "c"]} - df = mixin._df_constructor(data) + df = mixin._df_constructor(data).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["num", "letter"] assert df["num"].to_list() == [1, 2, 3] @@ -268,7 +272,7 @@ def test_df_constructor(self, mixin: PolarsMixin): data = [[1, "a"], [2, "b"], [3, "c"]] df = mixin._df_constructor( data, columns=["num", "letter"], dtypes={"num": "int64"} - ) + ).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["num", "letter"] assert df["num"].dtype == pl.Int64 @@ -277,7 +281,7 @@ def test_df_constructor(self, mixin: PolarsMixin): # Test with pandas DataFrame data = pd.DataFrame({"num": [1, 2, 3], "letter": ["a", "b", "c"]}) - df = mixin._df_constructor(data) + df = mixin._df_constructor(data).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["index", "num", "letter"] assert df["index"].to_list() == [0, 1, 2] @@ -288,65 +292,65 @@ def test_df_constructor(self, mixin: PolarsMixin): data = {"a": 5} df = mixin._df_constructor( data, index=pl.int_range(5, eager=True), index_cols="index" - ) + ).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["index", "a"] assert df["a"].to_list() == [5, 5, 5, 5, 5] assert df["index"].to_list() == [0, 1, 2, 3, 4] - def test_df_contains(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_contains(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with list - result = mixin._df_contains(df_0, "A", [5, 2, 3]) + result = mixin._df_contains(df_0, "A", [5, 2, 3]).collect() assert isinstance(result, pl.Series) assert result.name == "contains" assert result.to_list() == [False, True, True] - def test_df_div(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_div(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test dividing the DataFrame by a sequence element-wise along the rows (axis='index') - result = mixin._df_div(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_div(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [0.25, 0.4, 0.5] assert result["D"].to_list() == [0.25, 0.4, 0.5] # Test dividing the DataFrame by a sequence element-wise along the columns (axis='columns') - result = mixin._df_div(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_div(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [1, 2, 3] assert result["D"].to_list() == [0.5, 1, 1.5] # Test dividing DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_div( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 0.75] assert result["D"].to_list() == [None, None, 3] - def test_df_drop_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_drop_columns(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with str - dropped = mixin._df_drop_columns(df_0, "A") + dropped = mixin._df_drop_columns(df_0, "A").collect() assert isinstance(dropped, pl.DataFrame) assert dropped.columns == ["unique_id", "B", "C", "D"] # Test with list - dropped = mixin._df_drop_columns(df_0, ["A", "C"]) + dropped = mixin._df_drop_columns(df_0, ["A", "C"]).collect() assert dropped.columns == ["unique_id", "B", "D"] - def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.LazyFrame): new_df = pl.concat([df_0, df_0], how="vertical") assert len(new_df) == 6 # Test with all columns - dropped = mixin._df_drop_duplicates(new_df) + dropped = mixin._df_drop_duplicates(new_df).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 3 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] # Test with subset (str) - other_df = pl.DataFrame( + other_df = pl.LazyFrame( { "unique_id": ["x", "y", "z"], "A": [1, 2, 3], @@ -356,156 +360,164 @@ def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.DataFrame): }, ) new_df = pl.concat([df_0, other_df], how="vertical") - dropped = mixin._df_drop_duplicates(new_df, subset="unique_id") + dropped = mixin._df_drop_duplicates(new_df, subset="unique_id").collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 3 # Test with subset (list) - dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"]) + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"]).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 5 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] assert dropped["B"].to_list() == ["a", "b", "c", "e", "f"] # Test with subset (list) and keep='last' - dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep="last") + dropped = mixin._df_drop_duplicates( + new_df, subset=["A", "C"], keep="last" + ).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 5 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] assert dropped["B"].to_list() == ["d", "b", "c", "e", "f"] # Test with subset (list) and keep=False - dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep=False) + dropped = mixin._df_drop_duplicates( + new_df, subset=["A", "C"], keep=False + ).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 4 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] assert dropped["B"].to_list() == ["b", "c", "e", "f"] - def test_df_ge(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_ge(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - result = mixin._df_ge(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_ge(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [False, False, False] assert result["D"].to_list() == [False, False, False] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_ge(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_ge(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [True, True, True] assert result["D"].to_list() == [False, True, True] # Test comparing DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_ge( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, False] assert result["D"].to_list() == [None, None, True] - def test_df_get_bool_mask(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_get_bool_mask(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with pl.Series[bool] - mask = mixin._df_get_bool_mask(df_0, "A", pl.Series([True, False, True])) + mask = mixin._df_get_bool_mask( + df_0, "A", pl.Series([True, False, True]) + ).collect() assert mask.to_list() == [True, False, True] # Test with DataFrame - mask_df = pl.DataFrame({"A": [1, 3]}) - mask = mixin._df_get_bool_mask(df_0, "A", mask_df) + mask_df = pl.LazyFrame({"A": [1, 3]}) + mask = mixin._df_get_bool_mask(df_0, "A", mask_df).collect() assert mask.to_list() == [True, False, True] # Test with single value - mask = mixin._df_get_bool_mask(df_0, "A", 1) + mask = mixin._df_get_bool_mask(df_0, "A", 1).collect() assert mask.to_list() == [True, False, False] # Test with list of values - mask = mixin._df_get_bool_mask(df_0, "A", [1, 3]) + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3]).collect() assert mask.to_list() == [True, False, True] # Test with negate=True - mask = mixin._df_get_bool_mask(df_0, "A", [1, 3], negate=True) + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3], negate=True).collect() assert mask.to_list() == [False, True, False] - def test_df_get_masked_df(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_get_masked_df(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with pl.Series[bool] - masked_df = mixin._df_get_masked_df(df_0, "A", pl.Series([True, False, True])) + masked_df = mixin._df_get_masked_df( + df_0, "A", pl.Series([True, False, True]) + ).collect() assert masked_df["A"].to_list() == [1, 3] assert masked_df["unique_id"].to_list() == ["x", "z"] # Test with DataFrame - mask_df = pl.DataFrame({"A": [1, 3]}) - masked_df = mixin._df_get_masked_df(df_0, "A", mask_df) + mask_df = pl.LazyFrame({"A": [1, 3]}) + masked_df = mixin._df_get_masked_df(df_0, "A", mask_df).collect() assert masked_df["A"].to_list() == [1, 3] assert masked_df["unique_id"].to_list() == ["x", "z"] # Test with single value - masked_df = mixin._df_get_masked_df(df_0, "A", 1) + masked_df = mixin._df_get_masked_df(df_0, "A", 1).collect() assert masked_df["A"].to_list() == [1] assert masked_df["unique_id"].to_list() == ["x"] # Test with list of values - masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3]) + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3]).collect() assert masked_df["A"].to_list() == [1, 3] assert masked_df["unique_id"].to_list() == ["x", "z"] # Test with columns - masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3], columns=["B"]) + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3], columns=["B"]).collect() assert list(masked_df.columns) == ["B"] assert masked_df["B"].to_list() == ["a", "c"] # Test with negate=True - masked = mixin._df_get_masked_df(df_0, "A", [1, 3], negate=True) + masked = mixin._df_get_masked_df(df_0, "A", [1, 3], negate=True).collect() assert len(masked) == 1 - def test_df_groupby_cumcount(self, df_0: pl.DataFrame, mixin: PolarsMixin): - result = mixin._df_groupby_cumcount(df_0, "C") + def test_df_groupby_cumcount(self, df_0: pl.LazyFrame, mixin: PolarsMixin): + result = mixin._df_groupby_cumcount(df_0, "C").collect() assert result.to_list() == [1, 1, 2] - def test_df_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): - index = mixin._df_index(df_0, "unique_id") + def test_df_index(self, mixin: PolarsMixin, df_0: pl.LazyFrame): + index = mixin._df_index(df_0, "unique_id").collect() assert isinstance(index, pl.Series) assert index.to_list() == ["x", "y", "z"] - def test_df_iterator(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_iterator(self, mixin: PolarsMixin, df_0: pl.LazyFrame): iterator = mixin._df_iterator(df_0) first_item = next(iterator) assert first_item == {"unique_id": "x", "A": 1, "B": "a", "C": True, "D": 1} def test_df_join(self, mixin: PolarsMixin): - left = pl.DataFrame({"A": [1, 2], "B": ["a", "b"]}) - right = pl.DataFrame({"A": [1, 3], "C": ["x", "y"]}) + left = pl.LazyFrame({"A": [1, 2], "B": ["a", "b"]}) + right = pl.LazyFrame({"A": [1, 3], "C": ["x", "y"]}) # Test with 'on' (left join) - joined = mixin._df_join(left, right, on="A") + joined = mixin._df_join(left, right, on="A").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1, 2] # Test with 'left_on' and 'right_on' (left join) - right_1 = pl.DataFrame({"D": [1, 2], "C": ["x", "y"]}) - joined = mixin._df_join(left, right_1, left_on="A", right_on="D") + right_1 = pl.LazyFrame({"D": [1, 2], "C": ["x", "y"]}) + joined = mixin._df_join(left, right_1, left_on="A", right_on="D").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1, 2] # Test with 'right' join - joined = mixin._df_join(left, right, on="A", how="right") + joined = mixin._df_join(left, right, on="A", how="right").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1, 3] # Test with 'inner' join - joined = mixin._df_join(left, right, on="A", how="inner") + joined = mixin._df_join(left, right, on="A", how="inner").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1] # Test with 'outer' join - joined = mixin._df_join(left, right, on="A", how="outer") + joined = mixin._df_join(left, right, on="A", how="outer").collect() assert set(joined.columns) == {"A", "B", "A_right", "C"} assert joined["A"].to_list() == [1, None, 2] assert joined["A_right"].to_list() == [1, 3, None] # Test with 'cross' join - joined = mixin._df_join(left, right, how="cross") + joined = mixin._df_join(left, right, how="cross").collect() assert set(joined.columns) == {"A", "B", "A_right", "C"} assert len(joined) == 4 assert joined.row(0) == (1, "a", 1, "x") @@ -514,7 +526,7 @@ def test_df_join(self, mixin: PolarsMixin): assert joined.row(3) == (2, "b", 3, "y") # Test with different 'suffix' - joined = mixin._df_join(left, right, suffix="_r", how="cross") + joined = mixin._df_join(left, right, suffix="_r", how="cross").collect() assert set(joined.columns) == {"A", "B", "A_r", "C"} assert len(joined) == 4 assert joined.row(0) == (1, "a", 1, "x") @@ -522,109 +534,113 @@ def test_df_join(self, mixin: PolarsMixin): assert joined.row(2) == (2, "b", 1, "x") assert joined.row(3) == (2, "b", 3, "y") - def test_df_lt(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_lt(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - result = mixin._df_lt(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_lt(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [True, True, True] assert result["D"].to_list() == [True, True, True] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_lt(df_0[["A", "D"]], [2, 3], axis="columns") + result = mixin._df_lt(df_0[["A", "D"]], [2, 3], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [True, False, False] assert result["D"].to_list() == [True, True, False] # Test comparing DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_lt( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, True] assert result["D"].to_list() == [None, None, False] - def test_df_mod(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_mod(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test taking the modulo of the DataFrame by a sequence element-wise along the rows (axis='index') - result = mixin._df_mod(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_mod(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [1, 2, 3] assert result["D"].to_list() == [1, 2, 3] # Test taking the modulo of the DataFrame by a sequence element-wise along the columns (axis='columns') - result = mixin._df_mod(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_mod(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [0, 0, 0] assert result["D"].to_list() == [1, 0, 1] # Test taking the modulo of DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_mod( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 3] assert result["D"].to_list() == [None, None, 0] - def test_df_mul(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_mul(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test multiplying the DataFrame by a sequence element-wise along the rows (axis='index') - result = mixin._df_mul(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_mul(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [4, 10, 18] assert result["D"].to_list() == [4, 10, 18] # Test multiplying the DataFrame by a sequence element-wise along the columns (axis='columns') - result = mixin._df_mul(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_mul(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [1, 2, 3] assert result["D"].to_list() == [2, 4, 6] # Test multiplying DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_mul( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 12] assert result["D"].to_list() == [None, None, 3] def test_df_norm(self, mixin: PolarsMixin): - df = pl.DataFrame({"A": [3, 4], "B": [4, 3]}) + df = pl.LazyFrame({"A": [3, 4], "B": [4, 3]}) # If include_cols = False - norm = mixin._df_norm(df) + norm = mixin._df_norm(df).collect() assert isinstance(norm, pl.Series) assert len(norm) == 2 assert norm[0] == 5 assert norm[1] == 5 # If include_cols = True - norm = mixin._df_norm(df, include_cols=True) + norm = mixin._df_norm(df, include_cols=True).collect() assert isinstance(norm, pl.DataFrame) assert len(norm) == 2 assert norm.columns == ["A", "B", "norm"] assert norm.row(0, named=True)["norm"] == 5 assert norm.row(1, named=True)["norm"] == 5 - def test_df_or(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_or(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - df_0 = df_0.with_columns(F=pl.Series([True, True, False])) - df_1 = df_1.with_columns(F=pl.Series([False, False, True])) - result = mixin._df_or(df_0[["C", "F"]], df_1["F"], axis="index") + df_0_with_f = df_0.with_columns(F=pl.lit([True, True, False])) + df_1_with_f = df_1.with_columns(F=pl.lit([False, False, True])) + result = mixin._df_or( + df_0_with_f[["C", "F"]], df_1_with_f["F"], axis="index" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, False, True] assert result["F"].to_list() == [True, True, True] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_or(df_0[["C", "F"]], [True, False], axis="columns") + result = mixin._df_or( + df_0_with_f[["C", "F"]], [True, False], axis="columns" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, True, True] assert result["F"].to_list() == [True, True, False] @@ -635,16 +651,16 @@ def test_df_or(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame) df_1[["unique_id", "C", "F"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, None, True] assert result["F"].to_list() == [True, True, False] def test_df_reindex( - self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame ): # Test with DataFrame - reindexed = mixin._df_reindex(df_0, df_1, "unique_id") + reindexed = mixin._df_reindex(df_0, df_1, "unique_id").collect() assert isinstance(reindexed, pl.DataFrame) assert reindexed["unique_id"].to_list() == ["z", "a", "b"] assert reindexed["A"].to_list() == [3, None, None] @@ -653,7 +669,7 @@ def test_df_reindex( assert reindexed["D"].to_list() == [3, None, None] # Test with list - reindexed = mixin._df_reindex(df_0, ["z", "a", "b"], "unique_id") + reindexed = mixin._df_reindex(df_0, ["z", "a", "b"], "unique_id").collect() assert isinstance(reindexed, pl.DataFrame) assert reindexed["unique_id"].to_list() == ["z", "a", "b"] assert reindexed["A"].to_list() == [3, None, None] @@ -667,7 +683,7 @@ def test_df_reindex( ["z", "a", "b"], new_index_cols="new_index", original_index_cols="unique_id", - ) + ).collect() assert isinstance(reindexed, pl.DataFrame) assert reindexed["new_index"].to_list() == ["z", "a", "b"] assert reindexed["A"].to_list() == [3, None, None] @@ -675,61 +691,63 @@ def test_df_reindex( assert reindexed["C"].to_list() == [True, None, None] assert reindexed["D"].to_list() == [3, None, None] - def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): - renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"]) + def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.LazyFrame): + renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"]).collect() assert renamed.columns == ["unique_id", "X", "Y", "C", "D"] - def test_df_reset_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_reset_index(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # with drop = False - new_df = mixin._df_reset_index(df_0) + new_df = mixin._df_reset_index(df_0).collect() assert mixin._df_all(new_df == df_0).all() # with drop = True - new_df = mixin._df_reset_index(df_0, index_cols="unique_id", drop=True) + new_df = mixin._df_reset_index( + df_0, index_cols="unique_id", drop=True + ).collect() assert new_df.columns == ["A", "B", "C", "D"] assert len(new_df) == len(df_0) for col in new_df.columns: assert (new_df[col] == df_0[col]).all() - def test_df_remove(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_remove(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with list - removed = mixin._df_remove(df_0, [1, 3], "A") + removed = mixin._df_remove(df_0, [1, 3], "A").collect() assert len(removed) == 1 assert removed["unique_id"].to_list() == ["y"] - def test_df_sample(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_sample(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with n - sampled = mixin._df_sample(df_0, n=2, seed=42) + sampled = mixin._df_sample(df_0, n=2, seed=42).collect() assert len(sampled) == 2 # Test with frac - sampled = mixin._df_sample(df_0, frac=2 / 3, seed=42) + sampled = mixin._df_sample(df_0, frac=2 / 3, seed=42).collect() assert len(sampled) == 2 # Test with replacement - sampled = mixin._df_sample(df_0, n=4, with_replacement=True, seed=42) + sampled = mixin._df_sample(df_0, n=4, with_replacement=True, seed=42).collect() assert len(sampled) == 4 assert sampled.n_unique() < 4 - def test_df_set_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_set_index(self, mixin: PolarsMixin, df_0: pl.LazyFrame): index = pl.int_range(len(df_0), eager=True) - new_df = mixin._df_set_index(df_0, "index", index) + new_df = mixin._df_set_index(df_0, "index", index).collect() assert (new_df["index"] == index).all() - def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with list new_df = mixin._df_with_columns( df_0, data=[[4, "d"], [5, "e"], [6, "f"]], new_columns=["D", "E"], - ) + ).collect() assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] assert new_df["D"].to_list() == [4, 5, 6] assert new_df["E"].to_list() == ["d", "e", "f"] # Test with pl.DataFrame - second_df = pl.DataFrame({"D": [4, 5, 6], "E": ["d", "e", "f"]}) - new_df = mixin._df_with_columns(df_0, second_df) + second_df = pl.LazyFrame({"D": [4, 5, 6], "E": ["d", "e", "f"]}) + new_df = mixin._df_with_columns(df_0, second_df).collect() assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] assert new_df["D"].to_list() == [4, 5, 6] assert new_df["E"].to_list() == ["d", "e", "f"] @@ -737,18 +755,22 @@ def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): # Test with dictionary new_df = mixin._df_with_columns( df_0, data={"D": [4, 5, 6], "E": ["d", "e", "f"]} - ) + ).collect() assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] assert new_df["D"].to_list() == [4, 5, 6] assert new_df["E"].to_list() == ["d", "e", "f"] # Test with numpy array - new_df = mixin._df_with_columns(df_0, data=np.array([4, 5, 6]), new_columns="D") + new_df = mixin._df_with_columns( + df_0, data=np.array([4, 5, 6]), new_columns="D" + ).collect() assert "D" in new_df.columns assert new_df["D"].to_list() == [4, 5, 6] # Test with pl.Series - new_df = mixin._df_with_columns(df_0, pl.Series([4, 5, 6]), new_columns="D") + new_df = mixin._df_with_columns( + df_0, pl.Series([4, 5, 6]), new_columns="D" + ).collect() assert "D" in new_df.columns assert new_df["D"].to_list() == [4, 5, 6] @@ -767,29 +789,29 @@ def test_srs_contains(self, mixin: PolarsMixin): srs = [1, 2, 3, 4, 5] # Test with single value - result = mixin._srs_contains(srs, 3) + result = mixin._srs_contains(srs, 3).collect() assert result.to_list() == [True] # Test with list - result = mixin._srs_contains(srs, [1, 3, 6]) + result = mixin._srs_contains(srs, [1, 3, 6]).collect() assert result.to_list() == [True, True, False] # Test with numpy array - result = mixin._srs_contains(srs, np.array([1, 3, 6])) + result = mixin._srs_contains(srs, np.array([1, 3, 6])).collect() assert result.to_list() == [True, True, False] def test_srs_range(self, mixin: PolarsMixin): # Test with default step - srs = mixin._srs_range("test", 0, 5) + srs = mixin._srs_range("test", 0, 5).collect() assert srs.name == "test" assert srs.to_list() == [0, 1, 2, 3, 4] # Test with custom step - srs = mixin._srs_range("test", 0, 10, step=2) + srs = mixin._srs_range("test", 0, 10, step=2).collect() assert srs.to_list() == [0, 2, 4, 6, 8] def test_srs_to_df(self, mixin: PolarsMixin): srs = pl.Series("test", [1, 2, 3]) - df = mixin._srs_to_df(srs) + df = mixin._srs_to_df(srs).collect() assert isinstance(df, pl.DataFrame) assert df["test"].to_list() == [1, 2, 3] From ac19ddc68736f34ab10b36bdfa3405b8bb3cf1e5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:14:02 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/agentset.py | 34 ++----- tests/test_agentset.py | 148 ++++++++++++++++++++++++++----- 2 files changed, 134 insertions(+), 48 deletions(-) diff --git a/mesa_frames/concrete/agentset.py b/mesa_frames/concrete/agentset.py index 0954b2ae..9fe33323 100644 --- a/mesa_frames/concrete/agentset.py +++ b/mesa_frames/concrete/agentset.py @@ -192,11 +192,7 @@ def get( if attr_names is None: # Return all columns except unique_id return masked_df.select(pl.exclude("unique_id")).collect() - attr_names = ( - self._df.select(attr_names).columns.copy() - if attr_names - else [] - ) + attr_names = self._df.select(attr_names).columns.copy() if attr_names else [] if not attr_names: return masked_df.collect() masked_df = masked_df.select(attr_names).collect() @@ -258,9 +254,9 @@ def process_single_attr( unique_id_column = None unique_id_column = None if "unique_id" not in obj._df: - unique_id_column = self._generate_unique_ids(len(masked_df.collect())).alias( - "unique_id" - ) + unique_id_column = self._generate_unique_ids( + len(masked_df.collect()) + ).alias("unique_id") obj._df = obj._df.with_columns(unique_id_column) masked_df = masked_df.with_columns(unique_id_column) b_mask = obj._get_bool_mask(mask) @@ -286,10 +282,7 @@ def select( if n is not None: # Need to collect for sampling sample_ids = obj._df.filter(mask).collect().sample(n)["unique_id"] - mask = ( - (obj._df.collect()["unique_id"]) - .is_in(sample_ids) - ) + mask = (obj._df.collect()["unique_id"]).is_in(sample_ids) if negate: mask = mask.not_() obj._mask = mask @@ -347,12 +340,8 @@ def _concatenate_agentsets( if len(agentset) == max_length: original_index = agentset._df.collect()["unique_id"] final_dfs = [self._df] - final_active_indices = [ - self._df.filter(self._mask).collect()["unique_id"] - ] - final_indices = ( - self._df.collect()["unique_id"].clone() - ) + final_active_indices = [self._df.filter(self._mask).collect()["unique_id"]] + final_indices = self._df.collect()["unique_id"].clone() for obj in iter(agentsets): # Remove agents that are already in the final DataFrame final_dfs.append( @@ -382,10 +371,7 @@ def _concatenate_agentsets( final_active_index = pl.concat( [obj._df.filter(obj._mask).collect()["unique_id"] for obj in agentsets] ) - final_mask = ( - final_df.collect()["unique_id"] - .is_in(final_active_index) - ) + final_mask = final_df.collect()["unique_id"].is_in(final_active_index) self._df = final_df self._mask = final_mask # If some ids were removed in the do-method, we need to remove them also from final_df @@ -493,9 +479,7 @@ def _discard(self, ids: PolarsIdsLike) -> Self: mask = self._get_bool_mask(ids) if isinstance(self._mask, pl.Series): - original_active_indices = self._df.filter(self._mask).collect()[ - "unique_id" - ] + original_active_indices = self._df.filter(self._mask).collect()["unique_id"] self._df = self._df.filter(mask.not_()) diff --git a/tests/test_agentset.py b/tests/test_agentset.py index d042b6ce..dd37a20d 100644 --- a/tests/test_agentset.py +++ b/tests/test_agentset.py @@ -113,18 +113,58 @@ def test_add( result = agents.add( pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}), inplace=False ) - assert result.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert result.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert result.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test with a list (Sequence[Any]) result = agents.add([5, 10], inplace=False) - assert result.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5] - assert result.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 10] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert result.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 10, + ] # Test with a dict[str, Any] agents.add({"wealth": [5, 6], "age": [50, 60]}) - assert agents.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test ValueError for dictionary with unique_id key (Line 131) with pytest.raises( @@ -405,7 +445,12 @@ def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with Collection values (Line 213) - using list as Collection result = agents.set("wealth", [100, 200, 300, 400], inplace=False) - assert result.df.select("wealth").collect()["wealth"].to_list() == [100, 200, 300, 400] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 100, + 200, + 300, + 400, + ] def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -439,19 +484,54 @@ def test__add__( # Test with an AgentSetPolars and a DataFrame agents3 = agents + pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}) - assert agents3.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents3.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents3.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents3.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test with an AgentSetPolars and a list (Sequence[Any]) agents3 = agents + [5, 5] # unique_id, wealth, age - assert all(agents3.df.select("unique_id").collect()["unique_id"].to_list()[:-1] == agents["unique_id"]) + assert all( + agents3.df.select("unique_id").collect()["unique_id"].to_list()[:-1] + == agents["unique_id"] + ) assert len(agents3.df.collect()) == 5 - assert agents3.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5] - assert agents3.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 5] + assert agents3.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert agents3.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 5, + ] # Test with an AgentSetPolars and a dict agents3 = agents + {"age": 10, "wealth": 5} - assert agents3.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents3.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] def test__contains__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with a single value @@ -505,8 +585,22 @@ def test__iadd__( # Test with an AgentSetPolars and a DataFrame agents = deepcopy(fix1_AgentSetPolars) agents += pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}) - assert agents.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test with an AgentSetPolars and a list agents = deepcopy(fix1_AgentSetPolars) @@ -516,13 +610,25 @@ def test__iadd__( == fix1_AgentSetPolars["unique_id"][0, 1, 2, 3] ) assert len(agents.df.collect()) == 5 - assert agents.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] assert agents.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 5] # Test with an AgentSetPolars and a dict agents = deepcopy(fix1_AgentSetPolars) agents += {"age": 10, "wealth": 5} - assert agents.df.select("wealth").collect()["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] def test__iter__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -626,9 +732,7 @@ def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with select - agents.select( - agents.df.select("wealth").collect()["wealth"] > 2, inplace=True - ) + agents.select(agents.df.select("wealth").collect()["wealth"] > 2, inplace=True) assert agents.active_agents.select("unique_id").collect()[ "unique_id" ].to_list() == [2, 3] @@ -642,9 +746,7 @@ def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars - agents.select( - agents.df.select("wealth").collect()["wealth"] > 2, inplace=True - ) + agents.select(agents.df.select("wealth").collect()["wealth"] > 2, inplace=True) assert agents.inactive_agents.select("unique_id").collect()[ "unique_id" ].to_list() == [0, 1]