Skip to content

Commit 5d2812d

Browse files
authored
Merge pull request #40 from SteveBronder/parallel-service
Add design doc for parallel services
2 parents 91177e7 + fa63256 commit 5d2812d

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

designs/0020-parallel-chain-api.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
- Feature Name: parallel_chain_api
2+
- Start Date: 2021-04-06
3+
- RFC PR: (leave this empty)
4+
- Stan Issue: (leave this empty)
5+
6+
# Summary
7+
[summary]: #summary
8+
9+
This outlines a services layer API for running multiple chains in one Stan program.
10+
11+
# Motivation
12+
[motivation]: #motivation
13+
14+
Currently, to run multiple chains for a given model a user or developer must use higher level parallelization tools such as `gnu parallel` or R/Python parallelism schemes. The high level approach is partly done because of intracacies at the lower level around managing Stan's thread local stack allocators along with multi-threaded IO. Providing a service layer API for multiple chains in one Stan program will remove the requirment of interfaces to impliment all the necessary tools for parallel chains in one Stan program independently. Moreover, we have access to the TBB and with it a schedular for managing hierarchical parallelism. We can utilize the TBB to provide service API's for running multiple chains in one program and safely account for possible parallelism within a model using tools such as `reduce_sum()`.
15+
16+
The benefits to this scheme are mostly in memory savings and standardization of multi chain processes in Stan. Because a stan model is immutable after construction it's possible to share that model across all chains. For a model that uses 1GB of data running 8 chains in parallel means we use 8GB of RAM. However by sharing the model across the chains we simply use 1GB of data.
17+
18+
Having a standardized IO and API for multi chain processes will allow researchers to develop methods which utilize information across chains. This research can allow for algorithms such as automated warmup periods where instead of hard coding the number of warmups, warmups will only happen until a set of conditions are achieved and then we can begin sampling.
19+
20+
# Guide-level explanation
21+
[guide-level-explanation]: #guide-level-explanation
22+
23+
Each of the servies layers in [`src/stan/services/`](https://github.com/stan-dev/stan/blob/147fba5fb93aa007ec42744a36d97cc84c291945/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp) will have the current API for single chain processes as well as an API for running multi chain processes. Their inputs are conceptually the same, but several of the inputs have been changed to be vectors of the single chain processes arguments in order to account for multiple chains. For instance, the signature of a single chain for `hmc_nuts_dense_e_adapt` now has `std::vector`s for the initialial values, inverse metric, init writers, sample writers, and diagnostic writers.
24+
25+
```cpp
26+
template <class Model>
27+
int hmc_nuts_dense_e_adapt(
28+
Model& model,
29+
const stan::io::var_context& init,
30+
const stan::io::var_context& init_inv_metric,
31+
unsigned int random_seed,
32+
unsigned int init_chain_id, double init_radius, int num_warmup, int num_samples,
33+
int num_thin, bool save_warmup, int refresh, double stepsize,
34+
double stepsize_jitter, int max_depth, double delta, double gamma,
35+
double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer,
36+
unsigned int window,
37+
callbacks::interrupt& interrupt,
38+
callbacks::logger& logger,
39+
callbacks::writer& init_writer,
40+
callbacks::writer& sample_writer,
41+
callbacks::writer& diagnostic_writer)
42+
```
43+
44+
```cpp
45+
template <typename Model, typename InitContextPtr, typename InitInvContextPtr,
46+
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
47+
int hmc_nuts_dense_e_adapt(
48+
Model& model,
49+
size_t num_chains,
50+
// now vectors
51+
const std::vector<InitContextPtr>& init,
52+
const std::vector<InitInvContextPtr>& init_inv_metric,
53+
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
54+
int num_warmup, int num_samples, int num_thin, bool save_warmup,
55+
int refresh, double stepsize, double stepsize_jitter, int max_depth,
56+
double delta, double gamma, double kappa, double t0,
57+
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
58+
// interrupt and logger must be threadsafe
59+
callbacks::interrupt& interrupt,
60+
callbacks::logger& logger,
61+
// now vectors
62+
std::vector<InitWriter>& init_writer,
63+
std::vector<SampleWriter>& sample_writer,
64+
std::vector<DiagnosticWriter>& diagnostic_writer)
65+
```
66+
67+
Additionally the new API has an argument `num_chains` which tells the backend how many chains to run and `init_chain_id` instead of `chain`. `init_chain_id` will be used to generate PRNGs for each chain as `seed + init_chain_id + chain_num` where `chain_num` is the i'th chain being generated. All of the vector inputs must be the same size as `num_chains`. `InitContextPtr` and `InitInvContextPtr` must have a valid `operator*` which returns back a reference to a class derived from `stan::io::var_context`.
68+
69+
The elements of the vectors for `init`, `init_inv_metric`, `interrupt`, `logger`, `init_writer`, `sample_writer`, and `diagnostic_writer` must be threadsafe. `init` and `init_inv_metric` are only read from so should be threadsafe by default. Any of the writers which write to `std::cout` are safe by the standard, though it is recommended to write any output to an local `std::stringstream` and then pass the fully constructed output so that thread outputs are not mixed together. See the code [here](https://github.com/stan-dev/stan/pull/3033/files#diff-ab5eb0683288927defb395f1af49548c189f6e7ab4b06e217dec046b0c1be541R80) for an example. Additionally if the elements of `init_writer`, `sample_writer`, and `diagnostic_writer` each point to unique output they will be threadsafe as well.
70+
71+
# Reference-level explanation
72+
[reference-level-explanation]: #reference-level-explanation
73+
74+
The services API on the backend has a prototype implementation found [here](https://github.com/stan-dev/stan/blob/147fba5fb93aa007ec42744a36d97cc84c291945/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp#L206). The main additions to this change are in creating the following for each chain.
75+
76+
1. PRNGs
77+
2. Initializations
78+
3. Samplers
79+
4. inverse metrics
80+
81+
Then a [`tbb::parallel_for()`](https://github.com/stan-dev/stan/blob/147fba5fb93aa007ec42744a36d97cc84c291945/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp#L261) is used to run the each of the samplers.
82+
83+
PRNGs will be initialized such as the following pseudocode, where a constant stride is used to initialize the PRNG.
84+
85+
```cpp
86+
inline boost::ecuyer1988 create_rng(unsigned int seed, unsigned int init_chain_id, unsigned int chain_num) {
87+
// Initialize L’ecuyer generator
88+
boost::ecuyer1988 rng(seed);
89+
90+
// Seek generator to disjoint region for each chain
91+
static uintmax_t DISCARD_STRIDE = static_cast<uintmax_t>(1) << 50;
92+
rng.discard(DISCARD_STRIDE * (init_chain_id + chain_num - 1));
93+
return rng;
94+
}
95+
```
96+
97+
The constant stride guarantees that models which use multiple chains in one program and multiple programs using multiple chains are able to be reproducible given the same seed as noted below.
98+
99+
### Recommended Upstream Initialization
100+
101+
Upstream packages can generate `init` and `init_inv_metric` as they wish, though for cmdstan the prototype follows the following rules for reading user input.
102+
103+
If the user specifies their init as `{file_name}.{file_ending}` with an input `id` of `N` and chains `M` then the program will search for `{file_name}_{N..(N + M)}.{file_ending}` where `N..(N + M)` is a linear integer sequence from `N` to `N + M`. If the program fails to find any of the `{file_name}_{N..(N + M)}.{file_ending}` it will then search for `{file_name}.{file_ending}` and if found will use that. Otherwise an exception will occur.
104+
105+
For example, if a user specifies `chains=4`, `id=2`, and their init file as `init=init.data.R` then the program
106+
will first search for `init.data_2.R` and if it finds it will then search for `init.data_3.R`,
107+
`init.data_4.R`, `init.data_5.R` and will fail if all files are not found. If the program fails to find `init.data_2.R` then it will attempt
108+
to find `init.data.R` and if successful will use these initial values for all chains. If neither
109+
are found then an error will be thrown.
110+
111+
Documentation must be added to clarify reproducibility between a multi-chain program and running multiple chains across several programs. This requires
112+
113+
1. Using the same random seed for the multi-chain program and each program running a chain.
114+
2. Starting each program in the multi-chain context with the `ith` chain number.
115+
116+
For example, the following two sets of calls should produce the same results up to floating point accuracy.
117+
118+
```bash
119+
# From cmdstan example folder
120+
# running 4 chains at once
121+
examples/bernoulli/bernoulli sample data file=examples/bernoulli/bernoulli.data.R chains=4 id=1 random seed=123 output file=output.csv
122+
# Running 4 seperate chains
123+
examples/bernoulli/bernoulli sample data file=examples/bernoulli/bernoulli.data.R chains=1 id=1 random seed=123 output file=output1.csv
124+
examples/bernoulli/bernoulli sample data file=examples/bernoulli/bernoulli.data.R chains=1 id=2 random seed=123 output file=output2.csv
125+
examples/bernoulli/bernoulli sample data file=examples/bernoulli/bernoulli.data.R chains=1 id=3 random seed=123 output file=output3.csv
126+
127+
examples/bernoulli/bernoulli sample data file=examples/bernoulli/bernoulli.data.R chains=1 id=4 random seed=123 output file=output4.csv
128+
```
129+
130+
In general the constant stride allow for the following where `n1 + n2 + n3 + n4 = N` chains.
131+
132+
```
133+
seed=848383, id=1, chains=n1
134+
seed=848383, id=1 + n1, chains=n2
135+
seed=848383, id=1 + n1 + n2, chains=n3
136+
seed=848383, id=1 + n1 + n2 + n3, chains=n4
137+
```
138+
139+
140+
141+
# Drawbacks
142+
[drawbacks]: #drawbacks
143+
144+
This does add overhead to existing implementations in managing the per chain IO.

0 commit comments

Comments
 (0)