diff --git a/neat/population.py b/neat/population.py index 9b8de4fe..6068b527 100644 --- a/neat/population.py +++ b/neat/population.py @@ -1,5 +1,7 @@ """Implements the core evolution algorithm.""" +from itertools import count + from neat.math_util import mean from neat.reporting import ReporterSet @@ -45,6 +47,10 @@ def __init__(self, config, initial_state=None): self.species.speciate(config, self.population, self.generation) else: self.population, self.species, self.generation = initial_state + # If the reproduction object has a genome indexer, + # set it to continue from the last genome ID. + if hasattr(self.reproduction, "genome_indexer"): + self.reproduction.genome_indexer = count(max(self.population.keys()) + 1) self.best_genome = None diff --git a/tests/test_population.py b/tests/test_population.py index 33ba62b7..9275dcbd 100644 --- a/tests/test_population.py +++ b/tests/test_population.py @@ -35,6 +35,39 @@ def test_invalid_fitness_criterion(self): with self.assertRaises(Exception): p = neat.Population(config) + def test_count_after_checkpoint_restore(self): + """ + Test that the genome indexer in DefaultGenome continues from the last genome ID + after restoring from a checkpoint. + """ + # Load configuration. + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + p = neat.Population(config) + filename_prefix = 'neat-checkpoint-test_population' + checkpointer = neat.Checkpointer(1, 5, filename_prefix=filename_prefix) + p.add_reporter(checkpointer) + + def eval_genomes(genomes, config): + for genome_id, genome in genomes: + genome.fitness = 0.5 + + p.run(eval_genomes, 5) + + filename = '{0}{1}'.format( + filename_prefix, checkpointer.last_generation_checkpoint + ) + restored_population = neat.Checkpointer.restore_checkpoint(filename) + last_genome_key = max([x.key for x in p.population.values()]) + + self.assertEqual( + next(restored_population.reproduction.genome_indexer), + last_genome_key + 1 + ) # def test_minimal(): # # sample fitness function