Skip to content

Commit 8c796dd

Browse files
Merge pull request #30 from LCSB-BioCore/mk-dpmap2
add distributed pool map
2 parents 682251f + 9c3d91d commit 8c796dd

File tree

6 files changed

+79
-72
lines changed

6 files changed

+79
-72
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "DistributedData"
22
uuid = "f6a0035f-c5ac-4ad0-b410-ad102ced35df"
33
authors = ["Mirek Kratochvil <[email protected]>",
44
"LCSB R3 team <[email protected]>"]
5-
version = "0.1.2"
5+
version = "0.1.3"
66

77
[deps]
88
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

src/DistributedData.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ export save_at,
1717
dtransform,
1818
dmapreduce,
1919
dmap,
20+
dpmap,
2021
gather_array,
2122
tmp_symbol
2223

2324
include("io.jl")
24-
export dstore,
25-
dload,
26-
dunlink
25+
export dstore, dload, dunlink
2726

2827
include("tools.jl")
2928
export dcopy,

src/base.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The symbols are saved in Main module on the corresponding worker. For example,
3030
`save_at(1, :x, nothing)` _will_ erase your local `x` variable. Beware of name
3131
collisions.
3232
"""
33-
function save_at(worker, sym::Symbol, val; mod=Main)
33+
function save_at(worker, sym::Symbol, val; mod = Main)
3434
remotecall(() -> Base.eval(mod, :(
3535
begin
3636
$sym = $val
@@ -45,7 +45,7 @@ end
4545
Get a value `val` from a remote `worker`; quoting of `val` works just as with
4646
`save_at`. Returns a future with the requested value.
4747
"""
48-
function get_from(worker, val; mod=Main)
48+
function get_from(worker, val; mod = Main)
4949
remotecall(() -> Base.eval(mod, :($val)), worker)
5050
end
5151

@@ -162,11 +162,7 @@ end
162162
163163
Same as `dtransform`, but specialized for `Dinfo`.
164164
"""
165-
function dtransform(
166-
dInfo::Dinfo,
167-
fn,
168-
tgt::Symbol = dInfo.val,
169-
)::Dinfo
165+
function dtransform(dInfo::Dinfo, fn, tgt::Symbol = dInfo.val)::Dinfo
170166
dtransform(dInfo.val, fn, dInfo.workers, tgt)
171167
end
172168

@@ -314,11 +310,34 @@ Call a function `fn` on `workers`, with a single parameter arriving from the
314310
corresponding position in `arr`.
315311
"""
316312
function dmap(arr::Vector, fn, workers)
317-
futures = [
318-
remotecall(() -> Base.eval(Main, :($fn($(arr[i])))), pid) #TODO convert to get_from
319-
for (i, pid) in enumerate(workers)
320-
]
321-
return [fetch(f) for f in futures]
313+
fetch.([get_from(w, :($fn($(arr[i])))) for (i, w) in enumerate(workers)])
314+
end
315+
316+
"""
317+
dpmap(fn, args...; mod = Main, kwargs...)
318+
319+
"Distributed pool map."
320+
321+
A wrapper for `pmap` from `Distributed` package that executes the code in the
322+
correct module, so that it can access the distributed variables at remote
323+
workers. All arguments other than the first function `fn` are passed to `pmap`.
324+
325+
The function `fn` should return an expression that is going to get evaluated.
326+
327+
# Example
328+
329+
```julia
330+
using Distributed
331+
dpmap(x -> :(computeSomething(someData, \$x)), WorkerPool(workers), Vector(1:10))
332+
```
333+
334+
```julia
335+
di = distributeSomeData()
336+
dpmap(x -> :(computeSomething(\$(di.val), \$x)), CachingPool(di.workers), Vector(1:10))
337+
```
338+
"""
339+
function dpmap(fn, args...; mod = Main, kwargs...)
340+
return pmap(x -> Base.eval(mod, fn(x)), args...; kwargs...)
322341
end
323342

324343
"""

src/io.jl

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ end
3333
3434
Overloaded functionality for `Dinfo`.
3535
"""
36-
function dstore(
37-
dInfo::Dinfo,
38-
files = defaultFiles(dInfo.val, dInfo.workers),
39-
)
36+
function dstore(dInfo::Dinfo, files = defaultFiles(dInfo.val, dInfo.workers))
4037
dstore(dInfo.val, dInfo.workers, files)
4138
end
4239

@@ -47,16 +44,12 @@ Import the content of symbol `sym` by each worker specified by `pids` from the
4744
corresponding filename in `files`.
4845
"""
4946
function dload(sym::Symbol, pids, files = defaultFiles(sym, pids))
50-
dmap(
51-
files,
52-
(fn) -> Base.eval(Main, :(
53-
begin
54-
$sym = open($deserialize, $fn)
55-
nothing
56-
end
57-
)),
58-
pids,
59-
)
47+
dmap(files, (fn) -> Base.eval(Main, :(
48+
begin
49+
$sym = open($deserialize, $fn)
50+
nothing
51+
end
52+
)), pids)
6053
return Dinfo(sym, pids)
6154
end
6255

@@ -65,10 +58,7 @@ end
6558
6659
Overloaded functionality for `Dinfo`.
6760
"""
68-
function dload(
69-
dInfo::Dinfo,
70-
files = defaultFiles(dInfo.val, dInfo.workers),
71-
)
61+
function dload(dInfo::Dinfo, files = defaultFiles(dInfo.val, dInfo.workers))
7262
dload(dInfo.val, dInfo.workers, files)
7363
end
7464

@@ -87,9 +77,6 @@ end
8777
8878
Overloaded functionality for `Dinfo`.
8979
"""
90-
function dunlink(
91-
dInfo::Dinfo,
92-
files = defaultFiles(dInfo.val, dInfo.workers),
93-
)
80+
function dunlink(dInfo::Dinfo, files = defaultFiles(dInfo.val, dInfo.workers))
9481
dunlink(dInfo.val, dInfo.workers, files)
9582
end

src/tools.jl

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@ end
1313
1414
Reduce dataset to selected columns, optionally save it under a different name.
1515
"""
16-
function dselect(
17-
dInfo::Dinfo,
18-
columns::Vector{Int},
19-
tgt::Symbol = dInfo.val,
20-
)::Dinfo
16+
function dselect(dInfo::Dinfo, columns::Vector{Int}, tgt::Symbol = dInfo.val)::Dinfo
2117
dtransform(dInfo, mtx -> mtx[:, columns], tgt)
2218
end
2319

@@ -91,10 +87,7 @@ end
9187
Compute mean and standard deviation of the columns in dataset. Returns a tuple
9288
with a vector of means in `columns`, and a vector of corresponding sdevs.
9389
"""
94-
function dstat(
95-
dInfo::Dinfo,
96-
columns::Vector{Int},
97-
)::Tuple{Vector{Float64},Vector{Float64}}
90+
function dstat(dInfo::Dinfo, columns::Vector{Int})::Tuple{Vector{Float64},Vector{Float64}}
9891

9992
sum_squares = x -> sum(x .^ 2)
10093

@@ -136,8 +129,7 @@ function dstat_buckets(
136129
)
137130

138131
# extract the bucketed stats
139-
(sums, sqsums, ns) =
140-
dmapreduce([dInfo, buckets], get_bucketed_stats, combine_stats)
132+
(sums, sqsums, ns) = dmapreduce([dInfo, buckets], get_bucketed_stats, combine_stats)
141133

142134
return (
143135
sums ./ ns, #means
@@ -285,7 +277,8 @@ less or higher than `targets`.
285277
"""
286278
function update_extrema(counts, targets, lims, mids)
287279
broadcast(
288-
(cnt, target, lim, mid) -> cnt >= target ? # if the count is too high,
280+
(cnt, target, lim, mid) ->
281+
cnt >= target ? # if the count is too high,
289282
(lim[1], mid) : # median is going to be in the lower half
290283
(mid, lim[2]), # otherwise in the higher half
291284
counts,
@@ -313,11 +306,8 @@ function dmedian(dInfo::Dinfo, columns::Vector{Int}; iters = 20)
313306
target = dmapreduce(dInfo, d -> size(d, 1), +) ./ 2
314307

315308
# current estimation range for the median (tuples of min, max)
316-
lims = dmapreduce(
317-
dInfo,
318-
d -> mapslices(extrema, d[:, columns], dims = 1),
319-
reduce_extrema,
320-
)
309+
lims =
310+
dmapreduce(dInfo, d -> mapslices(extrema, d[:, columns], dims = 1), reduce_extrema)
321311

322312
# convert the limits to a simple vector
323313
lims = cat(lims..., dims = 1)
@@ -368,8 +358,8 @@ function dmedian_buckets(
368358
get_bucket_extrema =
369359
(d, b) -> catmapbuckets(
370360
(_, x) -> length(x) > 0 ? # if there are some elements
371-
extrema(x) : # just take the extrema
372-
(Inf, -Inf), # if not, use backup values
361+
extrema(x) : # just take the extrema
362+
(Inf, -Inf), # if not, use backup values
373363
d[:, columns],
374364
nbuckets,
375365
b,
@@ -384,21 +374,22 @@ function dmedian_buckets(
384374
# this counts the elements smaller than mids in buckets
385375
# (both mids and elements are bucketed and column-sliced into matrices)
386376
bucketed_count_smaller_than_mids =
387-
(d, b) -> vcat(mapbuckets(
388-
(bucketID, d) ->
389-
[
390-
count(x -> x < mids[bucketID, colID], d[:, colID])
391-
for (colID, c) in enumerate(columns)
392-
]',
393-
d,
394-
nbuckets,
395-
b,
396-
slicedims = (1, 2),
397-
)...)
377+
(d, b) -> vcat(
378+
mapbuckets(
379+
(bucketID, d) ->
380+
[
381+
count(x -> x < mids[bucketID, colID], d[:, colID]) for
382+
(colID, c) in enumerate(columns)
383+
]',
384+
d,
385+
nbuckets,
386+
b,
387+
slicedims = (1, 2),
388+
)...,
389+
)
398390

399391
# gather the counts
400-
counts =
401-
dmapreduce([dInfo, buckets], bucketed_count_smaller_than_mids, +)
392+
counts = dmapreduce([dInfo, buckets], bucketed_count_smaller_than_mids, +)
402393

403394
lims = update_extrema(counts, targets, lims, mids)
404395
end

test/base.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,21 @@
8585
@test dmapreduce(:noname, x -> x, (a, b) -> a + b, []) == nothing
8686
end
8787

88+
@testset "`pmap` on distributed data" begin
89+
fetch.(save_at.(W, :test, 1234321))
90+
di = Dinfo(:test, W) # also test the example in docs
91+
@test dpmap(
92+
x -> :($(di.val) + $x),
93+
WorkerPool(di.workers),
94+
[4321234, 1234, 4321],
95+
) == [5555555, 1235555, 1238642]
96+
fetch.(remove_from.(W, :test))
97+
end
98+
8899
@testset "Internal utilities" begin
89100
@test DistributedData.tmp_symbol(:test) != :test
90-
@test DistributedData.tmp_symbol(:test, prefix = "abc",
91-
suffix = "def") == :abctestdef
101+
@test DistributedData.tmp_symbol(:test, prefix = "abc", suffix = "def") ==
102+
:abctestdef
92103
@test DistributedData.tmp_symbol(Dinfo(:test, W)) != :test
93104
end
94105

0 commit comments

Comments
 (0)