Skip to content

Commit 86aa5fd

Browse files
authored
Upgrade omeinsum to version 0.9.1 (#101)
* upgrade - omeinsum * update * fix tests
1 parent b3a48f4 commit 86aa5fd

File tree

11 files changed

+249
-14
lines changed

11 files changed

+249
-14
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorInference"
22
uuid = "c2297e78-99bd-40ad-871d-f50e56b81012"
33
authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -23,7 +23,7 @@ TensorInferenceCUDAExt = "CUDA"
2323
CUDA = "4, 5"
2424
DocStringExtensions = "0.8.6, 0.9"
2525
LinearAlgebra = "1"
26-
OMEinsum = "0.8.7"
26+
OMEinsum = "0.9.1"
2727
Pkg = "1"
2828
PrettyTables = "2"
2929
ProblemReductions = "0.3"

docs/src/api/public.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ TensorInference
3434
```@docs
3535
GreedyMethod
3636
KaHyParBipartite
37+
HyperND
38+
TreeSASlicer
39+
ScoreFunction
3740
MergeGreedy
3841
MergeVectors
3942
SABipartite
@@ -73,4 +76,6 @@ update_temperature
7376
random_matrix_product_state
7477
random_matrix_product_uai
7578
random_tensor_train_uai
79+
save_tensor_network
80+
load_tensor_network
7681
```

docs/src/performance-tips.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ probability(tn)
3030

3131
# For large scale applications, it is also possible to slice over certain degrees of freedom to reduce the space complexity, i.e.
3232
# loop and accumulate over certain degrees of freedom so that one can have a smaller tensor network inside the loop due to the removal of these degrees of freedom.
33-
# In the [`TreeSA`](@ref) optimizer, one can set `nslices` to a value larger than zero to turn on this feature.
34-
# As a comparison we slice over 5 degrees of freedom, which can reduce the space complexity by at most 5.
33+
# One can use the `slicer` keyword argument to reduce the space complexity by slicing over certain degrees of freedom.
34+
# In the following example, we use the `TreeSASlicer` to reduce the space complexity to `sc_target=10`.
3535
# In this application, the slicing achieves the largest possible space complexity reduction 5, while the time and read-write complexity are only increased by less than 1,
3636
# i.e. the peak memory usage is reduced by a factor ``32``, while the (theoretical) computing time is increased by at a factor ``< 2``.
37-
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.3:100, nslices=5)
38-
tn = TensorNetworkModel(model; optimizer, evidence);
37+
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.3:100)
38+
tn = TensorNetworkModel(model; optimizer, evidence, slicer=TreeSASlicer(score=ScoreFunction(sc_target=10)));
3939
contraction_complexity(tn)
4040

4141
# ## Faster Tropical tensor contraction to speed up MAP and MMAP

src/Core.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Probabilistic modeling with a tensor network.
4949
* `code` is the tensor network contraction pattern.
5050
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`.
5151
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
52-
* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities.
52+
* `unity_tensors_idx` is a vector of indices pointing to the unity tensors in the `tensors` array. Unity tensors are dummy tensors with all entries equal to one, which are used to obtain the marginal probabilities.
5353
"""
5454
struct TensorNetworkModel{ET, MT <: AbstractArray}
5555
nvars::Int
@@ -118,6 +118,7 @@ function TensorNetworkModel(
118118
evidence = Dict{Int,Int}(),
119119
optimizer = GreedyMethod(),
120120
simplifier = nothing,
121+
slicer = nothing,
121122
unity_tensors_labels = [[i] for i=1:model.nvars]
122123
) where {ET, FT}
123124
# `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified.
@@ -127,7 +128,7 @@ function TensorNetworkModel(
127128
rawcode = EinCode([unity_tensors_labels..., [[factor.vars...] for factor in model.factors]...], collect(Int, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
128129
tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...]
129130
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
130-
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
131+
code = optimize_code(rawcode, size_dict, optimizer; simplifier, slicer)
131132
return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels)))
132133
end
133134

src/TensorInference.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ module TensorInference
99

1010
using OMEinsum, LinearAlgebra
1111
using OMEinsum: CacheTree, cached_einsum
12+
using OMEinsum.OMEinsumContractionOrders.JSON
1213
using DocStringExtensions, TropicalNumbers
1314
# The Tropical GEMM support
1415
using StatsBase
@@ -19,7 +20,7 @@ import Pkg
1920

2021
# reexport OMEinsum functions
2122
export RescaledArray
22-
export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors
23+
export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, HyperND, SABipartite, MergeGreedy, MergeVectors, TreeSASlicer, ScoreFunction
2324

2425
# read and load uai files
2526
export read_model_file, read_td_file, read_evidence_file
@@ -44,6 +45,9 @@ export update_temperature
4445
# belief propagation
4546
export BeliefPropgation, belief_propagate
4647

48+
# fileio
49+
export save_tensor_network, load_tensor_network
50+
4751
# utils
4852
export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai
4953

@@ -56,5 +60,6 @@ include("mmap.jl")
5660
include("sampling.jl")
5761
include("cspmodels.jl")
5862
include("belief.jl")
63+
include("fileio.jl")
5964

6065
end # module

src/fileio.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
save_tensor_network(tn::TensorNetworkModel; folder::String)
3+
4+
Save a tensor network model to a folder with separate files for code, tensors, and model metadata.
5+
The code is saved using `OMEinsum.writejson`, tensors as JSON, and model specifics in model.json.
6+
7+
# Arguments
8+
- `tn::TensorNetworkModel`: The tensor network model to save
9+
- `folder::String`: The folder path to save the files
10+
11+
# Files Created
12+
- `code.json`: Contains the einsum code using OMEinsum format
13+
- `tensors.json`: Contains the tensor data as JSON
14+
- `model.json`: Contains nvars, evidence, and unity_tensors_idx
15+
16+
# Example
17+
```julia
18+
tn = TensorNetworkModel(...) # create your model
19+
save_tensor_network(tn; folder="my_model")
20+
```
21+
"""
22+
function save_tensor_network(tn::TensorNetworkModel; folder::String)
23+
!isdir(folder) && mkpath(folder)
24+
25+
# save code
26+
OMEinsum.writejson(joinpath(folder, "code.json"), tn.code)
27+
28+
# save tensors
29+
open(joinpath(folder, "tensors.json"), "w") do io
30+
JSON.print(io, [tensor_to_dict(tensor) for tensor in tn.tensors], 2)
31+
end
32+
33+
# save model metadata
34+
open(joinpath(folder, "model.json"), "w") do io
35+
JSON.print(io, Dict(
36+
"nvars" => tn.nvars,
37+
"evidence" => tn.evidence,
38+
"unity_tensors_idx" => tn.unity_tensors_idx
39+
), 2)
40+
end
41+
return nothing
42+
end
43+
44+
"""
45+
load_tensor_network(folder::String)
46+
47+
Load a tensor network model from a folder containing code, tensors, and model files.
48+
49+
# Arguments
50+
- `folder::String`: The folder path containing the files
51+
52+
# Returns
53+
- `TensorNetworkModel`: The loaded tensor network model
54+
55+
# Required Files
56+
- `code.json`: Contains the einsum code using OMEinsum format
57+
- `tensors.json`: Contains the tensor data as JSON
58+
- `model.json`: Contains nvars, evidence, and unity_tensors_idx
59+
60+
# Example
61+
```julia
62+
tn = load_tensor_network("my_model")
63+
```
64+
"""
65+
function load_tensor_network(folder::String)::TensorNetworkModel
66+
!isdir(folder) && throw(SystemError("Folder not found: $folder"))
67+
68+
code_path = joinpath(folder, "code.json")
69+
tensors_path = joinpath(folder, "tensors.json")
70+
model_path = joinpath(folder, "model.json")
71+
!isfile(code_path) && throw(SystemError("Code file not found: $code_path"))
72+
!isfile(tensors_path) && throw(SystemError("Tensors file not found: $tensors_path"))
73+
!isfile(model_path) && throw(SystemError("Model file not found: $model_path"))
74+
75+
code = OMEinsum.readjson(code_path)
76+
77+
tensors = [tensor_from_dict(t) for t in JSON.parsefile(tensors_path)]
78+
79+
model_dict = JSON.parsefile(model_path)
80+
81+
# Convert evidence keys to Int (JSON parses them as strings)
82+
evidence = Dict{Int, Int}()
83+
for (k, v) in model_dict["evidence"]
84+
evidence[parse(Int, k)] = v
85+
end
86+
87+
return TensorNetworkModel(
88+
model_dict["nvars"],
89+
code,
90+
tensors,
91+
evidence,
92+
collect(Int, model_dict["unity_tensors_idx"])
93+
)
94+
end
95+
96+
"""
97+
tensor_to_dict(tensor::AbstractArray{T}) where T
98+
99+
Convert a tensor to a dictionary representation for JSON serialization.
100+
101+
# Arguments
102+
- `tensor::AbstractArray{T}`: The tensor to convert
103+
104+
# Returns
105+
- `Dict`: A dictionary containing tensor metadata and data
106+
107+
# Dictionary Structure
108+
- `"size"`: The dimensions of the tensor
109+
- `"complex"`: Boolean indicating if the tensor contains complex numbers
110+
- `"data"`: The tensor data as a flat array of real numbers
111+
"""
112+
function tensor_to_dict(tensor::AbstractArray{T}) where T
113+
d = Dict()
114+
d["size"] = collect(size(tensor))
115+
d["complex"] = T <: Complex
116+
d["data"] = vec(reinterpret(real(T), tensor))
117+
return d
118+
end
119+
120+
"""
121+
tensor_from_dict(dict::Dict)
122+
123+
Convert a dictionary back to a tensor.
124+
125+
# Arguments
126+
- `dict::Dict`: The dictionary representation of a tensor
127+
128+
# Returns
129+
- `AbstractArray`: The reconstructed tensor
130+
131+
# Dictionary Structure Expected
132+
- `"size"`: The dimensions of the tensor
133+
- `"complex"`: Boolean indicating if the tensor contains complex numbers
134+
- `"data"`: The tensor data as a flat array of real numbers
135+
"""
136+
function tensor_from_dict(dict::Dict)
137+
size_vec = Tuple(dict["size"])
138+
is_complex = dict["complex"]
139+
data = collect(Float64, dict["data"])
140+
141+
if is_complex
142+
complex_data = reinterpret(ComplexF64, data)
143+
return reshape(complex_data, size_vec...)
144+
else
145+
return reshape(data, size_vec...)
146+
end
147+
end

src/mmap.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ function MMAPModel(vars::AbstractVector{LT}, cards::AbstractVector{Int}, factors
9191
ixsi = all_ixs[cluster]
9292
vari = unique!(vcat(ixsi...))
9393
iyi = setdiff(vari, contracted)
94-
codei = optimize_code(EinCode(ixsi, iyi), size_dict, marginalize_optimizer, marginalize_simplifier)
94+
codei = optimize_code(EinCode(ixsi, iyi), size_dict, marginalize_optimizer; simplifier=marginalize_simplifier)
9595
push!(ixs, iyi)
9696
push!(clusters, Cluster(contracted, codei, ts))
9797
end
9898
rem_indices = setdiff(1:length(all_ixs), vcat([c.second for c in subsets]...))
9999
remaining_tensors = all_tensors[rem_indices]
100-
code = optimize_code(EinCode([all_ixs[rem_indices]..., ixs...], iy), size_dict, optimizer, simplifier)
100+
code = optimize_code(EinCode([all_ixs[rem_indices]..., ixs...], iy), size_dict, optimizer; simplifier)
101101
return MMAPModel(setdiff(vars, marginalized), code, remaining_tensors, clusters, evidence)
102102
end
103103

src/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function generate_samples!(code::DynamicNestedEinsum, cache::CacheTree{T}, iy_en
139139
siblings = filter(x->x !== child, cache.siblings)
140140
siblings_ixs = filter(x->x !== ix, ixs)
141141
iy_subenv = batch_label ix ? ix : [ix..., batch_label]
142-
envcode = optimize_code(EinCode([siblings_ixs..., iy_env], iy_subenv), size_dict, GreedyMethod(; nrepeat=1))
142+
envcode = optimize_code(EinCode([siblings_ixs..., iy_env], iy_subenv), size_dict, GreedyMethod())
143143
subenv = einsum(envcode, (getfield.(siblings, :content)..., env), size_dict)
144144

145145
# generate samples

test/belief.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656
tnet = TensorNetworkModel(mps_uai)
5757
mars_tnet = marginals(tnet)
5858
for v in 1:TensorInference.num_variables(bp)
59-
@test mars[[v]] mars_tnet[[v]] atol=1e-4
59+
@test mars[[v]] mars_tnet[[v]] atol=1e-3
6060
end
6161
end
6262

@@ -119,4 +119,4 @@ end
119119
@test mars[[v]] mars_tnet[[v]] atol=1e-2
120120
end
121121
end
122-
end
122+
end

test/fileio.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Test
2+
using TensorInference
3+
using Random
4+
5+
@testset "TensorNetworkModel file I/O" begin
6+
# Create a test model
7+
n = 3
8+
chi = 2
9+
d = 2
10+
tn = random_matrix_product_state(n, chi, d)
11+
12+
# Add evidence for testing
13+
tn.evidence[1] = 0
14+
tn.evidence[2] = 1
15+
16+
# Create temporary directory
17+
test_dir = mktempdir()
18+
19+
# Test saving
20+
@testset "Saving" begin
21+
TensorInference.save_tensor_network(tn; folder=test_dir)
22+
@test isfile(joinpath(test_dir, "code.json"))
23+
@test isfile(joinpath(test_dir, "tensors.json"))
24+
@test isfile(joinpath(test_dir, "model.json"))
25+
end
26+
27+
# Test loading
28+
@testset "Loading" begin
29+
tn_loaded = TensorInference.load_tensor_network(test_dir)
30+
31+
# Verify basic properties
32+
@test tn_loaded.nvars == tn.nvars
33+
@test tn_loaded.evidence == tn.evidence
34+
@test tn_loaded.unity_tensors_idx == tn.unity_tensors_idx
35+
36+
# Verify code structure
37+
@test tn_loaded.code isa typeof(tn.code)
38+
39+
# Verify tensors
40+
@test length(tn_loaded.tensors) == length(tn.tensors)
41+
for (t_orig, t_loaded) in zip(tn.tensors, tn_loaded.tensors)
42+
@test size(t_orig) == size(t_loaded)
43+
@test eltype(t_orig) == eltype(t_loaded)
44+
@test Array(t_orig) Array(t_loaded)
45+
end
46+
47+
# Verify model functionality
48+
@test probability(tn)[] probability(tn_loaded)[]
49+
end
50+
end
51+
52+
@testset "Tensor serialization" begin
53+
Random.seed!(42)
54+
55+
# Test real tensor
56+
real_tensor = rand(2, 2)
57+
dict_real = TensorInference.tensor_to_dict(real_tensor)
58+
@test TensorInference.tensor_from_dict(dict_real) real_tensor
59+
60+
# Test complex tensor
61+
complex_tensor = rand(ComplexF64, 2, 2)
62+
dict_complex = TensorInference.tensor_to_dict(complex_tensor)
63+
@test TensorInference.tensor_from_dict(dict_complex) complex_tensor
64+
65+
# Test higher-dimensional tensor
66+
high_dim_tensor = rand(2, 3, 4)
67+
dict_high_dim = TensorInference.tensor_to_dict(high_dim_tensor)
68+
@test TensorInference.tensor_from_dict(dict_high_dim) high_dim_tensor
69+
70+
# Test invalid input
71+
@test_throws KeyError TensorInference.tensor_from_dict(Dict())
72+
@test_throws KeyError TensorInference.tensor_from_dict(Dict("size" => [2,2]))
73+
end

0 commit comments

Comments
 (0)