Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions rllab/sampler/stateful_pool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@


from joblib.pool import MemmapingPool
import inspect
import multiprocessing as mp
from rllab.misc import logger
import pyprind
import time
import traceback
import sys

from joblib.pool import MemmapingPool
import pyprind

from rllab.misc import logger


class ProgBarCounter(object):
def __init__(self, total_count):
Expand Down Expand Up @@ -63,18 +64,24 @@ def initialize(self, n_parallel):

def run_each(self, runner, args_list=None):
"""
Run the method on each worker process, and collect the result of execution.
The runner method will receive 'G' as its first argument, followed by the arguments
in the args_list, if any
Run the method on each worker process, and collect the result of
execution.

The runner method will receive 'G' as its first argument, followed by
the arguments in the args_list, if any
:return:
"""
assert not inspect.ismethod(runner), (
"run_each() cannot run a class method. Please ensure that runner is"
" a function with the prototype def foo(G, ...), where G is an "
"object of type rllab.sampler.stateful_pool.SharedGlobal")

if args_list is None:
args_list = [tuple()] * self.n_parallel
assert len(args_list) == self.n_parallel
if self.n_parallel > 1:
results = self.pool.map_async(
_worker_run_each, [(runner, args) for args in args_list]
)
_worker_run_each, [(runner, args) for args in args_list])
for i in range(self.n_parallel):
self.worker_queue.get()
for i in range(self.n_parallel):
Expand All @@ -83,50 +90,74 @@ def run_each(self, runner, args_list=None):
return [runner(self.G, *args_list[0])]

def run_map(self, runner, args_list):
assert not inspect.ismethod(runner), (
"run_map() cannot run a class method. Please ensure that runner is "
"a function with the prototype 'def foo(G, ...)', where G is an "
"object of type rllab.sampler.stateful_pool.SharedGlobal")

if self.n_parallel > 1:
return self.pool.map(_worker_run_map, [(runner, args) for args in args_list])
return self.pool.map(_worker_run_map,
[(runner, args) for args in args_list])
else:
ret = []
for args in args_list:
ret.append(runner(self.G, *args))
return ret

def run_imap_unordered(self, runner, args_list):
assert not inspect.ismethod(runner), (
"run_imap_unordered() cannot run a class method. Please ensure that"
"runner is a function with the prototype 'def foo(G, ...)', where "
"G is an object of type rllab.sampler.stateful_pool.SharedGlobal")

if self.n_parallel > 1:
for x in self.pool.imap_unordered(_worker_run_map, [(runner, args) for args in args_list]):
for x in self.pool.imap_unordered(
_worker_run_map, [(runner, args) for args in args_list]):
yield x
else:
for args in args_list:
yield runner(self.G, *args)

def run_collect(self, collect_once, threshold, args=None, show_prog_bar=True):
def run_collect(self,
collect_once,
threshold,
args=None,
show_prog_bar=True):
"""
Run the collector method using the worker pool. The collect_once method will receive 'G' as
its first argument, followed by the provided args, if any. The method should return a pair of values.
The first should be the object to be collected, and the second is the increment to be added.
This will continue until the total increment reaches or exceeds the given threshold.
Run the collector method using the worker pool. The collect_once method
will receive 'G' as its first argument, followed by the provided args,
if any. The method should return a pair of values. The first should be
the object to be collected, and the second is the increment to be added.
This will continue until the total increment reaches or exceeds the
given threshold.

Sample script:

def collect_once(G):
return 'a', 1

stateful_pool.run_collect(collect_once, threshold=3) # => ['a', 'a', 'a']
stateful_pool.run_collect(collect_once, threshold=3)
# should return ['a', 'a', 'a']

:param collector:
:param threshold:
:return:
"""
assert not inspect.ismethod(collect_once), (
"run_collect() cannot run a class method. Please ensure that "
"collect_once is a function with the prototype 'def foo(G, ...)', "
"where G is an object of type "
"rllab.sampler.stateful_pool.SharedGlobal")

if args is None:
args = tuple()
if self.pool:
manager = mp.Manager()
counter = manager.Value('i', 0)
lock = manager.RLock()
results = self.pool.map_async(
_worker_run_collect,
[(collect_once, counter, lock, threshold, args)] * self.n_parallel
)
_worker_run_collect, [(collect_once, counter, lock, threshold,
args)] * self.n_parallel)
if show_prog_bar:
pbar = ProgBarCounter(threshold)
last_value = 0
Expand Down