Skip to content

Commit d70bf03

Browse files
committed
Adapt new solver codes into ITensorNetworks module
1 parent 9235349 commit d70bf03

15 files changed

+203
-223
lines changed

src/ITensorNetworks.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@ include("gauging.jl")
3535
include("utils.jl")
3636
include("update_observer.jl")
3737

38+
include("solvers/local_solvers/eigsolve.jl")
39+
include("solvers/local_solvers/exponentiate.jl")
40+
include("solvers/local_solvers/runge_kutta.jl")
41+
include("solvers/truncation_parameters.jl")
42+
include("solvers/sweep_solve.jl")
43+
include("solvers/iterators.jl")
44+
include("solvers/region_plans/dfs_plans.jl")
45+
include("solvers/region_plans/euler_tour.jl")
46+
include("solvers/region_plans/euler_plans.jl")
47+
include("solvers/region_plans/tdvp_region_plans.jl")
48+
include("solvers/extracter.jl")
49+
include("solvers/inserter.jl")
50+
include("solvers/subspace/subspace.jl")
51+
include("solvers/subspace/densitymatrix.jl")
52+
include("solvers/eigsolve.jl")
53+
include("solvers/applyexp.jl")
54+
3855
include("treetensornetworks/abstracttreetensornetwork.jl")
3956
include("treetensornetworks/treetensornetwork.jl")
4057
include("treetensornetworks/opsum_to_ttn/matelem.jl")

src/solvers/applyexp.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import ITensorNetworks as itn
2-
using Printf
1+
using Printf: @printf
2+
import ConstructionBase: setproperties
33

44
@kwdef mutable struct ApplyExpProblem{State}
55
state::State
66
operator
77
current_time::Number = 0.0
88
end
99

10-
state(tdvp::ApplyExpProblem) = tdvp.state
10+
ITensorNetworks.state(tdvp::ApplyExpProblem) = tdvp.state
1111
operator(tdvp::ApplyExpProblem) = tdvp.operator
1212
current_time(tdvp::ApplyExpProblem) = tdvp.current_time
1313

@@ -31,11 +31,11 @@ function updater(
3131
curr_reg = current_region(region_iterator)
3232
next_reg = next_region(region_iterator)
3333
if !isnothing(next_reg) && next_reg != curr_reg
34-
next_edge = first(itn.edge_sequence_between_regions(state(T), curr_reg, next_reg))
35-
v1, v2 = itn.src(next_edge), itn.dst(next_edge)
34+
next_edge = first(edge_sequence_between_regions(state(T), curr_reg, next_reg))
35+
v1, v2 = src(next_edge), dst(next_edge)
3636
psi = copy(state(T))
3737
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
38-
shifted_operator = itn.position(operator(T), psi, itn.NamedEdge(v1=>v2))
38+
shifted_operator = position(operator(T), psi, NamedEdge(v1=>v2))
3939
R_t, _ = solver(x->optimal_map(shifted_operator, x), -time_step, R; kws...)
4040
local_state = psi[v1]*R_t
4141
end
@@ -53,7 +53,7 @@ function applyexp_sweep_printer(
5353
if outputlevel >= 1
5454
T = problem(region_iterator)
5555
@printf(" Current time = %s, ", process_time(current_time(T)))
56-
@printf("maxlinkdim=%d", itn.maxlinkdim(state(T)))
56+
@printf("maxlinkdim=%d", maxlinkdim(state(T)))
5757
println()
5858
flush(stdout)
5959
end
@@ -83,7 +83,7 @@ end
8383

8484
function applyexp(H, init_state, exponents; kws...)
8585
init_prob = ApplyExpProblem(;
86-
state=permute_indices(init_state), operator=itn.ProjTTN(permute_indices(H))
86+
state=permute_indices(init_state), operator=ProjTTN(permute_indices(H))
8787
)
8888
return applyexp(init_prob, exponents; kws...)
8989
end

src/solvers/eigsolve.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import ITensorNetworks as itn
2-
using Printf
1+
using Printf: @printf
2+
import ConstructionBase: setproperties
33

44
@kwdef mutable struct EigsolveProblem{State,Operator}
55
state::State
@@ -8,7 +8,7 @@ using Printf
88
end
99

1010
eigenvalue(E::EigsolveProblem) = E.eigenvalue
11-
state(E::EigsolveProblem) = E.state
11+
ITensorNetworks.state(E::EigsolveProblem) = E.state
1212
operator(E::EigsolveProblem) = E.operator
1313

1414
function updater(
@@ -36,7 +36,7 @@ function eigsolve_sweep_printer(region_iterator; outputlevel, sweep, nsweeps, kw
3636
end
3737
E = problem(region_iterator)
3838
@printf("eigenvalue=%.12f ", eigenvalue(E))
39-
@printf("maxlinkdim=%d", itn.maxlinkdim(state(E)))
39+
@printf("maxlinkdim=%d", maxlinkdim(state(E)))
4040
println()
4141
flush(stdout)
4242
end
@@ -68,7 +68,7 @@ end
6868

6969
function eigsolve(H, init_state; kws...)
7070
init_prob = EigsolveProblem(;
71-
state=permute_indices(init_state), operator=itn.ProjTTN(permute_indices(H))
71+
state=permute_indices(init_state), operator=ProjTTN(permute_indices(H))
7272
)
7373
return eigsolve(init_prob; kws...)
7474
end

src/solvers/extracter.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ import ConstructionBase: setproperties
33
function extracter(problem, region_iterator; sweep, trunc=(;), kws...)
44
trunc = truncation_parameters(sweep; trunc...)
55
region = current_region(region_iterator)
6-
psi = itn.orthogonalize(state(problem), region)
6+
psi = orthogonalize(state(problem), region)
77
local_state = prod(psi[v] for v in region)
88
problem = setproperties(problem; state=psi)
99

1010
problem, local_state = subspace_expand(
1111
problem, local_state, region_iterator; sweep, trunc, kws...
1212
)
1313

14-
shifted_operator = itn.position(operator(problem), state(problem), region)
14+
shifted_operator = position(operator(problem), state(problem), region)
1515

1616
return setproperties(problem; operator=shifted_operator), local_state
1717
end

src/solvers/fitting.jl

Lines changed: 0 additions & 112 deletions
This file was deleted.

src/solvers/inserter.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ function inserter(
2121
indsTe = it.inds(psi[first(region)])
2222
tags = it.tags(psi, e)
2323
U, C, _ = it.factorize(local_tensor, indsTe; tags, trunc...)
24-
itn.@preserve_graph psi[first(region)] = U
24+
@preserve_graph psi[first(region)] = U
2525
else
2626
error("Region of length $(length(region)) not currently supported")
2727
end
2828
v = last(region)
29-
itn.@preserve_graph psi[v] = C
30-
psi = set_orthogonal_region ? itn.set_ortho_region(psi, [v]) : psi
31-
normalize && itn.@preserve_graph psi[v] = psi[v] / norm(psi[v])
29+
@preserve_graph psi[v] = C
30+
psi = set_orthogonal_region ? set_ortho_region(psi, [v]) : psi
31+
normalize && @preserve_graph psi[v] = psi[v] / norm(psi[v])
3232
return setproperties(problem; state=psi)
3333
end

src/solvers/operator_map.jl

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,9 @@
1-
import ITensorNetworks as itn
21

3-
function optimal_map(P::itn.ProjTTN, ψ)
4-
envs = [itn.environment(P, e) for e in itn.incident_edges(P)]
5-
site_ops = [itn.operator(P)[s] for s in itn.sites(P)]
2+
function optimal_map(P::ProjTTN, ψ)
3+
envs = [environment(P, e) for e in incident_edges(P)]
4+
site_ops = [operator(P)[s] for s in sites(P)]
65
contract_list = [envs..., site_ops..., ψ]
7-
sequence = itn.contraction_sequence(contract_list; alg="optimal")
8-
= itn.contract(contract_list; sequence)
6+
sequence = contraction_sequence(contract_list; alg="optimal")
7+
= contract(contract_list; sequence)
98
return noprime(Pψ)
109
end
11-
12-
# This function is a workaround for the slow contraction order
13-
# heuristic in ITensorNetworks/src/treetensornetworks/projttns/projttn.jl
14-
# in the projected_operator_tensors(P::ProjTTN) function (line 97 or so)
15-
function operator_map(P::itn.ProjTTN, ψ)
16-
ψ = copy(ψ)
17-
if itn.on_edge(P)
18-
for edge in itn.incident_edges(P)
19-
ψ *= itn.environment(P, edge)
20-
end
21-
else
22-
region = itn.sites(P)
23-
ie = itn.incident_edges(P)
24-
# TODO: improvement ideas
25-
# - check which vertex (first(region) vs. last(region)
26-
# has more incident edges and contract those environments first
27-
for edge in ie
28-
if itn.dst(edge) == first(region)
29-
ψ *= itn.environment(P, edge)
30-
end
31-
end
32-
for s in itn.sites(P)
33-
ψ *= itn.operator(P)[s]
34-
end
35-
for edge in ie
36-
if itn.dst(edge) != first(region)
37-
ψ *= itn.environment(P, edge)
38-
end
39-
end
40-
end
41-
return noprime(ψ)
42-
end

src/solvers/permute_indices.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import ITensorNetworks as itn
2-
using ITensors
31

42
function permute_indices(tn)
5-
si = itn.siteinds(tn)
3+
si = siteinds(tn)
64
ptn = copy(tn)
7-
for v in itn.vertices(tn)
5+
for v in vertices(tn)
86
is = inds(tn[v])
97
ls = setdiff(is, si[v])
108
isempty(ls) && continue

src/solvers/region_plans.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import NamedGraphs: GraphsExtensions
66
# return [fwd_sweep..., reverse(fwd_sweep)...]
77
#end
88

9-
function tdvp_regions(g::AbstractGraph, time_step; nsites=1, updater_kwargs, sweep_kwargs...)
9+
function tdvp_regions(
10+
g::AbstractGraph, time_step; nsites=1, updater_kwargs, sweep_kwargs...
11+
)
1012
@assert nsites==1
1113
fwd_up_args = (; time=(time_step / 2), updater_kwargs...)
1214
rev_up_args = (; time=(-time_step / 2), updater_kwargs...)
@@ -25,7 +27,7 @@ function tdvp_regions(g::AbstractGraph, time_step; nsites=1, updater_kwargs, swe
2527
end
2628

2729
function overlap(ea::AbstractEdge, eb::AbstractEdge)
28-
return intersect([src(ea),dst(ea)], [src(eb),dst(eb)])
30+
return intersect([src(ea), dst(ea)], [src(eb), dst(eb)])
2931
end
3032

3133
function forward_region(edges, which_edge; nsites=1, region_kwargs=(;))

src/solvers/subspace.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1-
using ITensors: commonind, dag, dim, directsum, dot, hascommoninds, Index, norm, onehot, uniqueinds, random_itensor
1+
using ITensors:
2+
commonind,
3+
dag,
4+
dim,
5+
directsum,
6+
dot,
7+
hascommoninds,
8+
Index,
9+
norm,
10+
onehot,
11+
uniqueinds,
12+
random_itensor
213

314
# TODO: hoist num_expand default value out to a function or similar
4-
function subspace_expand!(problem::EigsolveProblem, local_tensor, region; prev_region, num_expand=4, kws...)
5-
6-
if isnothing(prev_region) || isa(region, AbstractEdge)
15+
function subspace_expand!(
16+
problem::EigsolveProblem, local_tensor, region; prev_region, num_expand=4, kws...
17+
)
18+
if isnothing(prev_region) || isa(region, AbstractEdge)
719
return local_tensor
820
end
921

@@ -49,4 +61,4 @@ function subspace_expand!(problem::EigsolveProblem, local_tensor, region; prev_r
4961
local_tensor = prod(psi[v] for v in region)
5062

5163
return local_tensor
52-
end
64+
end

0 commit comments

Comments
 (0)