Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7a51966
Add `Problem` as type parameter to `SweepIterator`
Sep 22, 2025
245182b
Format test files and improve comparisons for readabilty on failure
Sep 24, 2025
7af4b25
Redesign iterator interface by introducing AbstractNetworkIterator ab…
Sep 24, 2025
c0ae5d0
Add `EachRegion` adapter that wraps `RegionIterator`, behaving the sa…
Sep 25, 2025
3b9d0af
Add unit tests for the `AbstractNetworkIterator` interface
Sep 30, 2025
4ef4e75
Rename `done` to `laststep` to better reflect the when it evalutes to…
Sep 30, 2025
e112eb4
Rename `previous_region` to `prev_region` to better align with julia …
Sep 30, 2025
da360e0
Rename `PauseAfterIncrement` -> `NoComputeStep` and improve some vari…
Oct 1, 2025
8bfc483
Make `extract` and `subspace_expand` mutating
Oct 3, 2025
1ef8498
Make `update` mutable
Oct 3, 2025
0a6e891
Make `insert` mutable
Oct 3, 2025
0653c47
First implementation of an `options` system.
Oct 3, 2025
d77321e
Simplify options interface to a single function `default_kwargs`.
Oct 6, 2025
aff14c7
Put calls to `extract!` etc in `compute!` function directly
Oct 6, 2025
4b21cc9
Refactor the region plan generating code.
Oct 7, 2025
e71512f
Have `dmrg` take a strict number of arguments
Oct 7, 2025
a4ce308
Purge non-mutating field setter functions.
Oct 7, 2025
a8b2c51
Use `current_kwargs` for getting kwargs from `RegionIterator`
Oct 7, 2025
18a8503
Introduce defaults using `default_kwargs` and be stricter about which…
Oct 7, 2025
0c9022c
Swap order of local_state and region_iter args
Oct 7, 2025
a9be11e
Add some unit tests for the defaults
Oct 7, 2025
4d52088
Rename file options.jl -> test_default_kwargs.jl
Oct 7, 2025
613d533
Fix `euler_sweep` returning kwargs not as `NamedTuple`
Oct 7, 2025
20bf783
The `sweep_solve` callbacks now get called without any keyword argume…
Oct 7, 2025
568c631
Some minor refactoring of the iterators.
Oct 7, 2025
fed9137
The `EachRegion` adapter now flattens the nested Sweep/Region iterato…
Oct 9, 2025
4ce453e
Add tests for `EachRegion` and `eachregion` wrapper functions
Oct 9, 2025
c59a9c5
Rename `laststep` -> `islaststep` in fitting with Julia conventions.
Oct 9, 2025
62195b6
Overhaul `default_kwargs` such that it mirrors the function signature…
Oct 9, 2025
917f2f1
Rename `NoComputeStep` to `IncrementOnly`
Oct 9, 2025
112d55e
Remove @info statement and fix bug with `astypes` not promoting corre…
Oct 10, 2025
0a9f127
Update `default_kwargs` tests.
Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/solvers/adapters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,21 @@ iterator which outputs a tuple of the form (current_region, current_region_kwarg
at each step.
"""
region_tuples(R::RegionIterator) = TupleRegionIterator(R)

"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator}

Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

done(NC::PauseAfterIncrement) = done(NC.parent)
state(NC::PauseAfterIncrement) = state(NC.parent)
increment!(NC::PauseAfterIncrement) = increment!(NC.parent)
compute!(NC::PauseAfterIncrement) = NC

PauseAfterIncrement(NC::PauseAfterIncrement) = NC
Copy link
Member

@mtfishman mtfishman Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small recommendation, I would prefer a variable name besides NC, maybe we could just use iterator.

39 changes: 20 additions & 19 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ operator(A::ApplyExpProblem) = A.operator
state(A::ApplyExpProblem) = A.state
current_exponent(A::ApplyExpProblem) = A.current_exponent
function current_time(A::ApplyExpProblem)
t = im*A.current_exponent
t = im * A.current_exponent
return iszero(imag(t)) ? real(t) : t
end

Expand All @@ -36,41 +36,42 @@ function update(
iszero(abs(exponent_step)) && return prob, local_state

local_state, info = solver(
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
x -> optimal_map(operator(prob), x), exponent_step, local_state; kws...
)
if nsites==1
if nsites == 1
curr_reg = current_region(region_iterator)
next_reg = next_region(region_iterator)
if !isnothing(next_reg) && next_reg != curr_reg
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
v1, v2 = src(next_edge), dst(next_edge)
psi = copy(state(prob))
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1]*R_t
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1] * R_t
end
end

prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
prob = set_current_exponent(prob, current_exponent(prob) + exponent_step)

return prob, local_state
end

function sweep_callback(
problem::ApplyExpProblem;
function default_sweep_callback(
sweep_iterator::SweepIterator{<:ApplyExpProblem};
exponent_description="exponent",
outputlevel,
sweep,
nsweeps,
process_time=identity,
kws...,
kwargs...,
)
if outputlevel >= 1
the_problem = problem(sweep_iterator)
@printf(
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
" Current %s = %s, ",
exponent_description,
process_time(current_exponent(the_problem))
)
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
println()
flush(stdout)
end
Expand All @@ -88,9 +89,10 @@ function applyexp(
kws...,
)
exponent_steps = diff([zero(eltype(exponents)); exponents])
# exponent_steps = diff(exponents)
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
sweep_iter = sweep_iterator(init_prob, kws_array)
sweep_iter = SweepIterator(init_prob, kws_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
return state(converged_prob)
end
Expand All @@ -111,11 +113,10 @@ function time_evolve(
time_points,
init_state;
process_time=process_real_times,
sweep_callback=(
a...; k...
)->sweep_callback(a...; exponent_description="time", process_time, k...),
sweep_callback=(a...; k...) ->
default_sweep_callback(a...; exponent_description="time", process_time, k...),
kws...,
)
exponents = [-im*t for t in time_points]
exponents = [-im * t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
end
16 changes: 10 additions & 6 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function update(
solver=eigsolve_solver,
kws...,
)
eigval, local_state = solver(ψ->optimal_map(operator(prob), ψ), local_state; kws...)
eigval, local_state = solver(ψ -> optimal_map(operator(prob), ψ), local_state; kws...)
prob = set_eigenvalue(prob, eigval)
if outputlevel >= 2
@printf(
Expand All @@ -44,12 +44,16 @@ function update(
return prob, local_state
end

function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
function default_sweep_callback(
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel
)
if outputlevel >= 1
if nsweeps >= 10
@printf("After sweep %02d/%d ", sweep, nsweeps)
nsweeps = length(sweep_iterator)
current_sweep = sweep_iterator.which_sweep
if length(sweep_iterator) >= 10
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
else
@printf("After sweep %d/%d ", sweep, nsweeps)
@printf("After sweep %d/%d ", current_sweep, nsweeps)
end
@printf("eigenvalue=%.12f", eigenvalue(problem))
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
Expand All @@ -73,7 +77,7 @@ function eigsolve(
init_prob = EigsolveProblem(;
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
)
sweep_iter = sweep_iterator(
sweep_iter = SweepIterator(
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
)
prob = sweep_solve(sweep_iter; outputlevel, kws...)
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/fitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function fit_tensornetwork(
insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false)
common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs)
kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps]
sweep_iter = sweep_iterator(init_prob, kwargs_array)
sweep_iter = SweepIterator(init_prob, kwargs_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob))
end
Expand Down
188 changes: 123 additions & 65 deletions src/solvers/iterators.jl
Original file line number Diff line number Diff line change
@@ -1,100 +1,158 @@
#
# SweepIterator
#

mutable struct SweepIterator
sweep_kws
region_iter
which_sweep::Int
end

problem(S::SweepIterator) = problem(S.region_iter)
"""
abstract type AbstractNetworkIterator
Base.length(S::SweepIterator) = length(S.sweep_kws)
A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
with a call to `increment!` before executing `compute!`, however the initial call to
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
this call is implict. Termination of the iterator is controlled by the function `done`.
"""
abstract type AbstractNetworkIterator end

function Base.iterate(S::SweepIterator, which=nothing)
if isnothing(which)
sweep_kws_state = iterate(S.sweep_kws)
else
sweep_kws_state = iterate(S.sweep_kws, which)
end
isnothing(sweep_kws_state) && return nothing
current_sweep_kws, next = sweep_kws_state
# We use greater than or equals here as we increment the state at the start of the iteration
done(NI::AbstractNetworkIterator) = state(NI) >= length(NI)

if !isnothing(which)
S.region_iter = region_iterator(
problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws...
)
end
S.which_sweep += 1
return S.region_iter, next
function Base.iterate(NI::AbstractNetworkIterator, init=true)
done(NI) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
init || increment!(NI)
rv = compute!(NI)
return rv, false
end

function sweep_iterator(problem, sweep_kws)
region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...)
return SweepIterator(sweep_kws, region_iter, 1)
end
function increment! end
compute!(NI::AbstractNetworkIterator) = NI

function sweep_iterator(problem, nsweeps::Integer; sweep_kws...)
return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps))
step!(NI::AbstractNetworkIterator) = step!(identity, NI)
function step!(f, NI::AbstractNetworkIterator)
compute!(NI)
f(NI)
increment!(NI)
return NI
end

#
# RegionIterator
#

@kwdef mutable struct RegionIterator{Problem,RegionPlan}
"""
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
"""
mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator
problem::Problem
region_plan::RegionPlan
which_region::Int = 1
const sweep::Int
which_region::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R}
return new{P,R}(problem, region_plan, sweep, 1)
end
end

state(R::RegionIterator) = R.which_region
Base.length(R::RegionIterator) = length(R.region_plan)

problem(R::RegionIterator) = R.problem

current_region_plan(R::RegionIterator) = R.region_plan[R.which_region]
current_region(R::RegionIterator) = current_region_plan(R)[1]
region_kwargs(R::RegionIterator) = current_region_plan(R)[2]
function previous_region(R::RegionIterator)
R.which_region==1 ? nothing : R.region_plan[R.which_region - 1][1]

function current_region(R::RegionIterator)
region, _ = current_region_plan(R)
return region
end
function next_region(R::RegionIterator)
R.which_region==length(R.region_plan) ? nothing : R.region_plan[R.which_region + 1][1]

function current_region_kwargs(R::RegionIterator)
_, kwargs = current_region_plan(R)
return kwargs
end

function previous_region(R::RegionIterator)
state(R) <= 1 && return nothing
prev, _ = R.region_plan[R.which_region - 1]
return prev
end
is_last_region(R::RegionIterator) = isnothing(next_region(R))

function Base.iterate(R::RegionIterator, which=1)
R.which_region = which
region_plan_state = iterate(R.region_plan, which)
isnothing(region_plan_state) && return nothing
(current_region, region_kwargs), next = region_plan_state
R.problem = region_step(problem(R), R; region_kwargs...)
return R, next
function next_region(R::RegionIterator)
is_last_region(R) && return nothing
next, _ = R.region_plan[R.which_region + 1]
return next
end
is_last_region(R::RegionIterator) = length(R) === state(R)

#
# Functions associated with RegionIterator
#

function region_iterator(problem; sweep_kwargs...)
return RegionIterator(; problem, region_plan=region_plan(problem; sweep_kwargs...))
function compute!(R::RegionIterator)
region_kwargs = current_region_kwargs(R)
R.problem = region_step(R; region_kwargs...)
return R
end
function increment!(R::RegionIterator)
R.which_region += 1
return R
end

function RegionIterator(problem; sweep, sweep_kwargs...)
plan = region_plan(problem; sweep, sweep_kwargs...)
return RegionIterator(problem, plan, sweep)
end

function region_step(
problem,
region_iterator;
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
sweep,
kws...,
region_iterator; extract_kwargs=(;), update_kwargs=(;), insert_kwargs=(;), kws...
)
problem, local_state = extract(problem, region_iterator; extract_kwargs..., sweep, kws...)
problem, local_state = update(
problem, local_state, region_iterator; update_kwargs..., kws...
)
problem = insert(problem, local_state, region_iterator; sweep, insert_kwargs..., kws...)
return problem
prob = problem(region_iterator)

sweep = region_iterator.sweep

prob, local_state = extract(prob, region_iterator; extract_kwargs..., sweep, kws...)
prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...)
prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...)
return prob
end

function region_plan(problem; kws...)
return euler_sweep(state(problem); kws...)
end

#
# SweepIterator
#

mutable struct SweepIterator{Problem} <: AbstractNetworkIterator
sweep_kws
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we parametrize SweepIterator by the type of sweep_kws? Or are we expecting it could be modified such that it changes type (though if that is the case we could make it explicit by setting the type parameter to Any)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering what would be the advantage of having this parameterized? I think the power of this SweepIterator object comes from that fact that it can be basically any object implementing the iterate interface (for example, an iterator that loads the kwargs from a file, even). The eltype of the iterator should be a NamedTuple, but then we cannot define this type necessarily as different kwargs then lead to different concrete NamedTuple types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Purely for the sake of type stability, in cases where that is possible. I was just making the point that by parametrizing and then setting the default to Any we could allow for either option but allow it to be dynamic by default, but we could add that option later if needed.

Also I would recommend ordering the fields in the struct as region_iter, sweep_kws, which_sweep, that seems more natural to me since we are inputting the problem argument as the first argument and generally dispatching on the problem.

region_iter::RegionIterator{Problem}
which_sweep::Int
function SweepIterator(problem, sweep_kws)
sweep_kws = Iterators.Stateful(sweep_kws)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sweep_kws = Iterators.Stateful(sweep_kws)
stateful_sweep_kws = Iterators.Stateful(sweep_kws)

I would find that a bit clearer (i.e. it is easier to keep track that we are now using the stateful version of sweep_kws).

first_kwargs, _ = Iterators.peel(sweep_kws)
region_iter = RegionIterator(problem; sweep=1, first_kwargs...)
return new{typeof(problem)}(sweep_kws, region_iter, 1)
end
end

done(SR::SweepIterator) = isnothing(peek(SR.sweep_kws))

region_iterator(S::SweepIterator) = S.region_iter
problem(S::SweepIterator) = problem(region_iterator(S))

state(SR::SweepIterator) = SR.which_sweep
Base.length(S::SweepIterator) = length(S.sweep_kws)
function increment!(SR::SweepIterator)
SR.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(SR.sweep_kws)
SR.region_iter = RegionIterator(problem(SR); sweep=state(SR), sweep_kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to our discussion around how the RegionIterator gets updated and how the problem gets passed around, maybe it could help to wrap this constructor call into a function, for example:

function update_region_iterator(iterator::RegionIterator; kwargs...)
  return RegionIterator(iterator.problem; kwargs...)
end

or something like that (mostly to hide the detail of how the problem gets passed along at this level of the code).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what does SR stand for? Generally I prefer lower case variable names, and also more descriptive ones, say sweep_iter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SR was an abbreviation of SweepRegionIterator which was a previous name for something I was using previously...

On the first point, I feel like update_region_iterator is not really descriptive of what the function as implemented there does. It does not update an existing RegionIterator, it creates a new one via the constructor. I do like however the following:

function update_region_iterator!(iterator::SweepIterator; kwargs...)
    iterator.region_iter = new_region_iterator(iterator.region_iter; kwargs...)
    return iterator
end

where new_region_iterator is a renaming of the method you describe. Let me know if you agree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not update an existing RegionIterator, it creates a new one via the constructor. I do like however the following:

I suppose that's a matter of perspective, i.e. we could think of it as a fancy kind of "setter" method that keeps some properties and sets other ones, which for immutable types has to create a new object by definition, but I'm ok with new_region_iterator.

return SR
end

function compute!(SR::SweepIterator)
for _ in SR.region_iter
# TODO: Is it sensible to execute the default region callback function?
end
end

# More basic constructor where sweep_kwargs are constant throughout sweeps
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
# Initialize this to an empty RegionIterator
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
return SweepIterator(problem, sweep_kwargs_iter)
end
Loading
Loading