diff --git a/mesa/agent.py b/mesa/agent.py index 06d8f451b51..e6ac5473ff6 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -162,30 +162,39 @@ class AgentSet(MutableSet, Sequence): which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet. Notes: - A `UserWarning` is issued if `random=None`. You can resolve this warning by explicitly - passing a random number generator. In most cases, this will be the seeded random number - generator in the model. So, you would do `random=self.random` in a `Model` or `Agent` instance. + If random is None then the random number generator in the model of the first agent is used. + If the agents list is empty and random is also None a user warning is issued and the AgentSet + is an empty list and a default random number generator. This can make models non-reproducible. + If your code may create an AgentSet with no agents please pass a random number generator explicitly. """ - def __init__(self, agents: Iterable[Agent], random: Random | None = None): + def __init__( + self, + agents: Iterable[Agent], + random: Random | np.random.Generator | None = None, + ): """Initializes the AgentSet with a collection of agents and a reference to the model. Args: agents (Iterable[Agent]): An iterable of Agent objects to be included in the set. - random (Random): the random number generator + random (Random | np.random.Generator | None): the random number generator """ - if random is None: + self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents)) + if (len(self._agents) == 0) and random is None: warnings.warn( - "Random number generator not specified, this can make models non-reproducible. Please pass a random number generator explicitly", + "No Agents specified in creation of AgentSet and no random number generator specified. " + "This can make models non-reproducible. Please pass a random number generator explicitly", UserWarning, stacklevel=2, ) - random = ( - Random() - ) # FIXME see issue 1981, how to get the central rng from model - self.random = random - self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents)) + random = Random() + + if random is not None: + self.random = random + else: + # all agents in an AgentSet should share the same model, just take it from first + self.random = self._agents.keys().__next__().model.random def __len__(self) -> int: """Return the number of agents in the AgentSet.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index fb49d401db6..900533cf2d1 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -62,7 +62,7 @@ def test_agentset(): model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) assert agents[0] in agentset assert len(agentset) == len(agents) @@ -118,7 +118,7 @@ def test_function(agent): # because AgentSet uses weakrefs, we need hard refs as well.... other_agents, another_set = pickle.loads( # noqa: S301 - pickle.dumps([agents, AgentSet(agents, random=model.random)]) + pickle.dumps([agents, AgentSet(agents)]) ) assert all( a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents) @@ -131,17 +131,33 @@ def test_agentset_initialization(): model = Model() empty_agentset = AgentSet([], random=model.random) assert len(empty_agentset) == 0 + with pytest.warns(UserWarning): + empty_agentset2 = AgentSet([]) + assert len(empty_agentset2) == 0 agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) assert len(agentset) == 10 +def test_agentset_initialization_w_random(): + """Test agentset initialization.""" + model = Model() + empty_agentset = AgentSet([], random=model.random) + assert len(empty_agentset) == 0 + assert empty_agentset.random == model.random + + agents = [AgentTest(model) for _ in range(10)] + agentset = AgentSet(agents) + assert len(agentset) == 10 + assert agentset.random == model.random + + def test_agentset_serialization(): """Test pickleability of agentset.""" model = Model() agents = [AgentTest(model) for _ in range(5)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) serialized = pickle.dumps(agentset) deserialized = pickle.loads(serialized) # noqa: S301 @@ -156,7 +172,7 @@ def test_agent_membership(): """Test agent membership in AgentSet.""" model = Model() agents = [AgentTest(model) for _ in range(5)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) assert agents[0] in agentset assert AgentTest(model) not in agentset @@ -218,7 +234,7 @@ def test_agentset_get_item(): """Test integer based access to AgentSet.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) assert agentset[0] == agents[0] assert agentset[-1] == agents[-1] @@ -232,7 +248,7 @@ def test_agentset_do_str(): """Test AgentSet.do with str.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) with pytest.raises(AttributeError): agentset.do("non_existing_method") @@ -245,7 +261,7 @@ def test_agentset_do_str(): n = 10 model = Model() agents = [AgentDoTest(model) for _ in range(n)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) for agent in agents: agent.agent_set = agentset @@ -255,7 +271,7 @@ def test_agentset_do_str(): # setup model = Model() agents = [AgentDoTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) for agent in agents: agent.agent_set = agentset @@ -267,7 +283,7 @@ def test_agentset_do_callable(): """Test AgentSet.do with callable.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) # Test callable with non-existent function with pytest.raises(AttributeError): @@ -281,7 +297,7 @@ def test_agentset_do_callable(): n = 10 model = Model() agents = [AgentDoTest(model) for _ in range(n)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) for agent in agents: agent.agent_set = agentset @@ -292,7 +308,7 @@ def test_agentset_do_callable(): # setup again for lambda function tests model = Model() agents = [AgentDoTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) for agent in agents: agent.agent_set = agentset @@ -310,7 +326,7 @@ def remove_function(agent): # setup again for actual function tests model = Model() agents = [AgentDoTest(model) for _ in range(n)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) for agent in agents: agent.agent_set = agentset @@ -321,7 +337,7 @@ def remove_function(agent): # setup again for actual function tests model = Model() agents = [AgentDoTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) for agent in agents: agent.agent_set = agentset @@ -386,7 +402,7 @@ def test_agentset_agg(): agent.energy = i + 1 agent.wealth = 10 * (i + 1) - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) # Test min aggregation min_energy = agentset.agg("energy", min) @@ -435,7 +451,7 @@ def __init__(self, model, age=None): model = Model() agents = [TestAgentWithAttribute(model, age=i) for i in range(5)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) # Set a new attribute "health" and an existing attribute "age" for all agents agentset.set("health", 100).set("age", 50).set("status", "active") @@ -454,7 +470,7 @@ def test_agentset_map_str(): """Test AgentSet.map with strings.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) with pytest.raises(AttributeError): agentset.do("non_existing_method") @@ -467,7 +483,7 @@ def test_agentset_map_callable(): """Test AgentSet.map with callable.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) # Test callable with non-existent function with pytest.raises(AttributeError): @@ -494,7 +510,7 @@ def test_method(self): self.called = True agents = [TestAgentShuffleDo(model) for _ in range(100)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) # Test shuffle_do with a string method name agentset.shuffle_do("test_method") @@ -544,7 +560,7 @@ def test_agentset_get_attribute(): """Test AgentSet.get for attributes.""" model = Model() agents = [AgentTest(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) unique_ids = agentset.get("unique_id") assert unique_ids == [agent.unique_id for agent in agents] @@ -558,7 +574,7 @@ def test_agentset_get_attribute(): agent = AgentTest(model) agent.i = i**2 agents.append(agent) - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) values = agentset.get(["unique_id", "i"]) @@ -634,7 +650,7 @@ def get_unique_identifier(self): model = Model() agents = [TestAgent(model) for _ in range(10)] - agentset = AgentSet(agents, random=model.random) + agentset = AgentSet(agents) groups = agentset.groupby("even") assert len(groups.groups[True]) == 5