Skip to content

Commit 0eedb51

Browse files
committed
fix issue #3: more robust indexing
all test should now pass
1 parent 968bc8a commit 0eedb51

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

src/kernels.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,22 @@ end
3232

3333

3434
# BTRS algorithm, adapted from the tensorflow library (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/random_binomial_op.cc)
35-
function kernel_BTRS!(A, count, prob, randstates)
35+
function kernel_BTRS!(A, count, prob, randstates, R1, R2, Rp, Ra, count_dim_larger_than_prob_dim)
3636
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
37-
indices = CartesianIndices(A)
3837

3938
@inbounds if i <= length(A)
40-
I = indices[i].I
41-
n = count[CartesianIndex(I[1:ndims(count)])]
42-
p = prob[CartesianIndex(I[1:ndims(prob)])]
39+
I = Ra[i]
40+
Ip = Rp[I[1]]
41+
I1 = R1[Ip[1]]
42+
I2 = R2[Ip[2]]
43+
44+
if count_dim_larger_than_prob_dim
45+
n = count[CartesianIndex(I1, I2)]
46+
p = prob[I1]
47+
else
48+
n = count[I1]
49+
p = prob[CartesianIndex(I1, I2)]
50+
end
4351

4452
# wrong parameter values (currently disabled)
4553
# n < 0 && throw(ArgumentError("kernel_BTRS!: count must be a nonnegative integer."))

src/rand_binomial.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,25 @@ function rand_binom!(rng, A::BinomialArray, count::BinomialArray, prob::DenseCuA
6767
return A
6868
end
6969
if size(A)[1:ndims(count)] == size(count) && size(A)[1:ndims(prob)] == size(prob)
70-
kernel = @cuda name="BTRS_full" launch=false kernel_BTRS!(A, count, prob, rng.state)
70+
count_dim_larger_than_prob_dim = ndims(count) > ndims(prob)
71+
if count_dim_larger_than_prob_dim
72+
R1 = CartesianIndices(prob) # indices for count
73+
R2 = CartesianIndices(size(count)[ndims(prob)+1:end]) # indices for prob that are not included in R1
74+
Rr = CartesianIndices(size(A)[ndims(count)+1:end]) # remaining indices in A
75+
else
76+
R1 = CartesianIndices(count) # indices for count
77+
R2 = CartesianIndices(size(prob)[ndims(count)+1:end]) # indices for prob that are not included in R1
78+
Rr = CartesianIndices(size(A)[ndims(prob)+1:end]) # remaining indices in A
79+
end
80+
Rp = CartesianIndices((length(R1), length(R2))) # indices for parameters
81+
Ra = CartesianIndices((length(Rp), length(Rr))) # indices for parameters and A
82+
83+
kernel = @cuda name="BTRS_full" launch=false kernel_BTRS!(A, count, prob, rng.state, R1, R2, Rp, Ra, count_dim_larger_than_prob_dim)
7184
config = launch_configuration(kernel.fun)
7285
threads = Base.min(length(A), config.threads, 256) # strangely seems to be faster when defaulting to 256 threads
7386
blocks = cld(length(A), threads)
74-
kernel(A, count, prob, rng.state; threads=threads, blocks=blocks)
87+
88+
kernel(A, count, prob, rng.state, R1, R2, Rp, Ra, count_dim_larger_than_prob_dim; threads=threads, blocks=blocks)
7589
else
7690
throw(DimensionMismatch("`count` and `prob` need have size compatible with A"))
7791
end

0 commit comments

Comments
 (0)