Skip to content

Commit 40380ce

Browse files
committed
fix tests
1 parent c8a60f3 commit 40380ce

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

src/belief.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Run the belief propagation algorithm, and return the final state and the informa
115115
116116
### Keyword Arguments
117117
- `max_iter::Int=100`: the maximum number of iterations
118-
- `tol::Float64=1e-6`: the tolerance for the convergence
118+
- `tol::Float64=1e-6`: the tolerance for the convergence, the convergence is checked by infidelity of messages in consecutive iterations. For complex numbers, the converged message may be different only by a phase factor.
119119
- `damping::Float64=0.2`: the damping factor for the message update, updated-message = damping * old-message + (1 - damping) * new-message
120120
"""
121121
function belief_propagate(bp::BeliefPropgation; kwargs...)
@@ -133,14 +133,21 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In
133133
collect_message!(bp, state; normalize = true)
134134
process_message!(state; normalize = true, damping = damping)
135135
# check convergence
136-
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
136+
if all(iv -> all(it -> message_converged(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
137137
return BPInfo(true, i)
138138
end
139139
pre_message_in = deepcopy(state.message_in)
140140
end
141141
return BPInfo(false, max_iter)
142142
end
143143

144+
# check if two messages are converged by fidelity (needed for complex numbers)
145+
function message_converged(a, b; atol)
146+
a_norm = norm(a)
147+
b_norm = norm(b)
148+
return isapprox(a_norm, b_norm, atol=atol) && isapprox(sqrt(abs(a' * b)), a_norm, atol=atol)
149+
end
150+
144151
# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
145152
function contraction_results(state::BPState{T}) where {T}
146153
return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in]

test/belief.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,21 @@ end
4949
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi)
5050
bp = BeliefPropgation(mps_uai)
5151
@test TensorInference.initial_state(bp) isa TensorInference.BPState
52-
state, info = belief_propagate(bp)
52+
state, info = belief_propagate(bp; max_iter=100, tol=1e-8)
5353
@test info.converged
5454
@test info.iterations < 20
5555
mars = marginals(state)
5656
tnet = TensorNetworkModel(mps_uai)
5757
mars_tnet = marginals(tnet)
5858
for v in 1:TensorInference.num_variables(bp)
59-
@test mars[[v]] mars_tnet[[v]] atol=1e-6
59+
@test mars[[v]] mars_tnet[[v]] atol=1e-4
6060
end
6161
end
6262

63-
@testset "belief propagation on circle" begin
63+
@testset "belief propagation on circle (Real)" begin
6464
n = 10
6565
chi = 3
66-
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge
66+
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true)
6767
bp = BeliefPropgation(mps_uai)
6868
@test TensorInference.initial_state(bp) isa TensorInference.BPState
6969
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)
@@ -78,6 +78,25 @@ end
7878
end
7979
end
8080

81+
82+
@testset "belief propagation on circle (Complex)" begin
83+
n = 10
84+
chi = 3
85+
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge
86+
bp = BeliefPropgation(mps_uai)
87+
@test TensorInference.initial_state(bp) isa TensorInference.BPState
88+
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)
89+
@test info.converged
90+
@test info.iterations < 100
91+
contraction_res = TensorInference.contraction_results(state)
92+
tnet = TensorNetworkModel(mps_uai)
93+
mars = marginals(state)
94+
mars_tnet = marginals(tnet)
95+
for v in 1:TensorInference.num_variables(bp)
96+
@test TensorInference.message_converged(mars[[v]], mars_tnet[[v]]; atol=1e-4)
97+
end
98+
end
99+
81100
@testset "marginal uai2014" begin
82101
for problem in [problem_from_artifact("uai2014", "MAR", "Promedus", 14), problem_from_artifact("uai2014", "MAR", "ObjectDetection", 42)]
83102
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)

0 commit comments

Comments
 (0)