@@ -3,6 +3,7 @@ using IterableTables
3
3
using TableTraits
4
4
using Random
5
5
using ProgressMeter
6
+ using Serialization
6
7
7
8
function Base. show (io:: IO , mcs:: MonteCarloSimulation )
8
9
println (" MonteCarloSimulation" )
@@ -83,11 +84,11 @@ function _store_trial_results(mcs::MonteCarloSimulation, trialnum::Int)
83
84
end
84
85
85
86
"""
86
- save_trial_results(mcs::MonteCarloSimulation, output_dir::String )
87
+ save_trial_results(mcs::MonteCarloSimulation, output_dir)
87
88
88
89
Save the stored MCS results to files in the directory `output_dir`
89
90
"""
90
- function save_trial_results (mcs:: MonteCarloSimulation , output_dir:: AbstractString )
91
+ function save_trial_results (mcs:: MonteCarloSimulation , output_dir)
91
92
multiple_results = (length (mcs. results) > 1 )
92
93
93
94
for (i, results) in enumerate (mcs. results)
@@ -106,8 +107,44 @@ function save_trial_results(mcs::MonteCarloSimulation, output_dir::AbstractStrin
106
107
end
107
108
end
108
109
109
- function save_trial_inputs (mcs:: MonteCarloSimulation , filename:: String )
110
- mkpath (dirname (filename), mode= 0o770 ) # ensure that the specified path exists
110
+ function _dummy_writer (mcs, pname, dir)
111
+ @info " _dummy_writer(pname=$pname , dir=$dir )"
112
+ end
113
+
114
+
115
+ """
116
+ _serialize(mcs::MonteCarloSimulation, pname, dir)
117
+
118
+ This is the default function used for writing non-bit-type (array-valued)
119
+ random vars, e.g., those produced by ReshapedDistribution.
120
+
121
+ This default method uses Serialization.serialize to write out the array.
122
+ No application-defined types are written, so the file should be robust.
123
+ """
124
+ function _serialize (mcs:: MonteCarloSimulation , pname, dir)
125
+ rv = mcs. rvdict[pname]
126
+ if rv isa RandomVariable{Mimi. SampleStore{T}} where T
127
+ name = first (split (string (pname), " !" )) # e.g., :alpha!1 --> "alpha"
128
+ path = joinpath (dir, " $name .dat" )
129
+ serialize (path, mcs. rvdict[pname]. dist. values)
130
+ else
131
+ error (" Tried to _serialize $pname , which isn't a RandomVariable{SampleStore{T}}" )
132
+ end
133
+ end
134
+
135
+ function save_trial_inputs (mcs:: MonteCarloSimulation , filename, non_bit_writer= _serialize)
136
+ dir = dirname (filename)
137
+ mkpath (dir, mode= 0o770 ) # ensure that the specified path exists
138
+
139
+ # If ReshapedDistribution was used, a single trial value for a field
140
+ # will be an Array of values. CSV format doesn't handle these well,
141
+ # so we serialize the arrays into a file using the field's name.
142
+ non_bits = [name for (name, T) in zip (column_names (mcs), column_types (mcs)) if ! isbitstype (T)]
143
+ for pname in non_bits
144
+ non_bit_writer (mcs, pname, dir)
145
+ end
146
+
147
+ # TBD: avoid writing array values to CSV files...
111
148
save (filename, mcs)
112
149
return nothing
113
150
end
@@ -190,7 +227,7 @@ function _copy_mcs_params(mcs::MonteCarloSimulation)
190
227
191
228
for (i, m) in enumerate (mcs. models)
192
229
md = m. mi. md
193
- param_vec[i] = Dict {Symbol, ModelParameter} (trans. paramname => copy (external_param (md, trans. paramname)) for trans in mcs. translist)
230
+ param_vec[i] = Dict {Symbol, ModelParameter} (trans. paramname => deepcopy (external_param (md, trans. paramname)) for trans in mcs. translist)
194
231
end
195
232
196
233
return param_vec
@@ -225,7 +262,13 @@ function _param_indices(param::ArrayModelParameter{T}, md::ModelDef, trans::Tran
225
262
num_pdims = length (pdims)
226
263
227
264
tdims = trans. dims
228
- num_dims = length (tdims)
265
+ num_dims = length (tdims)
266
+
267
+ # special case for handling reshaped data where a single draw returns a matrix of values
268
+ if num_dims == 0
269
+ indices = repeat ([Colon ()], num_pdims)
270
+ return indices
271
+ end
229
272
230
273
if num_pdims != num_dims
231
274
pname = trans. paramname
@@ -256,7 +299,8 @@ function _perturb_param!(param::ScalarModelParameter{T}, md::ModelDef, trans::Tr
256
299
end
257
300
end
258
301
259
- function _perturb_param! (param:: ArrayModelParameter{T} , md:: ModelDef , trans:: TransformSpec , rvalue:: Number ) where T
302
+ function _perturb_param! (param:: ArrayModelParameter{T} , md:: ModelDef ,
303
+ trans:: TransformSpec , rvalue:: Union{Number, Array{<: Number, N}} ) where {T, N}
260
304
op = trans. op
261
305
pvalue = value (param)
262
306
indices = _param_indices (param, md, trans)
@@ -539,6 +583,9 @@ IterableTables.getiterator(mcs::MonteCarloSimulation) = MCSIterator{mcs.nt_type}
539
583
column_names (mcs:: MonteCarloSimulation ) = fieldnames (mcs. nt_type)
540
584
column_types (mcs:: MonteCarloSimulation ) = [eltype (fld) for fld in values (mcs. rvdict)]
541
585
586
+ # TBD: strip the "!1" off the end of each field name?
587
+ # column_names(iter::MCSIterator) = Tuple([first(split(string(name), "!")) for name in fieldnames(iter.mcs.nt_type)])
588
+
542
589
column_names (iter:: MCSIterator ) = column_names (iter. mcs)
543
590
column_types (iter:: MCSIterator ) = IterableTables. column_types (iter. mcs)
544
591
0 commit comments