Skip to content

Commit 8295474

Browse files
committed
clean up
1 parent 370fad7 commit 8295474

File tree

12 files changed

+29
-134
lines changed

12 files changed

+29
-134
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
44
version = "0.5.0"
55

66
[deps]
7-
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
87
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
98
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
109
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
@@ -21,11 +20,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2120
TensorInferenceCUDAExt = "CUDA"
2221

2322
[compat]
24-
Artifacts = "1"
2523
CUDA = "4, 5"
2624
DocStringExtensions = "0.8.6, 0.9"
2725
LinearAlgebra = "1"
28-
OMEinsum = "0.8"
26+
OMEinsum = "0.8.7"
2927
Pkg = "1"
3028
PrettyTables = "2"
3129
ProblemReductions = "0.3"

ext/TensorInferenceCUDAExt.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
module TensorInferenceCUDAExt
22
using CUDA: CuArray
33
import CUDA
4-
import TensorInference: match_arraytype, keep_only!, onehot_like, togpu
4+
import TensorInference: keep_only!, onehot_like, togpu
55

66
function onehot_like(A::CuArray, j)
77
mask = zero(A)
88
CUDA.@allowscalar mask[j] = one(eltype(mask))
99
return mask
1010
end
1111

12-
# NOTE: this interface should be in OMEinsum
13-
match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target)
14-
1512
function keep_only!(x::CuArray{T}, j) where T
1613
CUDA.@allowscalar hotvalue = x[j]
1714
fill!(x, zero(T))

src/Core.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,4 @@ Returns the contraction complexity of a tensor newtork model.
190190
"""
191191
function OMEinsum.contraction_complexity(tn::TensorNetworkModel)
192192
return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone = true))))
193-
end
194-
195-
# adapt array type with the target array type
196-
match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target)
193+
end

src/RescaledArray.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,9 @@ end
4646
Base.size(arr::RescaledArray) = size(arr.normalized_value)
4747
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)
4848

49-
match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target))
49+
function OMEinsum.get_output_array(xs::NTuple{N, RescaledArray{T}}, size, fillzero::Bool) where {N, T}
50+
return RescaledArray(zero(T), OMEinsum.get_output_array(getfield.(xs, :normalized_value), size, fillzero))
51+
end
52+
# The following two APIs are required by OMEinsum
53+
Base.fill!(r::RescaledArray, x) = (fill!(r.normalized_value, x ./ exp(r.log_factor)); r)
54+
Base.conj(r::RescaledArray) = RescaledArray(conj(r.log_factor), conj(r.normalized_value))

src/TensorInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ $(EXPORTS)
88
module TensorInference
99

1010
using OMEinsum, LinearAlgebra
11+
using OMEinsum: CacheTree, cached_einsum
1112
using DocStringExtensions, TropicalNumbers
1213
# The Tropical GEMM support
1314
using StatsBase

src/belief.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
7979
@assert length(vectors_out) == length(vectors_in) == ndims(t) "dimensions mismatch: $(length(vectors_out)), $(length(vectors_in)), $(ndims(t))"
8080
# TODO: speed up if needed!
8181
code = star_code(length(vectors_in))
82-
cost, gradient = cost_and_gradient(code, [t, vectors_in...])
82+
cost, gradient = cost_and_gradient(code, (t, vectors_in...))
8383
for (o, g) in zip(vectors_out, gradient[2:end])
8484
o .= g
8585
end

src/map.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
########### Backward tropical tensor contraction ##############
44
# This part is copied from [`GenericTensorNetworks`](https://github.com/QuEraComputing/GenericTensorNetworks.jl).
5-
function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Tropical}} where {M}, y, size_dict, dy)
5+
function OMEinsum.einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Tropical}} where {M}, y, size_dict, dy)
66
return backward_tropical!(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict)
77
end
88

@@ -55,7 +55,7 @@ Returns the largest log-probability and the most probable configuration.
5555
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
5656
tensor_indices = check_queryvars(tn, [[v] for v in 1:tn.nvars])
5757
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
58-
logp, grads = cost_and_gradient(tn.code, tensors)
58+
logp, grads = cost_and_gradient(tn.code, (tensors...,))
5959
# use Array to convert CuArray to CPU arrays
6060
return content(Array(logp)[]), map(k -> haskey(tn.evidence, k) ? tn.evidence[k] : argmax(grads[tensor_indices[k]]) - 1, 1:tn.nvars)
6161
end

src/mar.jl

Lines changed: 1 addition & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -12,115 +12,6 @@ function adapt_tensors(code, tensors, evidence; usecuda, rescale)
1212
end
1313
end
1414

15-
# ######### Inference by back propagation ############
16-
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
17-
# It is a tree structure that isomorphic to the contraction tree,
18-
# `content` is the cached intermediate contraction result.
19-
# `children` are the children of current node, e.g. tensors that are contracted to get `content`.
20-
mutable struct CacheTree{T}
21-
content::AbstractArray{T}
22-
const children::Vector{CacheTree{T}}
23-
end
24-
25-
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
26-
# slicing is not supported yet.
27-
if length(se.slicing) != 0
28-
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
29-
end
30-
return cached_einsum(se.eins, xs, size_dict)
31-
end
32-
33-
# recursively contract and cache a tensor network
34-
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
35-
if OMEinsum.isleaf(code)
36-
# For a leaf node, cache the input tensor
37-
y = xs[code.tensorindex]
38-
return CacheTree(y, CacheTree{eltype(y)}[])
39-
else
40-
# For a non-leaf node, compute the einsum and cache the contraction result
41-
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
42-
# `einsum` evaluates the einsum contraction,
43-
# Its 1st argument is the contraction pattern,
44-
# Its 2nd one is a tuple of input tensors,
45-
# Its 3rd argument is the size dictionary (label as the key, size as the value).
46-
y = einsum(code.eins, ntuple(i -> caches[i].content, length(caches)), size_dict)
47-
return CacheTree(y, caches)
48-
end
49-
end
50-
51-
# computed gradient tree by back propagation
52-
function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
53-
if length(se.slicing) != 0
54-
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
55-
end
56-
return generate_gradient_tree(se.eins, cache, dy, size_dict)
57-
end
58-
59-
# recursively compute the gradients and store it into a tree.
60-
# also known as the back-propagation algorithm.
61-
function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
62-
if OMEinsum.isleaf(code)
63-
return CacheTree(dy, CacheTree{T}[])
64-
else
65-
xs = ntuple(i -> cache.children[i].content, length(cache.children))
66-
# `einsum_grad` is the back-propagation rule for einsum function.
67-
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
68-
# Then the back-propagation pass is
69-
# ```
70-
# A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1)
71-
# B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2)
72-
# ...
73-
# ```
74-
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
75-
dxs = einsum_backward_rule(code.eins, xs, cache.content, size_dict, dy)
76-
return CacheTree(dy, generate_gradient_tree.(code.args, cache.children, dxs, Ref(size_dict)))
77-
end
78-
end
79-
80-
# a unified interface of the backward rules for real numbers and tropical numbers
81-
function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy)
82-
return ntuple(i -> OMEinsum.einsum_grad(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), size_dict, dy, i), length(xs))
83-
end
84-
85-
# the main function for generating the gradient tree.
86-
function gradient_tree(code, xs)
87-
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
88-
size_dict = OMEinsum.get_size_dict!(getixsv(code), xs, Dict{Int, Int}())
89-
# forward compute and cache intermediate results.
90-
cache = cached_einsum(code, xs, size_dict)
91-
# initialize `y̅` as `1`. Note we always start from `L̅ := 1`.
92-
dy = match_arraytype(typeof(cache.content), ones(eltype(cache.content), size(cache.content)))
93-
# back-propagate
94-
return copy(cache.content), generate_gradient_tree(code, cache, dy, size_dict)
95-
end
96-
97-
# evaluate the cost and the gradient of leaves
98-
function cost_and_gradient(code, xs)
99-
cost, tree = gradient_tree(code, xs)
100-
# extract the gradients on leaves (i.e. the input tensors).
101-
return cost, extract_leaves(code, tree)
102-
end
103-
104-
# since slicing is not supported, we forward it to NestedEinsum.
105-
extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache)
106-
107-
# extract gradients on leaf nodes.
108-
function extract_leaves(code::NestedEinsum, cache::CacheTree)
109-
res = Vector{Any}(undef, length(getixsv(code)))
110-
return extract_leaves!(code, cache, res)
111-
end
112-
113-
function extract_leaves!(code, cache, res)
114-
if OMEinsum.isleaf(code)
115-
# extract
116-
res[code.tensorindex] = cache.content
117-
else
118-
# resurse deeper
119-
extract_leaves!.(code.args, cache.children, Ref(res))
120-
end
121-
return res
122-
end
123-
12415
"""
12516
$(TYPEDSIGNATURES)
12617
@@ -186,7 +77,7 @@ probabilities of the queried variables, represented by tensors.
18677
"""
18778
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
18879
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
189-
cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale))
80+
cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,))
19081
@debug "cost = $cost"
19182
ixs = OMEinsum.getixsv(tn.code)
19283
queryvars = ixs[tn.unity_tensors_idx]

src/mmap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ end
178178
function most_probable_config(mmap::MMAPModel; usecuda = false)::Tuple{Real, Vector}
179179
vars = get_vars(mmap)
180180
tensors = map(t -> OMEinsum.asarray(Tropical.(log.(t)), t), adapt_tensors(mmap; usecuda, rescale = false))
181-
logp, grads = cost_and_gradient(mmap.code, tensors)
181+
logp, grads = cost_and_gradient(mmap.code, (tensors...,))
182182
# use Array to convert CuArray to CPU arrays
183183
return content(Array(logp)[]), map(k -> haskey(mmap.evidence, vars[k]) ? mmap.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
184184
end

src/sampling.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ function generate_samples!(code::DynamicNestedEinsum, cache::CacheTree{T}, iy_en
134134
@assert length(iy_env) == ndims(env)
135135
if !(OMEinsum.isleaf(code))
136136
ixs, iy = getixsv(code.eins), getiyv(code.eins)
137-
for (subcode, child, ix) in zip(code.args, cache.children, ixs)
137+
for (subcode, child, ix) in zip(code.args, cache.siblings, ixs)
138138
# subenv for the current child, use it to sample and update its cache
139-
siblings = filter(x->x !== child, cache.children)
139+
siblings = filter(x->x !== child, cache.siblings)
140140
siblings_ixs = filter(x->x !== ix, ixs)
141141
iy_subenv = batch_label ix ? ix : [ix..., batch_label]
142142
envcode = optimize_code(EinCode([siblings_ixs..., iy_env], iy_subenv), size_dict, GreedyMethod(; nrepeat=1))
@@ -184,12 +184,12 @@ end
184184
function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}, batch_label::L, size_dict::Dict{L}) where {T, L}
185185
OMEinsum.isleaf(ne) && return
186186
updated = false
187-
for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins))
187+
for (subcode, child, ix) in zip(ne.args, cache.siblings, getixsv(ne.eins))
188188
if any(x->x el.first, ix)
189189
updated = true
190190
child.content = _eliminate!(child.content, ix, el, batch_label)
191191
udpate_cache_tree!(subcode, child, el, batch_label, size_dict)
192192
end
193193
end
194-
updated && (cache.content = einsum(ne.eins, (getfield.(cache.children, :content)...,), size_dict))
194+
updated && (cache.content = einsum(ne.eins, (getfield.(cache.siblings, :content)...,), size_dict))
195195
end

0 commit comments

Comments
 (0)