Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/belief.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
# TODO: speed up if needed!
code = star_code(length(vectors_in))
cost, gradient = cost_and_gradient(code, (t, vectors_in...))
for (o, g) in zip(vectors_out, gradient[2:end])
for (o, g) in zip(vectors_out, conj.(gradient[2:end]))
o .= g
end
return cost[]
Expand Down
1 change: 1 addition & 0 deletions src/mar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ probabilities of the queried variables, represented by tensors.
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,))
grads = conj.(grads)
@debug "cost = $cost"
ixs = OMEinsum.getixsv(tn.code)
queryvars = ixs[tn.unity_tensors_idx]
Expand Down
4 changes: 2 additions & 2 deletions test/belief.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
@testset "belief propagation" begin
n = 5
chi = 3
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi)
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test both Float64 and ComplexF64.

bp = BeliefPropgation(mps_uai)
@test TensorInference.initial_state(bp) isa TensorInference.BPState
state, info = belief_propagate(bp)
Expand All @@ -63,7 +63,7 @@ end
@testset "belief propagation on circle" begin
n = 10
chi = 3
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true)
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge
bp = BeliefPropgation(mps_uai)
@test TensorInference.initial_state(bp) isa TensorInference.BPState
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)
Expand Down
Loading