Skip to content

Commit a1df873

Browse files
committed
Fix: Migration to the new EA engine.
1 parent 98fa26e commit a1df873

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

examples/re_book/2_body_brain_evolution.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
- 'ctrl': Float Vector (Array) -> Decodes to CPG Parameters
88
"""
99

10-
from __future__ import annotations
11-
1210
# Standard library
1311
import argparse
1412
import random
@@ -70,6 +68,12 @@
7068
)
7169
parser.add_argument("--pop", type=int, default=80, help="Population size")
7270
parser.add_argument("--dur", type=int, default=30, help="Sim Duration")
71+
parser.add_argument(
72+
"--visualize",
73+
action=argparse.BooleanOptionalAction,
74+
default=True,
75+
help="Launch MuJoCo viewer for best individual",
76+
)
7377
args = parser.parse_args()
7478

7579
# Constants
@@ -115,7 +119,8 @@ def __init__(self) -> None:
115119
is_maximisation=False, # Minimize Distance to Target
116120
num_steps=BUDGET,
117121
target_population_size=POP_SIZE,
118-
db_file_path=DATA / "database.db",
122+
output_folder=DATA,
123+
db_file_name="database.db",
119124
)
120125

121126
# ------------------------------------------------------------------------ #
@@ -272,7 +277,6 @@ def create_individual(self) -> Individual:
272277
ind.tags["debug_joints"] = 0
273278
return ind
274279

275-
@EAOperation
276280
def reproduction(self, population: Population) -> Population:
277281
"""Joint Reproduction: Crossover (Body + Brain) + Mutation."""
278282
parents = [ind for ind in population if ind.tags.get("ps", False)]
@@ -343,7 +347,6 @@ def reproduction(self, population: Population) -> Population:
343347
population.extend(new_offspring)
344348
return population
345349

346-
@EAOperation
347350
def evaluate(self, population: Population) -> Population:
348351
"""Evaluation Loop: Calls run_simulation in 'simple' mode."""
349352
to_eval = [
@@ -363,11 +366,8 @@ def evaluate(self, population: Population) -> Population:
363366

364367
return population
365368

366-
@EAOperation
367369
def parent_selection(self, population: Population) -> Population:
368-
population.sort(
369-
key=lambda x: x.fitness_ if x.fitness_ is not None else float("inf"),
370-
)
370+
population = population.sort(sort="min", attribute="fitness_")
371371
cutoff = len(population) // 2
372372
for i, ind in enumerate(population):
373373
ind.tags["ps"] = i < cutoff
@@ -380,11 +380,8 @@ def parent_selection(self, population: Population) -> Population:
380380

381381
return population
382382

383-
@EAOperation
384383
def survivor_selection(self, population: Population) -> Population:
385-
population.sort(
386-
key=lambda x: x.fitness_ if x.fitness_ is not None else float("inf"),
387-
)
384+
population = population.sort(sort="min", attribute="fitness_")
388385
survivors = population[: self.config.target_population_size]
389386
for ind in population:
390387
if ind not in survivors:
@@ -578,13 +575,20 @@ def evolve(self) -> Individual | None:
578575
population = self.evaluate(population)
579576

580577
ops = [
581-
self.parent_selection(),
582-
self.reproduction(),
583-
self.evaluate(),
584-
self.survivor_selection(),
578+
EAOperation(self.parent_selection),
579+
EAOperation(self.reproduction),
580+
EAOperation(self.evaluate),
581+
EAOperation(self.survivor_selection),
585582
]
586583

587-
ea = EA(population, operations=ops, num_steps=BUDGET)
584+
ea = EA(
585+
population,
586+
operations=ops,
587+
num_steps=BUDGET,
588+
db_file_path=self.config.db_file_path,
589+
db_handling=self.config.db_handling,
590+
quiet=self.config.quiet,
591+
)
588592
ea.run()
589593

590594
return ea.get_solution("best", only_alive=False)
@@ -601,7 +605,8 @@ def main() -> None:
601605
if best:
602606
console.rule("[bold green]Final Best Result")
603607
console.log(f"Best Fitness (Dist to Target): {best.fitness:.4f}")
604-
evo.run_simulation("launcher", best)
608+
if args.visualize:
609+
evo.run_simulation("launcher", best)
605610

606611

607612
if __name__ == "__main__":

0 commit comments

Comments
 (0)