You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The inversion of p -> 1-p was not triggered sometimes, leading to wrong statistics in part of the sampled array.
Added distributional tests for mean and variance.
# BTRS algorithm, adapted from the tensorflow library (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/random_binomial_op.cc)
35
35
36
-
## Kernel for scalar parameters
36
+
## Kernels for scalar parameters
37
+
functionkernel_naive_scalar!(A, n, p, seed::UInt32, counter::UInt32)
38
+
device_rng = Random.default_rng()
39
+
40
+
# initialize the state
41
+
@inbounds Random.seed!(device_rng, seed, counter)
42
+
43
+
# grid-stride loop
44
+
tid =threadIdx().x
45
+
window = (blockDim().x -1i32) *gridDim().x
46
+
offset = (blockIdx().x -1i32) *blockDim().x
47
+
48
+
while offset <length(A)
49
+
i = tid + offset
50
+
51
+
k =0
52
+
ctr =1
53
+
while ctr <= n
54
+
rand(Float32) < p && (k +=1)
55
+
ctr +=1
56
+
end
57
+
58
+
if i <=length(A)
59
+
@inbounds A[i] = k
60
+
end
61
+
offset += window
62
+
end
63
+
returnnothing
64
+
end
65
+
functionkernel_inversion_scalar!(A, n, p, seed::UInt32, counter::UInt32)
66
+
device_rng = Random.default_rng()
67
+
68
+
# initialize the state
69
+
@inbounds Random.seed!(device_rng, seed, counter)
70
+
71
+
# grid-stride loop
72
+
tid =threadIdx().x
73
+
window = (blockDim().x -1i32) *gridDim().x
74
+
offset = (blockIdx().x -1i32) *blockDim().x
75
+
76
+
while offset <length(A)
77
+
i = tid + offset
78
+
79
+
logp = CUDA.log(1f0-p)
80
+
geom_sum =0f0
81
+
k =0
82
+
whiletrue
83
+
geom =ceil(CUDA.log(rand(Float32)) / logp)
84
+
geom_sum += geom
85
+
geom_sum > n &&break
86
+
k +=1
87
+
end
88
+
89
+
if i <=length(A)
90
+
@inbounds A[i] = k
91
+
end
92
+
offset += window
93
+
end
94
+
returnnothing
95
+
end
37
96
functionkernel_BTRS_scalar!(A, n, p, seed::UInt32, counter::UInt32)
38
97
device_rng = Random.default_rng()
39
98
@@ -49,80 +108,47 @@ function kernel_BTRS_scalar!(A, n, p, seed::UInt32, counter::UInt32)
0 commit comments