Skip to content
33 changes: 21 additions & 12 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,30 +162,39 @@ class AgentSet(MutableSet, Sequence):
which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet.

Notes:
A `UserWarning` is issued if `random=None`. You can resolve this warning by explicitly
passing a random number generator. In most cases, this will be the seeded random number
generator in the model. So, you would do `random=self.random` in a `Model` or `Agent` instance.
If random is None then the random number generator in the model of the first agent is used.
If the agents list is empty and random is also None a user warning is issued and the AgentSet
is an empty list and a default random number generator. This can make models non-reproducible.
If your code may create an AgentSet with no agents please pass a random number generator explicitly.

"""

def __init__(self, agents: Iterable[Agent], random: Random | None = None):
def __init__(
self,
agents: Iterable[Agent],
random: Random | np.random.Generator | None = None,
):
"""Initializes the AgentSet with a collection of agents and a reference to the model.

Args:
agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
random (Random): the random number generator
random (Random | np.random.Generator | None): the random number generator
"""
if random is None:
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
if (len(self._agents) == 0) and random is None:
warnings.warn(
"Random number generator not specified, this can make models non-reproducible. Please pass a random number generator explicitly",
"No Agents specified in creation of AgentSet and no random number generator specified. "
"This can make models non-reproducible. Please pass a random number generator explicitly",
UserWarning,
stacklevel=2,
)
random = (
Random()
) # FIXME see issue 1981, how to get the central rng from model
self.random = random
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
random = Random()

if random is not None:
self.random = random
else:
# all agents in an AgentSet should share the same model, just take it from first
self.random = self._agents.keys().__next__().model.random

def __len__(self) -> int:
"""Return the number of agents in the AgentSet."""
Expand Down
60 changes: 38 additions & 22 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_agentset():
model = Model()
agents = [AgentTest(model) for _ in range(10)]

agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

assert agents[0] in agentset
assert len(agentset) == len(agents)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_function(agent):

# because AgentSet uses weakrefs, we need hard refs as well....
other_agents, another_set = pickle.loads( # noqa: S301
pickle.dumps([agents, AgentSet(agents, random=model.random)])
pickle.dumps([agents, AgentSet(agents)])
)
assert all(
a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents)
Expand All @@ -131,17 +131,33 @@ def test_agentset_initialization():
model = Model()
empty_agentset = AgentSet([], random=model.random)
assert len(empty_agentset) == 0
with pytest.warns(UserWarning):
empty_agentset2 = AgentSet([])
assert len(empty_agentset2) == 0

agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
assert len(agentset) == 10


def test_agentset_initialization_w_random():
"""Test agentset initialization."""
model = Model()
empty_agentset = AgentSet([], random=model.random)
assert len(empty_agentset) == 0
assert empty_agentset.random == model.random

agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents)
assert len(agentset) == 10
assert agentset.random == model.random


def test_agentset_serialization():
"""Test pickleability of agentset."""
model = Model()
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

serialized = pickle.dumps(agentset)
deserialized = pickle.loads(serialized) # noqa: S301
Expand All @@ -156,7 +172,7 @@ def test_agent_membership():
"""Test agent membership in AgentSet."""
model = Model()
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

assert agents[0] in agentset
assert AgentTest(model) not in agentset
Expand Down Expand Up @@ -218,7 +234,7 @@ def test_agentset_get_item():
"""Test integer based access to AgentSet."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

assert agentset[0] == agents[0]
assert agentset[-1] == agents[-1]
Expand All @@ -232,7 +248,7 @@ def test_agentset_do_str():
"""Test AgentSet.do with str."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")
Expand All @@ -245,7 +261,7 @@ def test_agentset_do_str():
n = 10
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -255,7 +271,7 @@ def test_agentset_do_str():
# setup
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -267,7 +283,7 @@ def test_agentset_do_callable():
"""Test AgentSet.do with callable."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test callable with non-existent function
with pytest.raises(AttributeError):
Expand All @@ -281,7 +297,7 @@ def test_agentset_do_callable():
n = 10
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -292,7 +308,7 @@ def test_agentset_do_callable():
# setup again for lambda function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -310,7 +326,7 @@ def remove_function(agent):
# setup again for actual function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand All @@ -321,7 +337,7 @@ def remove_function(agent):
# setup again for actual function tests
model = Model()
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)
for agent in agents:
agent.agent_set = agentset

Expand Down Expand Up @@ -386,7 +402,7 @@ def test_agentset_agg():
agent.energy = i + 1
agent.wealth = 10 * (i + 1)

agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test min aggregation
min_energy = agentset.agg("energy", min)
Expand Down Expand Up @@ -435,7 +451,7 @@ def __init__(self, model, age=None):

model = Model()
agents = [TestAgentWithAttribute(model, age=i) for i in range(5)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Set a new attribute "health" and an existing attribute "age" for all agents
agentset.set("health", 100).set("age", 50).set("status", "active")
Expand All @@ -454,7 +470,7 @@ def test_agentset_map_str():
"""Test AgentSet.map with strings."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

with pytest.raises(AttributeError):
agentset.do("non_existing_method")
Expand All @@ -467,7 +483,7 @@ def test_agentset_map_callable():
"""Test AgentSet.map with callable."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test callable with non-existent function
with pytest.raises(AttributeError):
Expand All @@ -494,7 +510,7 @@ def test_method(self):
self.called = True

agents = [TestAgentShuffleDo(model) for _ in range(100)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

# Test shuffle_do with a string method name
agentset.shuffle_do("test_method")
Expand Down Expand Up @@ -544,7 +560,7 @@ def test_agentset_get_attribute():
"""Test AgentSet.get for attributes."""
model = Model()
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

unique_ids = agentset.get("unique_id")
assert unique_ids == [agent.unique_id for agent in agents]
Expand All @@ -558,7 +574,7 @@ def test_agentset_get_attribute():
agent = AgentTest(model)
agent.i = i**2
agents.append(agent)
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

values = agentset.get(["unique_id", "i"])

Expand Down Expand Up @@ -634,7 +650,7 @@ def get_unique_identifier(self):

model = Model()
agents = [TestAgent(model) for _ in range(10)]
agentset = AgentSet(agents, random=model.random)
agentset = AgentSet(agents)

groups = agentset.groupby("even")
assert len(groups.groups[True]) == 5
Expand Down
Loading