Skip to content

Commit a7c61ac

Browse files
author
yiz
committed
cross-chain warmup
1 parent b69492b commit a7c61ac

23 files changed

+1746
-426
lines changed

.gitmodules

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[submodule "lib/stan_math"]
22
path = lib/stan_math
33
url = https://github.com/stan-dev/math.git
4+
branch = mpi_warmup_v2

make/mpi_warmup.mk

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
ifdef MPI_ADAPTED_WARMUP
2+
CXXFLAGS += -DSTAN_LANG_MPI -DMPI_ADAPTED_WARMUP
3+
CC=mpicxx
4+
CXX=mpicxx
5+
endif

makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ help:
1313

1414
-include $(HOME)/.config/stan/make.local # user-defined variables
1515
-include make/local # user-defined variables
16+
-include make/mpi_warmup.mk
1617

1718
MATH ?= lib/stan_math/
1819
ifeq ($(OS),Windows_NT)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#ifndef STAN_CALLBACKS_MPI_STREAM_WRITER_HPP
2+
#define STAN_CALLBACKS_MPI_STREAM_WRITER_HPP
3+
4+
#ifdef MPI_ADAPTED_WARMUP
5+
6+
#include <stan/callbacks/writer.hpp>
7+
#include <stan/math/mpi/envionment.hpp>
8+
#include <ostream>
9+
#include <vector>
10+
#include <string>
11+
12+
namespace stan {
13+
namespace callbacks {
14+
/**
15+
* <code>mpi_stream_writer</code> is an implementation
16+
* of <code>writer</code> that writes to a stream.
17+
*/
18+
class mpi_stream_writer : public writer {
19+
public:
20+
/**
21+
* Constructs a stream writer with an output stream
22+
* and an optional prefix for comments.
23+
*
24+
* @param[in, out] output stream to write
25+
* @param[in] comment_prefix string to stream before
26+
* each comment line. Default is "".
27+
*/
28+
mpi_stream_writer(int num_chains, std::ostream& output,
29+
const std::string& comment_prefix = "")
30+
: num_chains_(num_chains), output_(output),
31+
comment_prefix_(comment_prefix)
32+
{}
33+
34+
/**
35+
* Virtual destructor
36+
*/
37+
virtual ~mpi_stream_writer() {}
38+
39+
/**
40+
* Set new value for @c num_chains_.
41+
*
42+
* @param[in] n new value of @c num_chains_
43+
*/
44+
void set_num_chains(int n) {
45+
num_chains_ = n;
46+
}
47+
48+
/**
49+
* Writes a set of names on a single line in csv format followed
50+
* by a newline.
51+
*
52+
* Note: the names are not escaped.
53+
*
54+
* @param[in] names Names in a std::vector
55+
*/
56+
void operator()(const std::vector<std::string>& names) {
57+
write_vector(names);
58+
}
59+
60+
/**
61+
* Writes a set of values in csv format followed by a newline.
62+
*
63+
* Note: the precision of the output is determined by the settings
64+
* of the stream on construction.
65+
*
66+
* @param[in] state Values in a std::vector
67+
*/
68+
void operator()(const std::vector<double>& state) {
69+
write_vector(state);
70+
}
71+
72+
/**
73+
* Writes the comment_prefix to the stream followed by a newline.
74+
*/
75+
void operator()() {
76+
if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) {
77+
output_ << comment_prefix_ << std::endl;
78+
}
79+
}
80+
81+
/**
82+
* Writes the comment_prefix then the message followed by a newline.
83+
*
84+
* @param[in] message A string
85+
*/
86+
void operator()(const std::string& message) {
87+
if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) {
88+
output_ << comment_prefix_ << message << std::endl;
89+
}
90+
}
91+
92+
private:
93+
94+
/**
95+
* nb. of chains that have its own output stream
96+
*/
97+
int num_chains_;
98+
99+
/**
100+
* Output stream
101+
*/
102+
std::ostream& output_;
103+
104+
/**
105+
* Comment prefix to use when printing comments: strings and blank lines
106+
*/
107+
std::string comment_prefix_;
108+
109+
/**
110+
* Writes a set of values in csv format followed by a newline.
111+
*
112+
* Note: the precision of the output is determined by the settings
113+
* of the stream on construction.
114+
*
115+
* @param[in] v Values in a std::vector
116+
*/
117+
template <class T>
118+
void write_vector(const std::vector<T>& v) {
119+
if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) {
120+
if (v.empty()) return;
121+
122+
typename std::vector<T>::const_iterator last = v.end();
123+
--last;
124+
125+
for (typename std::vector<T>::const_iterator it = v.begin();
126+
it != last; ++it)
127+
output_ << *it << ",";
128+
output_ << v.back() << std::endl;
129+
}
130+
}
131+
};
132+
133+
}
134+
}
135+
136+
#endif
137+
138+
#endif

src/stan/callbacks/stream_writer.hpp

Lines changed: 88 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,101 +7,104 @@
77
#include <string>
88

99
namespace stan {
10-
namespace callbacks {
10+
namespace callbacks {
1111

12-
/**
13-
* <code>stream_writer</code> is an implementation
14-
* of <code>writer</code> that writes to a stream.
15-
*/
16-
class stream_writer : public writer {
17-
public:
18-
/**
19-
* Constructs a stream writer with an output stream
20-
* and an optional prefix for comments.
21-
*
22-
* @param[in, out] output stream to write
23-
* @param[in] comment_prefix string to stream before
24-
* each comment line. Default is "".
25-
*/
26-
explicit stream_writer(std::ostream& output,
27-
const std::string& comment_prefix = "")
28-
: output_(output), comment_prefix_(comment_prefix) {}
12+
/**
13+
* <code>stream_writer</code> is an implementation
14+
* of <code>writer</code> that writes to a stream.
15+
*/
16+
class stream_writer : public writer {
17+
public:
18+
/**
19+
* Constructs a stream writer with an output stream
20+
* and an optional prefix for comments.
21+
*
22+
* @param[in, out] output stream to write
23+
* @param[in] comment_prefix string to stream before
24+
* each comment line. Default is "".
25+
*/
26+
stream_writer(std::ostream& output,
27+
const std::string& comment_prefix = ""):
28+
output_(output), comment_prefix_(comment_prefix) {}
2929

30-
/**
31-
* Virtual destructor
32-
*/
33-
virtual ~stream_writer() {}
30+
/**
31+
* Virtual destructor
32+
*/
33+
virtual ~stream_writer() {}
3434

35-
/**
36-
* Writes a set of names on a single line in csv format followed
37-
* by a newline.
38-
*
39-
* Note: the names are not escaped.
40-
*
41-
* @param[in] names Names in a std::vector
42-
*/
43-
void operator()(const std::vector<std::string>& names) {
44-
write_vector(names);
45-
}
35+
/**
36+
* Writes a set of names on a single line in csv format followed
37+
* by a newline.
38+
*
39+
* Note: the names are not escaped.
40+
*
41+
* @param[in] names Names in a std::vector
42+
*/
43+
void operator()(const std::vector<std::string>& names) {
44+
write_vector(names);
45+
}
4646

47-
/**
48-
* Writes a set of values in csv format followed by a newline.
49-
*
50-
* Note: the precision of the output is determined by the settings
51-
* of the stream on construction.
52-
*
53-
* @param[in] state Values in a std::vector
54-
*/
55-
void operator()(const std::vector<double>& state) { write_vector(state); }
47+
/**
48+
* Writes a set of values in csv format followed by a newline.
49+
*
50+
* Note: the precision of the output is determined by the settings
51+
* of the stream on construction.
52+
*
53+
* @param[in] state Values in a std::vector
54+
*/
55+
void operator()(const std::vector<double>& state) {
56+
write_vector(state);
57+
}
5658

57-
/**
58-
* Writes the comment_prefix to the stream followed by a newline.
59-
*/
60-
void operator()() { output_ << comment_prefix_ << std::endl; }
59+
/**
60+
* Writes the comment_prefix to the stream followed by a newline.
61+
*/
62+
void operator()() {
63+
output_ << comment_prefix_ << std::endl;
64+
}
6165

62-
/**
63-
* Writes the comment_prefix then the message followed by a newline.
64-
*
65-
* @param[in] message A string
66-
*/
67-
void operator()(const std::string& message) {
68-
output_ << comment_prefix_ << message << std::endl;
69-
}
66+
/**
67+
* Writes the comment_prefix then the message followed by a newline.
68+
*
69+
* @param[in] message A string
70+
*/
71+
void operator()(const std::string& message) {
72+
output_ << comment_prefix_ << message << std::endl;
73+
}
7074

71-
private:
72-
/**
73-
* Output stream
74-
*/
75-
std::ostream& output_;
75+
private:
76+
/**
77+
* Output stream
78+
*/
79+
std::ostream& output_;
7680

77-
/**
78-
* Comment prefix to use when printing comments: strings and blank lines
79-
*/
80-
std::string comment_prefix_;
81+
/**
82+
* Comment prefix to use when printing comments: strings and blank lines
83+
*/
84+
std::string comment_prefix_;
8185

82-
/**
83-
* Writes a set of values in csv format followed by a newline.
84-
*
85-
* Note: the precision of the output is determined by the settings
86-
* of the stream on construction.
87-
*
88-
* @param[in] v Values in a std::vector
89-
*/
90-
template <class T>
91-
void write_vector(const std::vector<T>& v) {
92-
if (v.empty())
93-
return;
86+
/**
87+
* Writes a set of values in csv format followed by a newline.
88+
*
89+
* Note: the precision of the output is determined by the settings
90+
* of the stream on construction.
91+
*
92+
* @param[in] v Values in a std::vector
93+
*/
94+
template <class T>
95+
void write_vector(const std::vector<T>& v) {
96+
if (v.empty()) return;
9497

95-
typename std::vector<T>::const_iterator last = v.end();
96-
--last;
98+
typename std::vector<T>::const_iterator last = v.end();
99+
--last;
97100

98-
for (typename std::vector<T>::const_iterator it = v.begin(); it != last;
99-
++it)
100-
output_ << *it << ",";
101-
output_ << v.back() << std::endl;
102-
}
103-
};
101+
for (typename std::vector<T>::const_iterator it = v.begin();
102+
it != last; ++it)
103+
output_ << *it << ",";
104+
output_ << v.back() << std::endl;
105+
}
106+
};
104107

105-
} // namespace callbacks
106-
} // namespace stan
108+
}
109+
}
107110
#endif

0 commit comments

Comments
 (0)