diff --git a/mesa/time.py b/mesa/time.py index ca54bab7a70..4c097ab97ee 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -26,6 +26,7 @@ from __future__ import annotations import heapq +import weakref from collections import defaultdict # mypy @@ -64,7 +65,7 @@ def __init__(self, model: Model) -> None: self.model = model self.steps = 0 self.time: TimeT = 0 - self._agents: dict[int, Agent] = {} + self._agents: dict[int, weakref.ref[Agent]] = {} def add(self, agent: Agent) -> None: """Add an Agent object to the schedule. @@ -78,15 +79,17 @@ def add(self, agent: Agent) -> None: f"Agent with unique id {agent.unique_id!r} already added to scheduler" ) - self._agents[agent.unique_id] = agent + self._agents[agent.unique_id] = weakref.ref(agent) - def remove(self, agent: Agent) -> None: - """Remove all instances of a given agent from the schedule. + def remove(self, agent: Agent | weakref.ref[Agent]) -> None: + """Remove an agent from the schedule. Args: - agent: An agent object. + agent: An agent object or a weak reference to an agent. """ - del self._agents[agent.unique_id] + agent_id = (agent() if isinstance(agent, weakref.ref) else agent).unique_id + if agent_id in self._agents: + del self._agents[agent_id] def step(self) -> None: """Execute the step of all the agents, one at a time.""" @@ -102,7 +105,12 @@ def get_agent_count(self) -> int: @property def agents(self) -> list[Agent]: - return list(self._agents.values()) + """Return a list of live agent instances.""" + return [ + agent_ref() + for agent_ref in self._agents.values() + if agent_ref() is not None + ] def get_agent_keys(self, shuffle: bool = False) -> list[int]: # To be able to remove and/or add agents during stepping @@ -113,13 +121,18 @@ def get_agent_keys(self, shuffle: bool = False) -> list[int]: return agent_keys def do_each(self, method, agent_keys=None, shuffle=False): + """Performs a method on each agent, managing weak references.""" if agent_keys is None: - agent_keys = self.get_agent_keys() + agent_keys = list(self._agents.keys()) if shuffle: self.model.random.shuffle(agent_keys) - for agent_key in agent_keys: - if agent_key in self._agents: - getattr(self._agents[agent_key], method)() + + for agent_key in list(agent_keys): # Create a copy of the list to iterate over + agent_ref = self._agents.get(agent_key) + if agent_ref is not None: + agent = agent_ref() + if agent is not None: + getattr(agent, method)() class RandomActivation(BaseScheduler):