Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions mesa/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

import heapq
import weakref
from collections import defaultdict

# mypy
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(self, model: Model) -> None:
self.model = model
self.steps = 0
self.time: TimeT = 0
self._agents: dict[int, Agent] = {}
self._agents: dict[int, weakref.ref[Agent]] = {}

def add(self, agent: Agent) -> None:
"""Add an Agent object to the schedule.
Expand All @@ -78,15 +79,17 @@ def add(self, agent: Agent) -> None:
f"Agent with unique id {agent.unique_id!r} already added to scheduler"
)

self._agents[agent.unique_id] = agent
self._agents[agent.unique_id] = weakref.ref(agent)

def remove(self, agent: Agent) -> None:
"""Remove all instances of a given agent from the schedule.
def remove(self, agent: Agent | weakref.ref[Agent]) -> None:
"""Remove an agent from the schedule.

Args:
agent: An agent object.
agent: An agent object or a weak reference to an agent.
"""
del self._agents[agent.unique_id]
agent_id = (agent() if isinstance(agent, weakref.ref) else agent).unique_id
if agent_id in self._agents:
del self._agents[agent_id]

def step(self) -> None:
"""Execute the step of all the agents, one at a time."""
Expand All @@ -102,7 +105,12 @@ def get_agent_count(self) -> int:

@property
def agents(self) -> list[Agent]:
return list(self._agents.values())
"""Return a list of live agent instances."""
return [
agent_ref()
for agent_ref in self._agents.values()
if agent_ref() is not None
]

def get_agent_keys(self, shuffle: bool = False) -> list[int]:
# To be able to remove and/or add agents during stepping
Expand All @@ -113,13 +121,18 @@ def get_agent_keys(self, shuffle: bool = False) -> list[int]:
return agent_keys

def do_each(self, method, agent_keys=None, shuffle=False):
"""Performs a method on each agent, managing weak references."""
if agent_keys is None:
agent_keys = self.get_agent_keys()
agent_keys = list(self._agents.keys())
if shuffle:
self.model.random.shuffle(agent_keys)
for agent_key in agent_keys:
if agent_key in self._agents:
getattr(self._agents[agent_key], method)()

for agent_key in list(agent_keys): # Create a copy of the list to iterate over
agent_ref = self._agents.get(agent_key)
if agent_ref is not None:
agent = agent_ref()
if agent is not None:
getattr(agent, method)()


class RandomActivation(BaseScheduler):
Expand Down