Skip to content

Commit 117664b

Browse files
author
yiz
committed
merge mpi_warmup_framework
2 parents 83a450a + 6052421 commit 117664b

File tree

5 files changed

+45
-32
lines changed

5 files changed

+45
-32
lines changed

src/stan/mcmc/hmc/mpi_cross_chain_adapter.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ namespace mcmc {
113113

114114
inline void
115115
write_num_cross_chain_warmup(callbacks::writer& sample_writer,
116-
int num_thin) {
117-
size_t n = num_cross_chain_draws();
118-
sample_writer("num_warmup = " + std::to_string(n / num_thin));
116+
int num_thin, int num_warmup) {
117+
if (use_cross_chain_adapt()) {
118+
size_t n = num_cross_chain_draws();
119+
sample_writer("num_warmup = " + std::to_string(n / num_thin));
120+
} else {
121+
sample_writer("num_warmup = " + std::to_string(num_warmup));
122+
}
119123
}
120124

121125
/*
@@ -454,7 +458,9 @@ namespace mcmc {
454458
sampler.set_nominal_stepsize(new_stepsize);
455459
}
456460

457-
inline bool use_cross_chain_adapt() { return true; }
461+
inline bool use_cross_chain_adapt() {
462+
return num_chains_ > 1;
463+
}
458464
};
459465

460466
#else // sequential version

src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,30 @@ class adapt_dense_e_nuts : public dense_e_nuts<Model, BaseRNG>,
3131
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
3232
s.accept_stat());
3333

34-
bool update;
3534
if (this -> use_cross_chain_adapt()) {
3635
this -> add_cross_chain_sample(s.log_prob());
37-
update = this -> cross_chain_adaptation(logger);
36+
bool update = this -> cross_chain_adaptation(logger);
3837
if (this -> is_cross_chain_adapted()) {
3938
update = false;
4039
}
41-
} else {
42-
update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_,
43-
this->z_.q);
44-
}
4540

46-
if (update) {
47-
this->init_stepsize(logger);
41+
if (update) {
42+
this->init_stepsize(logger);
4843

49-
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
50-
this->stepsize_adaptation_.restart();
44+
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
45+
this->stepsize_adaptation_.restart();
5146

52-
if (this -> use_cross_chain_adapt()) {
5347
this->set_cross_chain_stepsize();
5448
}
49+
} else {
50+
bool update = this->covar_adaptation_.learn_covariance(this->z_.inv_e_metric_,
51+
this->z_.q);
52+
if (update) {
53+
this->init_stepsize(logger);
54+
55+
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
56+
this->stepsize_adaptation_.restart();
57+
}
5558
}
5659
}
5760
return s;

src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,31 @@ class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
3131
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
3232
s.accept_stat());
3333

34-
bool update;
34+
3535
if (this -> use_cross_chain_adapt()) {
3636
this -> add_cross_chain_sample(s.log_prob());
37-
update = this -> cross_chain_adaptation(logger);
37+
bool update = this -> cross_chain_adaptation(logger);
3838
if (this -> is_cross_chain_adapted()) {
3939
update = false;
4040
}
41-
} else {
42-
update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
43-
this->z_.q);
44-
}
4541

46-
if (update) {
47-
this->init_stepsize(logger);
42+
if (update) {
43+
this->init_stepsize(logger);
4844

49-
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
50-
this->stepsize_adaptation_.restart();
45+
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
46+
this->stepsize_adaptation_.restart();
5147

52-
if (this -> use_cross_chain_adapt()) {
5348
this->set_cross_chain_stepsize();
5449
}
50+
} else {
51+
bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
52+
this->z_.q);
53+
if (update) {
54+
this->init_stepsize(logger);
55+
56+
this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
57+
this->stepsize_adaptation_.restart();
58+
}
5559
}
5660
}
5761
return s;

src/stan/services/util/mpi_cross_chain.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace util {
5454

5555
static void write_num_warmup(Sampler& sampler,
5656
callbacks::writer& sample_writer,
57-
int num_thin) {}
57+
int num_thin, int num_warmup) {}
5858
};
5959

6060
/*
@@ -81,8 +81,8 @@ namespace util {
8181

8282
static void write_num_warmup(Sampler& sampler,
8383
callbacks::writer& sample_writer,
84-
int num_thin) {
85-
sampler.write_num_cross_chain_warmup(sample_writer, num_thin);
84+
int num_thin, int num_warmup) {
85+
sampler.write_num_cross_chain_warmup(sample_writer, num_thin, num_warmup);
8686
}
8787
};
8888
#endif
@@ -111,9 +111,9 @@ namespace util {
111111

112112
static void write_num_warmup(Sampler& sampler,
113113
callbacks::writer& sample_writer,
114-
int num_thin) {
114+
int num_thin, int num_warmup) {
115115
mpi_cross_chain_impl<Sampler, has_cross_chain_warmup<Sampler>::value>::
116-
write_num_warmup(sampler, sample_writer, num_thin);
116+
write_num_warmup(sampler, sample_writer, num_thin, num_warmup);
117117
}
118118
};
119119

src/stan/services/util/run_adaptive_sampler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void run_adaptive_sampler(Sampler& sampler, Model& model,
7373
mpi_cross_chain<Sampler>::num_draws(sampler),
7474
num_warmup + num_samples, num_thin, refresh, save_warmup,
7575
true, writer, s, model, rng, interrupt, logger);
76-
mpi_cross_chain<Sampler>::write_num_warmup(sampler, sample_writer, num_thin);
76+
mpi_cross_chain<Sampler>::write_num_warmup(sampler, sample_writer, num_thin, num_warmup);
7777

7878
clock_t end = clock();
7979
double warm_delta_t = static_cast<double>(end - start) / CLOCKS_PER_SEC;

0 commit comments

Comments
 (0)