Skip to content

Commit d72715f

Browse files
committed
feat: Implement proposal 4
As described in projectmesa#2013 (comment)
1 parent b4afc6e commit d72715f

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

mesa/experimental/observer.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,53 @@ class DataCollector:
88
Example: a model consisting of a hybrid of Boltzmann wealth model and
99
Epstein civil violence.
1010
```
11-
def get_citizen():
12-
return model.get_agents_of_type(Citizen)
11+
groups = {
12+
"quiescents": lambda: model.agents.select(
13+
agent_type=Citizen, filter_func=lambda a: a.condition == "Quiescent"
14+
),
15+
"citizens": lambda: model.get_agents_of_type(Citizen),
16+
}
1317
1418
collectors = {
15-
model: {
16-
"n_quiescent": lambda model: len(
17-
model.agents.select(
18-
agent_type=Citizen,
19-
filter_func=lambda a: a.condition == "Quiescent"
20-
)
21-
),
22-
"gini": lambda model: calculate_gini(model.agents.get("wealth"))
23-
},
24-
get_citizen: {"condition": condition},
25-
# This is a string, because model.agents may refer to a different
26-
# object, over time.
27-
"agents": {"wealth": "wealth"}
19+
("n_quiescent", "quiescents"): len,
20+
("gini", "model"): lambda model: calculate_gini(model.agents.get("wealth")),
21+
# A better way to do the former:
22+
("gini", "agents"): lambda agents: calculate_gini(agents.get("wealth")),
23+
("gini_quiescent", "quiescents"): lambda agents: calculate_gini(
24+
agents.get("wealth")
25+
),
26+
("condition", "citizens"): "condition",
27+
# "agents" is a string, because model.agents may refer to a different
28+
# object, over time, when accessed from scratch each time.
29+
("wealth", "agents"): "wealth",
2830
}
31+
2932
# Then finally
3033
model.datacollector = DataCollector(model, collectors=collectors).collect()
3134
```
3235
"""
3336

34-
def __init__(self, model, collectors=None) -> "DataCollector":
37+
def __init__(self, model, groups=None, collectors=None) -> "DataCollector":
3538
self.model = model
39+
self.groups = groups
3640
self.collectors = collectors
3741
self.data_collection = defaultdict(list)
3842
return self
3943

4044
def collect(self) -> "DataCollector":
41-
for group, group_collector in self.collectors.items():
45+
group_data = defaultdict(dict)
46+
for (name, group), collector in self.collectors.items():
4247
group_object = group
4348
if group == "agents":
4449
group_object = self.model.agents
4550
elif callable(group):
4651
group_object = group()
47-
data = {
48-
name: self._collect_group(group_object, collector)
49-
for name, collector in group_collector.items()
50-
}
52+
elif isinstance(group, str):
53+
group_object = self.groups[group]
54+
55+
group_data[group][name] = self._collect_group(group_object, collector)
56+
57+
for group, data in group_data.items():
5158
self.data_collection[group].append(data)
5259
return self
5360

0 commit comments

Comments
 (0)