Skip to content

Commit eccd744

Browse files
added docstrings
1 parent d26561e commit eccd744

File tree

2 files changed

+123
-19
lines changed

2 files changed

+123
-19
lines changed

mesa_frames/abstract/space.py

Lines changed: 120 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
5959
from typing_extensions import Any, Self
6060

6161

62-
from mesa_frames.concrete.polars.agentset import AgentSetPolars
6362
from mesa_frames.concrete.agents import AgentsDF
6463
from mesa_frames.abstract.agents import AgentContainer, AgentSetDF
6564
from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
@@ -77,12 +76,13 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
7776
Series,
7877
SpaceCoordinate,
7978
SpaceCoordinates,
79+
AgentLike
8080
)
8181

8282
ESPG = int
8383

8484

85-
AgentLike = Union[AgentSetPolars, pl.DataFrame]
85+
8686

8787
if TYPE_CHECKING:
8888
from mesa_frames.concrete.model import ModelDF
@@ -1050,36 +1050,98 @@ def move_to(
10501050
include_center: bool = True,
10511051
shuffle: bool = True
10521052
) -> None:
1053+
"""
1054+
Move agents to new positions based on neighborhood ranking.
1055+
1056+
This method determines each agent's potential moves by computing their
1057+
local neighborhood (with optional inclusion of the center cell). For each
1058+
agent, the method ranks all possible moves according to specified attribute(s)
1059+
and rank order(s). If multiple agents contend for the same cell, the method
1060+
applies a tie-breaking approach. Agents can optionally be processed in a
1061+
randomized order to break ties. The final position of each agent is then
1062+
updated in-place.
1063+
1064+
Parameters
1065+
----------
1066+
agents : AgentLike
1067+
A DataFrame-like structure containing agent information. Must include
1068+
at least the following columns:
1069+
- ``agent_id``: a unique identifier for each agent
1070+
- ``dim_0``, ``dim_1``: the current positions of agents
1071+
- Optionally ``vision`` if ``radius`` is not provided
1072+
attr_names : str or list of str
1073+
The name(s) of the attribute(s) used for ranking the neighborhood cells.
1074+
If multiple attributes are provided, each should have a corresponding
1075+
entry in ``rank_order``.
1076+
rank_order : str or list of str, optional
1077+
The ranking order for each attribute. Accepts:
1078+
- ``"max"`` (default) for descending order
1079+
- ``"min"`` for ascending order
1080+
1081+
If a single string is provided, it is applied to all attributes in
1082+
``attr_names``.
1083+
radius : int or pl.Series, optional
1084+
The radius (or per-agent radii) defining the neighborhood around agents.
1085+
If not provided, this method attempts to use the ``vision`` column from
1086+
``agents``. If ``vision`` is not found, a ``ValueError`` is raised.
1087+
include_center : bool, optional
1088+
If ``True`` (default), the agent's current position is included in its
1089+
neighborhood.
1090+
shuffle : bool, optional
1091+
If ``True`` (default), the order of agents is randomized to break ties.
1092+
If ``False``, agents are processed in the order they appear in the data.
1093+
1094+
Returns
1095+
-------
1096+
None
1097+
This method updates agent positions in-place based on the computed best moves.
1098+
"""
1099+
# Ensure attr_names and rank_order are lists of the same length
10531100
if isinstance(attr_names, str):
10541101
attr_names = [attr_names]
10551102
if isinstance(rank_order, str):
10561103
rank_order = [rank_order] * len(attr_names)
10571104
if len(attr_names) != len(rank_order):
10581105
raise ValueError("attr_names and rank_order must have the same length")
1106+
1107+
# Handle the neighborhood radius
10591108
if radius is None:
10601109
if "vision" in agents.columns:
10611110
radius = agents["vision"]
10621111
else:
1063-
raise ValueError("radius must be specified if agents do not have a 'vision' attribute")
1112+
raise ValueError(
1113+
"radius must be specified if agents do not have a 'vision' attribute"
1114+
)
1115+
1116+
# Get neighborhood and join with cell information
10641117
neighborhood = self.get_neighborhood(
1065-
radius=radius,
1066-
agents=agents,
1118+
radius=radius,
1119+
agents=agents,
10671120
include_center=include_center
10681121
)
10691122
neighborhood = neighborhood.join(self.cells, on=["dim_0", "dim_1"])
1123+
1124+
# Determine the agent identifier column
1125+
agent_id_col = "agent_id" if "agent_id" in agents.columns else "unique_id"
1126+
1127+
# Add a column to identify the center agent
1128+
join_result = neighborhood.join(
1129+
agents.select(["dim_0", "dim_1", agent_id_col]),
1130+
left_on=["dim_0_center", "dim_1_center"],
1131+
right_on=["dim_0", "dim_1"]
1132+
)
1133+
10701134
neighborhood = neighborhood.with_columns(
1071-
agent_id_center=neighborhood.join(
1072-
agents.pos,
1073-
left_on=["dim_0_center", "dim_1_center"],
1074-
right_on=["dim_0", "dim_1"],
1075-
)["unique_id"]
1135+
agent_id_center=join_result[agent_id_col]
10761136
)
1137+
1138+
# Determine the processing order of agents
10771139
if shuffle:
10781140
agent_order = (
10791141
neighborhood
10801142
.unique(subset=["agent_id_center"], keep="first")
10811143
.select("agent_id_center")
1082-
.sample(fraction=1.0, seed=self.model.random.integers(0, 2**31-1))
1144+
.sample(fraction=1.0, seed=self.model.random.integers(0, 2**31 - 1))
10831145
.with_row_index("agent_order")
10841146
)
10851147
else:
@@ -1089,16 +1151,24 @@ def move_to(
10891151
.with_row_index("agent_order")
10901152
.select(["agent_id_center", "agent_order"])
10911153
)
1154+
1155+
# Join the processing order with the neighborhood
10921156
neighborhood = neighborhood.join(agent_order, on="agent_id_center")
1157+
1158+
# Prepare sorting columns and order
10931159
sort_cols = []
10941160
sort_desc = []
10951161
for attr, order in zip(attr_names, rank_order):
10961162
sort_cols.append(attr)
10971163
sort_desc.append(order.lower() == "max")
1164+
1165+
# Sort the neighborhood cells by specified attributes and then by location
10981166
neighborhood = neighborhood.sort(
10991167
sort_cols + ["radius", "dim_0", "dim_1"],
11001168
descending=sort_desc + [False, False, False]
11011169
)
1170+
1171+
# Join to track if another agent has blocked a cell
11021172
neighborhood = neighborhood.join(
11031173
agent_order.select(
11041174
pl.col("agent_id_center").alias("agent_id"),
@@ -1107,39 +1177,71 @@ def move_to(
11071177
on="agent_id",
11081178
how="left",
11091179
).rename({"agent_id": "blocking_agent_id"})
1180+
1181+
# Iteratively select the best moves
11101182
best_moves = pl.DataFrame()
1111-
while len(best_moves) < len(agents):
1183+
max_iterations = min(len(agents) * 2, 1000) # Safeguard against infinite loops
1184+
iteration_count = 0
1185+
1186+
while len(best_moves) < len(agents) and iteration_count < max_iterations:
1187+
iteration_count += 1
1188+
1189+
# Count how many times each (dim_0, dim_1) is being claimed
11121190
neighborhood = neighborhood.with_columns(
11131191
priority=pl.col("agent_order").cum_count().over(["dim_0", "dim_1"])
11141192
)
1193+
11151194
new_best_moves = (
11161195
neighborhood.group_by("agent_id_center", maintain_order=True)
11171196
.first()
11181197
.unique(subset=["dim_0", "dim_1"], keep="first", maintain_order=True)
11191198
)
1120-
condition = pl.col("blocking_agent_id").is_null() | (
1121-
pl.col("blocking_agent_id") == pl.col("agent_id_center")
1199+
1200+
condition = (
1201+
pl.col("blocking_agent_id").is_null()
1202+
| (pl.col("blocking_agent_id") == pl.col("agent_id_center"))
11221203
)
1204+
11231205
if len(best_moves) > 0:
11241206
condition = condition | pl.col("blocking_agent_id").is_in(
11251207
best_moves["agent_id_center"]
11261208
)
1209+
11271210
condition = condition & (pl.col("priority") == 1)
11281211
new_best_moves = new_best_moves.filter(condition)
1212+
11291213
if len(new_best_moves) == 0:
11301214
break
1215+
11311216
best_moves = pl.concat([best_moves, new_best_moves])
1217+
1218+
# Update neighborhood to exclude agents that already have a move
1219+
# and cells that are already claimed
11321220
neighborhood = neighborhood.filter(
11331221
~pl.col("agent_id_center").is_in(best_moves["agent_id_center"])
11341222
)
11351223
neighborhood = neighborhood.join(
11361224
best_moves.select(["dim_0", "dim_1"]), on=["dim_0", "dim_1"], how="anti"
11371225
)
1226+
1227+
# Move agents to their new positions
11381228
if len(best_moves) > 0:
1139-
self.move_agents(
1140-
best_moves.sort("agent_order")["agent_id_center"],
1141-
best_moves.sort("agent_order").select(["dim_0", "dim_1"])
1142-
)
1229+
try:
1230+
self.move_agents(
1231+
best_moves.sort("agent_order")["agent_id_center"],
1232+
best_moves.sort("agent_order").select(["dim_0", "dim_1"])
1233+
)
1234+
except Exception as e:
1235+
# Check if the agent exists in the model
1236+
available_agents = set(self.model.agents[agent_id_col].to_list()) if hasattr(self.model, 'agents') else set()
1237+
missing_agents = [a for a in best_moves["agent_id_center"].to_list() if a not in available_agents]
1238+
1239+
if missing_agents and available_agents:
1240+
raise ValueError(f"Some agents are not present in the model: {missing_agents}")
1241+
else:
1242+
raise ValueError(f"Error moving agents: {e}")
1243+
1244+
11431245

11441246

11451247
@property

mesa_frames/types_.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Type aliases for the mesa_frames package."""
22

33
from collections.abc import Collection, Sequence
4-
from typing import Literal
4+
from typing import Literal, Union
55

66
# import geopandas as gpd
77
# import geopolars as gpl
88
import pandas as pd
99
import polars as pl
1010
from numpy import ndarray
1111
from typing_extensions import Any
12+
from mesa_frames.concrete.polars.agentset import AgentSetPolars
1213

1314
####----- Agnostic Types -----####
1415
AgnosticMask = (
@@ -76,3 +77,4 @@
7677
NetworkCapacity = DataFrame
7778

7879
DiscreteSpaceCapacity = GridCapacity | NetworkCapacity
80+
AgentLike = Union[AgentSetPolars, pl.DataFrame]

0 commit comments

Comments
 (0)