Skip to content

Commit c8a60f3

Browse files
Add conj with cost_and_gradient
Signed-off-by: 周唤海 <[email protected]>
1 parent d914583 commit c8a60f3

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

src/belief.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
8080
# TODO: speed up if needed!
8181
code = star_code(length(vectors_in))
8282
cost, gradient = cost_and_gradient(code, (t, vectors_in...))
83-
for (o, g) in zip(vectors_out, gradient[2:end])
83+
for (o, g) in zip(vectors_out, conj.(gradient[2:end]))
8484
o .= g
8585
end
8686
return cost[]

src/mar.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ probabilities of the queried variables, represented by tensors.
7878
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
7979
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
8080
cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,))
81+
grads = conj.(grads)
8182
@debug "cost = $cost"
8283
ixs = OMEinsum.getixsv(tn.code)
8384
queryvars = ixs[tn.unity_tensors_idx]

test/belief.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646
@testset "belief propagation" begin
4747
n = 5
4848
chi = 3
49-
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi)
49+
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi)
5050
bp = BeliefPropgation(mps_uai)
5151
@test TensorInference.initial_state(bp) isa TensorInference.BPState
5252
state, info = belief_propagate(bp)
@@ -63,7 +63,7 @@ end
6363
@testset "belief propagation on circle" begin
6464
n = 10
6565
chi = 3
66-
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true)
66+
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge
6767
bp = BeliefPropgation(mps_uai)
6868
@test TensorInference.initial_state(bp) isa TensorInference.BPState
6969
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)

0 commit comments

Comments
 (0)