Skip to content

Commit b9d817a

Browse files
committed
feat: Implement experimental DataCollector API
1 parent 2dc485f commit b9d817a

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

mesa/experimental/measure.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from collections import defaultdict
2+
3+
4+
class Group:
5+
def __init__(self, model, fn):
6+
self.model = model
7+
self.fn = fn
8+
9+
@property
10+
def value(self):
11+
return self.fn(self.model)
12+
13+
14+
class Measure:
15+
def __init__(self, group, measurer):
16+
self.group = group
17+
self.measurer = measurer
18+
19+
def _measure_group(self, group, measurer):
20+
# get an attribute
21+
if isinstance(measurer, str):
22+
return getattr(group, measurer)
23+
# apply
24+
return measurer(group)
25+
26+
@property
27+
def value(self):
28+
group_object = self.group
29+
if isinstance(self.group, Group):
30+
group_object = self.group.value
31+
return self._measure_group(group_object, self.measurer)
32+
33+
34+
class DataCollector:
35+
"""
36+
Example: a model consisting of a hybrid of Boltzmann wealth model and
37+
Epstein civil violence.
38+
39+
class EpsteinBoltzmannModel:
40+
def __init__(self):
41+
# Groups
42+
self.quiescents = Group(
43+
lambda model: model.agents.select(
44+
agent_type=Citizen, filter_func=lambda a: a.condition == "Quiescent"
45+
)
46+
)
47+
self.citizens = Group(lambda model: model.get_agents_of_type(Citizen))
48+
49+
# Measurements
50+
self.num_quiescents = Measure(self.quiescents, len)
51+
self.gini = Measure(
52+
self.agents, lambda agents: calculate_gini(agents.get("wealth"))
53+
)
54+
self.gini_quiescents = Measure(
55+
self.quiescents, lambda agents: calculate_gini(agents.get("wealth"))
56+
)
57+
self.condition = Measure(self.citizens, "condition")
58+
self.wealth = Measure(self.agents, "wealth")
59+
60+
61+
def run():
62+
model = EpsteinBoltzmannModel()
63+
datacollector = DataCollector(
64+
model, ["num_quiescents", "gini_quiescents", "wealth"]
65+
)
66+
67+
for _ in range(10):
68+
model.step()
69+
datacollector.collect()
70+
"""
71+
72+
def __init__(self, model, attributes):
73+
self.model = model
74+
self.attributes = attributes
75+
self.data_collection = defaultdict(list)
76+
77+
def collect(self):
78+
for name in self.attributes:
79+
attribute = getattr(self.model, name)
80+
if isinstance(attribute, Measure):
81+
attribute = attribute.value
82+
self.data_collection[name] = attribute

0 commit comments

Comments
 (0)