Skip to content

Commit e137a60

Browse files
authored
Make batch_run pytestable by adding main() functions (#143)
Added a main() function to bank_reserves `batch_run.py` and sugarscape_g1mt `run.py` scripts to facilitate testing and script execution. The `main()` function encapsulates the primary script logic, allowing for easier modular testing and execution. By defining script operations within `main()`, we can directly invoke this function in testing environments without relying on command-line execution. This practice enhances code readability, maintainability, and testability, providing a clear entry point for the script's functionality.
1 parent df3b9e0 commit e137a60

File tree

2 files changed

+47
-71
lines changed

2 files changed

+47
-71
lines changed

examples/bank_reserves/batch_run.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -185,38 +185,16 @@ def run_model(self):
185185
"reserve_percent": 5,
186186
}
187187

188-
if __name__ == "__main__":
188+
189+
def main():
190+
# The existing batch run logic here
189191
data = mesa.batch_run(
190192
BankReservesModel,
191193
br_params,
192194
)
193195
br_df = pd.DataFrame(data)
194196
br_df.to_csv("BankReservesModel_Data.csv")
195197

196-
# The commented out code below is the equivalent code as above, but done
197-
# via the legacy BatchRunner class. This is a good example to look at if
198-
# you want to migrate your code to use `batch_run()` from `BatchRunner`.
199-
# Things to note:
200-
# - You have to set "reserve_percent" in br_params to `[5]`, because the
201-
# legacy BatchRunner doesn't auto-detect that it is single-valued.
202-
# - The model reporters need to be explicitly specified in the legacy
203-
# BatchRunner
204-
"""
205-
from mesa.batchrunner import BatchRunnerMP
206-
br = BatchRunnerMP(
207-
BankReservesModel,
208-
nr_processes=2,
209-
variable_parameters=br_params,
210-
iterations=2,
211-
max_steps=1000,
212-
model_reporters={"Data Collector": lambda m: m.datacollector},
213-
)
214-
br.run_all()
215-
br_df = br.get_model_vars_dataframe()
216-
br_step_data = pd.DataFrame()
217-
for i in range(len(br_df["Data Collector"])):
218-
if isinstance(br_df["Data Collector"][i], DataCollector):
219-
i_run_data = br_df["Data Collector"][i].get_model_vars_dataframe()
220-
br_step_data = br_step_data.append(i_run_data, ignore_index=True)
221-
br_step_data.to_csv("BankReservesModel_Step_Data.csv")
222-
"""
198+
199+
if __name__ == "__main__":
200+
main()

examples/sugarscape_g1mt/run.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -61,47 +61,45 @@ def assess_results(results, single_agent):
6161

6262

6363
# Run the model
64+
def main():
65+
args = sys.argv[1:]
66+
67+
if len(args) == 0:
68+
server.launch()
69+
70+
elif args[0] == "-s":
71+
print("Running Single Model")
72+
model = SugarscapeG1mt()
73+
model.run_model()
74+
model_results = model.datacollector.get_model_vars_dataframe()
75+
model_results["Step"] = model_results.index
76+
agent_results = model.datacollector.get_agent_vars_dataframe()
77+
agent_results = agent_results.reset_index()
78+
assess_results(model_results, agent_results)
79+
80+
elif args[0] == "-b":
81+
print("Conducting a Batch Run")
82+
params = {
83+
"width": 50,
84+
"height": 50,
85+
"vision_min": range(1, 4),
86+
"metabolism_max": [2, 3, 4, 5],
87+
}
88+
89+
results_batch = mesa.batch_run(
90+
SugarscapeG1mt,
91+
parameters=params,
92+
iterations=1,
93+
number_processes=1,
94+
data_collection_period=1,
95+
display_progress=True,
96+
)
97+
98+
assess_results(results_batch, None)
6499

65-
args = sys.argv[1:]
66-
67-
if len(args) == 0:
68-
server.launch()
69-
70-
elif args[0] == "-s":
71-
print("Running Single Model")
72-
# instantiate the model
73-
model = SugarscapeG1mt()
74-
# run the model
75-
model.run_model()
76-
# Get results
77-
model_results = model.datacollector.get_model_vars_dataframe()
78-
# Convert to make similar to batch_run_results
79-
model_results["Step"] = model_results.index
80-
agent_results = model.datacollector.get_agent_vars_dataframe()
81-
agent_results = agent_results.reset_index()
82-
# assess the results
83-
assess_results(model_results, agent_results)
84-
85-
elif args[0] == "-b":
86-
print("Conducting a Batch Run")
87-
# Batch Run
88-
params = {
89-
"width": 50,
90-
"height": 50,
91-
"vision_min": range(1, 4),
92-
"metabolism_max": [2, 3, 4, 5],
93-
}
94-
95-
results_batch = mesa.batch_run(
96-
SugarscapeG1mt,
97-
parameters=params,
98-
iterations=1,
99-
number_processes=1,
100-
data_collection_period=1,
101-
display_progress=True,
102-
)
103-
104-
assess_results(results_batch, None)
105-
106-
else:
107-
raise Exception("Option not found")
100+
else:
101+
raise Exception("Option not found")
102+
103+
104+
if __name__ == "__main__":
105+
main()

0 commit comments

Comments
 (0)