Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7a51966
Add `Problem` as type parameter to `SweepIterator`
Sep 22, 2025
245182b
Format test files and improve comparisons for readabilty on failure
Sep 24, 2025
7af4b25
Redesign iterator interface by introducing AbstractNetworkIterator ab…
Sep 24, 2025
c0ae5d0
Add `EachRegion` adapter that wraps `RegionIterator`, behaving the sa…
Sep 25, 2025
3b9d0af
Add unit tests for the `AbstractNetworkIterator` interface
Sep 30, 2025
4ef4e75
Rename `done` to `laststep` to better reflect the when it evalutes to…
Sep 30, 2025
e112eb4
Rename `previous_region` to `prev_region` to better align with julia …
Sep 30, 2025
da360e0
Rename `PauseAfterIncrement` -> `NoComputeStep` and improve some vari…
Oct 1, 2025
8bfc483
Make `extract` and `subspace_expand` mutating
Oct 3, 2025
1ef8498
Make `update` mutable
Oct 3, 2025
0a6e891
Make `insert` mutable
Oct 3, 2025
0653c47
First implementation of an `options` system.
Oct 3, 2025
d77321e
Simplify options interface to a single function `default_kwargs`.
Oct 6, 2025
aff14c7
Put calls to `extract!` etc in `compute!` function directly
Oct 6, 2025
4b21cc9
Refactor the region plan generating code.
Oct 7, 2025
e71512f
Have `dmrg` take a strict number of arguments
Oct 7, 2025
a4ce308
Purge non-mutating field setter functions.
Oct 7, 2025
a8b2c51
Use `current_kwargs` for getting kwargs from `RegionIterator`
Oct 7, 2025
18a8503
Introduce defaults using `default_kwargs` and be stricter about which…
Oct 7, 2025
0c9022c
Swap order of local_state and region_iter args
Oct 7, 2025
a9be11e
Add some unit tests for the defaults
Oct 7, 2025
4d52088
Rename file options.jl -> test_default_kwargs.jl
Oct 7, 2025
613d533
Fix `euler_sweep` returning kwargs not as `NamedTuple`
Oct 7, 2025
20bf783
The `sweep_solve` callbacks now get called without any keyword argume…
Oct 7, 2025
568c631
Some minor refactoring of the iterators.
Oct 7, 2025
fed9137
The `EachRegion` adapter now flattens the nested Sweep/Region iterato…
Oct 9, 2025
4ce453e
Add tests for `EachRegion` and `eachregion` wrapper functions
Oct 9, 2025
c59a9c5
Rename `laststep` -> `islaststep` in fitting with Julia conventions.
Oct 9, 2025
62195b6
Overhaul `default_kwargs` such that it mirrors the function signature…
Oct 9, 2025
917f2f1
Rename `NoComputeStep` to `IncrementOnly`
Oct 9, 2025
112d55e
Remove @info statement and fix bug with `astypes` not promoting corre…
Oct 10, 2025
0a9f127
Update `default_kwargs` tests.
Oct 10, 2025
e35f325
Remove stray `end` from `adapters.jl`.
Oct 14, 2025
6a8cdb1
Fix typo in docstring of `EachRegion` adapter.
jack-dunham Oct 14, 2025
9760de1
Function `reverse_regions` is now more concise.
Oct 14, 2025
26ece7b
Use explicit imports in `default_kwargs.jl`
Oct 14, 2025
340d805
Fix test imports and broken tests in `test_iterators.jl`.
Oct 14, 2025
f89c379
Merge branch 'network_solvers' of https://github.com/jack-dunham/ITen…
Oct 14, 2025
6a33f29
Rename @default_kwargs -> @define_default_kwargs
Oct 14, 2025
b4bcb93
Remove `astypes` option from `@define_default_kwargs`.
Oct 14, 2025
624f964
Update `default_kwargs` tests.
Oct 14, 2025
bd35f09
Add `sweep_solve` method for `EachRegion` adapter.
Oct 14, 2025
0b5314d
Add `@with_kwargs` macro which automatically splats `default_kwargs` …
Oct 14, 2025
a58ec92
Make use of `@with_kwargs` macro make code more concise.
Oct 14, 2025
b72a08f
The fallback default callback functions now no longer accept `kwargs.…
Oct 15, 2025
c5de5c4
Test fix: tests founds in sub-directories are now actually ran when i…
Oct 15, 2025
2788057
Skip broken tests for now
Oct 15, 2025
33b9e28
Rename `sweep_solve` -> `sweep_solve!` to obey convention
Oct 15, 2025
dedd82e
The `EachRegion` adapter now returns itself from `iterate` instead of…
Oct 15, 2025
d39f09e
The `sweep_solve!` function now always returns the type of the input …
Oct 15, 2025
3f5c97c
Mutating functions now return the first argument before any additiona…
Oct 15, 2025
7ad3138
Remove depreciated `solvers` code and tests from old interface
Oct 16, 2025
60235bc
Method `subspace_expand!(::Backend"densitymatrix")` now defines kwarg…
Oct 16, 2025
da3ad27
Solvers code now no longer relies on `default_kwargs` system
Oct 16, 2025
8725370
Remove `default_kwargs` related to source files
Oct 16, 2025
8afce8a
Delete stale include
mtfishman Oct 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ include("solvers/abstract_problem.jl")
include("solvers/eigsolve.jl")
include("solvers/applyexp.jl")
include("solvers/fitting.jl")
include("solvers/default_kwargs.jl")

include("apply.jl")
include("inner.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

abstract type AbstractProblem end

set_truncation_info(P::AbstractProblem, args...; kws...) = P
set_truncation_info!(P::AbstractProblem, args...; kws...) = P
55 changes: 31 additions & 24 deletions src/solvers/adapters.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,39 @@
"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator}

#
# TupleRegionIterator
#
# Adapts outputs to be (region, region_kwargs) tuples
#
# More generic design? maybe just assuming RegionIterator
# or its outputs implement some interface function that
# generates each tuple?
#

mutable struct TupleRegionIterator{RegionIter}
region_iterator::RegionIter
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct NoComputeStep{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

region_iterator(T::TupleRegionIterator) = T.region_iterator
laststep(adapter::NoComputeStep) = laststep(adapter.parent)
state(adapter::NoComputeStep) = state(adapter.parent)
increment!(adapter::NoComputeStep) = increment!(adapter.parent)
compute!(adapter::NoComputeStep) = adapter

function Base.iterate(T::TupleRegionIterator, which=1)
state = iterate(region_iterator(T), which)
isnothing(state) && return nothing
(current_region, region_kwargs) = current_region_plan(region_iterator(T))
return (current_region, region_kwargs), last(state)
end
NoComputeStep(adapter::NoComputeStep) = adapter

"""
region_tuples(R::RegionIterator)
struct EachRegion{RegionIterator} <: AbstractNetworkIterator

The `region_tuples` adapter converts a RegionIterator into an
iterator which outputs a tuple of the form (current_region, current_region_kwargs)
at each step.
Wapper adapter that returns a tuple (region, kwargs) at each step rather than the iterator
itself.
"""
region_tuples(R::RegionIterator) = TupleRegionIterator(R)
struct EachRegion{R<:RegionIterator} <: AbstractNetworkIterator
parent::R
end

# Essential definitions
Base.length(adapter::EachRegion) = length(adapter.parent)
state(adapter::EachRegion) = state(adapter.parent)
increment!(adapter::EachRegion) = state(adapter.parent)

function compute!(adapter::EachRegion)
# Do the usual compute! for RegionIterator
compute!(adapter.parent)
# But now lets return something useful
return current_region_plan(adapter)
end
89 changes: 43 additions & 46 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Printf: @printf
using Accessors: @set

@kwdef mutable struct ApplyExpProblem{State} <: AbstractProblem
operator
Expand All @@ -11,66 +10,64 @@ operator(A::ApplyExpProblem) = A.operator
state(A::ApplyExpProblem) = A.state
current_exponent(A::ApplyExpProblem) = A.current_exponent
function current_time(A::ApplyExpProblem)
t = im*A.current_exponent
t = im * A.current_exponent
return iszero(imag(t)) ? real(t) : t
end

set_operator(A::ApplyExpProblem, operator) = (@set A.operator = operator)
set_state(A::ApplyExpProblem, state) = (@set A.state = state)
set_current_exponent(A::ApplyExpProblem, exponent) = (@set A.current_exponent = exponent)

function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...)
return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...)
# Rename region_plan
function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...)
# The `exponent_step` kwarg for the `update!` function needs some pre-processing.
return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...)
end

function update(
prob::ApplyExpProblem,
local_state,
region_iterator;
function update!(
region_iterator::RegionIterator{<:ApplyExpProblem},
local_state;
nsites,
exponent_step,
solver=runge_kutta_solver,
outputlevel,
kws...,
)
iszero(abs(exponent_step)) && return prob, local_state
prob = problem(region_iterator)

iszero(abs(exponent_step)) && return local_state

local_state, info = solver(
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
local_state, _ = solver(
x -> optimal_map(operator(prob), x), exponent_step, local_state; kws...
)
if nsites==1
if nsites == 1
curr_reg = current_region(region_iterator)
next_reg = next_region(region_iterator)
if !isnothing(next_reg) && next_reg != curr_reg
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
v1, v2 = src(next_edge), dst(next_edge)
psi = copy(state(prob))
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1]*R_t
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1] * R_t
end
end

prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
prob.current_exponent += exponent_step

return prob, local_state
return local_state
end

function sweep_callback(
problem::ApplyExpProblem;
function default_sweep_callback(
sweep_iterator::SweepIterator{<:ApplyExpProblem};
exponent_description="exponent",
outputlevel,
sweep,
nsweeps,
outputlevel=0,
process_time=identity,
kws...,
)
if outputlevel >= 1
the_problem = problem(sweep_iterator)
@printf(
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
" Current %s = %s, ",
exponent_description,
process_time(current_exponent(the_problem))
)
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
println()
flush(stdout)
end
Expand All @@ -79,19 +76,20 @@ end
function applyexp(
init_prob::AbstractProblem,
exponents;
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
outputlevel=0,
nsites=1,
sweep_callback=default_sweep_callback,
order=4,
kws...,
nsites=2,
sweep_kwargs...,
)
exponent_steps = diff([zero(eltype(exponents)); exponents])
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
sweep_iter = sweep_iterator(init_prob, kws_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)

kws_array = [
(; order, nsites, sweep_kwargs..., exponent_step) for exponent_step in exponent_steps
]
sweep_iter = SweepIterator(init_prob, kws_array)

converged_prob = sweep_solve(sweep_callback, sweep_iter)

return state(converged_prob)
end

Expand All @@ -111,11 +109,10 @@ function time_evolve(
time_points,
init_state;
process_time=process_real_times,
sweep_callback=(
a...; k...
)->sweep_callback(a...; exponent_description="time", process_time, k...),
kws...,
sweep_callback=iter ->
default_sweep_callback(iter; exponent_description="time", process_time),
sweep_kwargs...,
)
exponents = [-im*t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
exponents = [-im * t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, sweep_kwargs...)
end
53 changes: 53 additions & 0 deletions src/solvers/default_kwargs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
default_kwargs(f, [obj = Any])
Return the default keyword arguments for the function `f`. These defaults may be
derived from the contents or type of the second arugment `obj`.
## Interface
Given a function `f`, one can optionally set the default keyword arguments for this
function by specializing either of the following two-argument methods:
```
ITensorNetworks.default_kwargs(::typeof(f), prob::AbstractProblem)
ITensorNetworks.default_kwargs(::typeof(f), ::Type{<:AbstractProblem})
```
If one does not require the contents of `prob::Prob` to generate the defaults then it is
recommended to dispatch on `Type{<:Prob}` directly (second method) so the defaults
can be accessed without constructing an instance of a `Prob`.
The return value of `default_kwargs` should be a `NamedTuple`, and will overwrite any
default values set in the function signature.
"""
default_kwargs(f) = default_kwargs(f, Any)
default_kwargs(f, obj) = _default_kwargs_fallback(f, obj)

# To avoid annoying potential method ambiguities.
function _default_kwargs_fallback(f, iter::RegionIterator)
return default_kwargs(f, problem(iter))
end
function _default_kwargs_fallback(f, problem::AbstractProblem)
return default_kwargs(f, typeof(problem))
end

# Eventually we reach this if nothing is specialized.
_default_kwargs_fallback(::Any, ::DataType) = (;)

"""
current_kwargs(f, iter::RegionIterator)
Return the keyword arguments to be passed to the function `f` for the current region
defined by the stateful iterator `iter`.
"""
function current_kwargs(f::Function, iter::RegionIterator)
region_kwargs = get(current_region_kwargs(iter), Symbol(f, :_kwargs), (;))
rv = merge(default_kwargs(f, iter), region_kwargs)
return rv
end

# Generic

# I think these should be set independent of a function, but for now:
function default_kwargs(::typeof(factorize), ::Any)
return (; maxdim=typemax(Int), cutoff=0.0, mindim=1)
end
69 changes: 31 additions & 38 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Accessors: @set
using Printf: @printf
using ITensors: truncerror

Expand All @@ -14,42 +13,48 @@ state(E::EigsolveProblem) = E.state
operator(E::EigsolveProblem) = E.operator
max_truncerror(E::EigsolveProblem) = E.max_truncerror

set_operator(E::EigsolveProblem, operator) = (@set E.operator = operator)
set_eigenvalue(E::EigsolveProblem, eigenvalue) = (@set E.eigenvalue = eigenvalue)
set_state(E::EigsolveProblem, state) = (@set E.state = state)
set_max_truncerror(E::EigsolveProblem, truncerror) = (@set E.max_truncerror = truncerror)

function set_truncation_info(E::EigsolveProblem; spectrum=nothing)
function set_truncation_info!(E::EigsolveProblem; spectrum=nothing)
if !isnothing(spectrum)
E = set_max_truncerror(E, max(max_truncerror(E), truncerror(spectrum)))
E.max_truncerror = max(max_truncerror(E), truncerror(spectrum))
end
return E
end

function update(
prob::EigsolveProblem,
local_state,
region_iterator;
outputlevel,
solver=eigsolve_solver,
kws...,
function update!(
region_iterator::RegionIterator{<:EigsolveProblem}, local_state; outputlevel, solver
)
eigval, local_state = solver(ψ->optimal_map(operator(prob), ψ), local_state; kws...)
prob = set_eigenvalue(prob, eigval)
prob = problem(region_iterator)

eigval, local_state = solver(
ψ -> optimal_map(operator(prob), ψ),
local_state;
current_kwargs(solver, region_iterator)...,
)

prob.eigenvalue = eigval

if outputlevel >= 2
@printf(
" Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob)
)
end
return prob, local_state
return local_state
end

function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
function default_kwargs(::typeof(update!), ::Type{<:EigsolveProblem})
return (; outputlevel=0, solver=eigsolve_solver)
end

function default_sweep_callback(
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0
)
if outputlevel >= 1
if nsweeps >= 10
@printf("After sweep %02d/%d ", sweep, nsweeps)
nsweeps = length(sweep_iterator)
current_sweep = sweep_iterator.which_sweep
if length(sweep_iterator) >= 10
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
else
@printf("After sweep %d/%d ", sweep, nsweeps)
@printf("After sweep %d/%d ", current_sweep, nsweeps)
end
@printf("eigenvalue=%.12f", eigenvalue(problem))
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
Expand All @@ -59,25 +64,13 @@ function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, k
end
end

function eigsolve(
operator,
init_state;
nsweeps,
nsites=1,
outputlevel=0,
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
kws...,
)
function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_kwargs...)
init_prob = EigsolveProblem(;
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
)
sweep_iter = sweep_iterator(
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
)
prob = sweep_solve(sweep_iter; outputlevel, kws...)
sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...)
prob = sweep_solve(sweep_iter)
return eigenvalue(prob), state(prob)
end

dmrg(args...; kws...) = eigsolve(args...; kws...)
dmrg(operator, init_state; kwargs...) = eigsolve(operator, init_state; kwargs...)
Loading
Loading