@@ -12,115 +12,6 @@ function adapt_tensors(code, tensors, evidence; usecuda, rescale)
1212 end
1313end
1414
15- # ######### Inference by back propagation ############
16- # `CacheTree` stores intermediate `NestedEinsum` contraction results.
17- # It is a tree structure that isomorphic to the contraction tree,
18- # `content` is the cached intermediate contraction result.
19- # `children` are the children of current node, e.g. tensors that are contracted to get `content`.
20- mutable struct CacheTree{T}
21- content:: AbstractArray{T}
22- const children:: Vector{CacheTree{T}}
23- end
24-
25- function cached_einsum (se:: SlicedEinsum , @nospecialize (xs), size_dict)
26- # slicing is not supported yet.
27- if length (se. slicing) != 0
28- @warn " Slicing is not supported for caching, got nslices = $(length (se. slicing)) ! Fallback to `NestedEinsum`."
29- end
30- return cached_einsum (se. eins, xs, size_dict)
31- end
32-
33- # recursively contract and cache a tensor network
34- function cached_einsum (code:: NestedEinsum , @nospecialize (xs), size_dict)
35- if OMEinsum. isleaf (code)
36- # For a leaf node, cache the input tensor
37- y = xs[code. tensorindex]
38- return CacheTree (y, CacheTree{eltype (y)}[])
39- else
40- # For a non-leaf node, compute the einsum and cache the contraction result
41- caches = [cached_einsum (arg, xs, size_dict) for arg in code. args]
42- # `einsum` evaluates the einsum contraction,
43- # Its 1st argument is the contraction pattern,
44- # Its 2nd one is a tuple of input tensors,
45- # Its 3rd argument is the size dictionary (label as the key, size as the value).
46- y = einsum (code. eins, ntuple (i -> caches[i]. content, length (caches)), size_dict)
47- return CacheTree (y, caches)
48- end
49- end
50-
51- # computed gradient tree by back propagation
52- function generate_gradient_tree (se:: SlicedEinsum , cache:: CacheTree{T} , dy:: AbstractArray{T} , size_dict:: Dict ) where {T}
53- if length (se. slicing) != 0
54- @warn " Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
55- end
56- return generate_gradient_tree (se. eins, cache, dy, size_dict)
57- end
58-
59- # recursively compute the gradients and store it into a tree.
60- # also known as the back-propagation algorithm.
61- function generate_gradient_tree (code:: NestedEinsum , cache:: CacheTree{T} , dy:: AbstractArray{T} , size_dict:: Dict ) where {T}
62- if OMEinsum. isleaf (code)
63- return CacheTree (dy, CacheTree{T}[])
64- else
65- xs = ntuple (i -> cache. children[i]. content, length (cache. children))
66- # `einsum_grad` is the back-propagation rule for einsum function.
67- # If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
68- # Then the back-propagation pass is
69- # ```
70- # A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1)
71- # B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2)
72- # ...
73- # ```
74- # Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
75- dxs = einsum_backward_rule (code. eins, xs, cache. content, size_dict, dy)
76- return CacheTree (dy, generate_gradient_tree .(code. args, cache. children, dxs, Ref (size_dict)))
77- end
78- end
79-
80- # a unified interface of the backward rules for real numbers and tropical numbers
81- function einsum_backward_rule (eins, xs:: NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy)
82- return ntuple (i -> OMEinsum. einsum_grad (OMEinsum. getixs (eins), xs, OMEinsum. getiy (eins), size_dict, dy, i), length (xs))
83- end
84-
85- # the main function for generating the gradient tree.
86- function gradient_tree (code, xs)
87- # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
88- size_dict = OMEinsum. get_size_dict! (getixsv (code), xs, Dict {Int, Int} ())
89- # forward compute and cache intermediate results.
90- cache = cached_einsum (code, xs, size_dict)
91- # initialize `y̅` as `1`. Note we always start from `L̅ := 1`.
92- dy = match_arraytype (typeof (cache. content), ones (eltype (cache. content), size (cache. content)))
93- # back-propagate
94- return copy (cache. content), generate_gradient_tree (code, cache, dy, size_dict)
95- end
96-
97- # evaluate the cost and the gradient of leaves
98- function cost_and_gradient (code, xs)
99- cost, tree = gradient_tree (code, xs)
100- # extract the gradients on leaves (i.e. the input tensors).
101- return cost, extract_leaves (code, tree)
102- end
103-
104- # since slicing is not supported, we forward it to NestedEinsum.
105- extract_leaves (code:: SlicedEinsum , cache:: CacheTree ) = extract_leaves (code. eins, cache)
106-
107- # extract gradients on leaf nodes.
108- function extract_leaves (code:: NestedEinsum , cache:: CacheTree )
109- res = Vector {Any} (undef, length (getixsv (code)))
110- return extract_leaves! (code, cache, res)
111- end
112-
113- function extract_leaves! (code, cache, res)
114- if OMEinsum. isleaf (code)
115- # extract
116- res[code. tensorindex] = cache. content
117- else
118- # resurse deeper
119- extract_leaves! .(code. args, cache. children, Ref (res))
120- end
121- return res
122- end
123-
12415"""
12516$(TYPEDSIGNATURES)
12617
@@ -186,7 +77,7 @@ probabilities of the queried variables, represented by tensors.
18677"""
18778function marginals (tn:: TensorNetworkModel ; usecuda = false , rescale = true ):: Dict{Vector{Int}}
18879 # sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
189- cost, grads = cost_and_gradient (tn. code, adapt_tensors (tn; usecuda, rescale))
80+ cost, grads = cost_and_gradient (tn. code, ( adapt_tensors (tn; usecuda, rescale) ... , ))
19081 @debug " cost = $cost "
19182 ixs = OMEinsum. getixsv (tn. code)
19283 queryvars = ixs[tn. unity_tensors_idx]
0 commit comments