Skip to content

Commit c08bede

Browse files
committed
feat: Implement experimental DataCollector API
1 parent 2dc485f commit c08bede

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

mesa/experimental/measure.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from collections import defaultdict
2+
3+
class Group:
4+
def __init__(self, model, fn):
5+
self.model = model
6+
self.fn = fn
7+
8+
@property
9+
def value(self):
10+
return self.fn(self.model)
11+
12+
class Measure:
13+
def __init__(self, group, measurer):
14+
self.group = group
15+
self.measurer = measurer
16+
17+
def _measure_group(self, group, measurer):
18+
# get an attribute
19+
if isinstance(measurer, str):
20+
return getattr(group, measurer)
21+
# apply
22+
return measurer(group)
23+
24+
@property
25+
def value(self):
26+
group_object = self.group
27+
if isinstance(self.group, Group):
28+
group_object = self.group.value
29+
return self._measure_group(group_object, self.measurer)
30+
31+
32+
class DataCollector:
33+
"""
34+
Example: a model consisting of a hybrid of Boltzmann wealth model and
35+
Epstein civil violence.
36+
37+
class EpsteinBoltzmannModel:
38+
def __init__(self):
39+
# Groups
40+
self.quiescents = Group(
41+
lambda model: model.agents.select(
42+
agent_type=Citizen, filter_func=lambda a: a.condition == "Quiescent"
43+
)
44+
)
45+
self.citizens = Group(lambda model: model.get_agents_of_type(Citizen))
46+
47+
# Measurements
48+
self.num_quiescents = Measure(self.quiescents, len)
49+
self.gini = Measure(
50+
self.agents, lambda agents: calculate_gini(agents.get("wealth"))
51+
)
52+
self.gini_quiescents = Measure(
53+
self.quiescents, lambda agents: calculate_gini(agents.get("wealth"))
54+
)
55+
self.condition = Measure(self.citizens, "condition")
56+
self.wealth = Measure(self.agents, "wealth")
57+
58+
59+
def run():
60+
model = EpsteinBoltzmannModel()
61+
datacollector = DataCollector(
62+
model, ["num_quiescents", "gini_quiescents", "wealth"]
63+
)
64+
65+
for _ in range(10):
66+
model.step()
67+
datacollector.collect()
68+
"""
69+
def __init__(self, model, attributes):
70+
self.model = model
71+
self.attributes = attributes
72+
self.data_collection = defaultdict(list)
73+
74+
def collect(self):
75+
for name in self.attributes:
76+
attribute = getattr(self.model, name)
77+
if isinstance(attribute, Measure):
78+
attribute = attribute.value
79+
self.data_collection[name] = attribute

0 commit comments

Comments
 (0)