Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.8"
version = "0.3.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
1 change: 0 additions & 1 deletion src/BaseExtensions/BaseExtensions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
module BaseExtensions
include("indexin.jl")
include("permutedims.jl")
end
20 changes: 0 additions & 20 deletions src/BaseExtensions/permutedims.jl

This file was deleted.

3 changes: 1 addition & 2 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd,
include("blockedtuple.jl")
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("fusedims.jl")
include("splitdims.jl")
include("matricize.jl")
include("contract/contract.jl")
include("contract/output_labels.jl")
include("contract/blockedperms.jl")
Expand Down
54 changes: 27 additions & 27 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ using Base.PermutedDimsArrays: genperm
# i.e. `ContractAdd`?
function output_axes(
::typeof(contract),
biperm_dest::BlockedPermutation{2},
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation{2},
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
biperm2::AbstractBlockPermutation{2},
α::Number=one(Bool),
)
axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1)
Expand All @@ -22,11 +22,11 @@ end
# i.e. `ContractAdd`?
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{0},
perm_dest::AbstractBlockPermutation{0},
a1::AbstractArray,
perm1::BlockedPermutation{1},
perm1::AbstractBlockPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
perm2::AbstractBlockPermutation{1},
α::Number=one(Bool),
)
axes_contracted = blockpermute(axes(a1), perm1)
Expand All @@ -38,11 +38,11 @@ end
# Vec-mat.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
perm_dest::AbstractBlockPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
perm1::AbstractBlockPermutation{1},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
biperm2::AbstractBlockPermutation{2},
α::Number=one(Bool),
)
(axes_contracted,) = blockpermute(axes(a1), perm1)
Expand All @@ -54,11 +54,11 @@ end
# Mat-vec.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
perm_dest::AbstractBlockPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{2},
perm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{1},
biperm2::AbstractBlockPermutation{1},
α::Number=one(Bool),
)
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
Expand All @@ -70,11 +70,11 @@ end
# Outer product.
function output_axes(
::typeof(contract),
biperm_dest::BlockedPermutation{2},
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
perm1::BlockedPermutation{1},
perm1::AbstractBlockPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
perm2::AbstractBlockPermutation{1},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm1))
Expand All @@ -86,11 +86,11 @@ end
# Array-scalar contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
perm_dest::AbstractBlockPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
perm1::AbstractBlockPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{0},
perm2::AbstractBlockPermutation{0},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm1))
Expand All @@ -101,11 +101,11 @@ end
# Scalar-array contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
perm_dest::AbstractBlockPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{0},
perm1::AbstractBlockPermutation{0},
a2::AbstractArray,
perm2::BlockedPermutation{1},
perm2::AbstractBlockPermutation{1},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm2))
Expand All @@ -116,11 +116,11 @@ end
# Scalar-scalar contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{0},
perm_dest::AbstractBlockPermutation{0},
a1::AbstractArray,
perm1::BlockedPermutation{0},
perm1::AbstractBlockPermutation{0},
a2::AbstractArray,
perm2::BlockedPermutation{0},
perm2::AbstractBlockPermutation{0},
α::Number=one(Bool),
)
return ()
Expand All @@ -130,11 +130,11 @@ end
# i.e. `ContractAdd`?
function allocate_output(
::typeof(contract),
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation,
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation,
α::Number=one(Bool),
)
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
Expand Down
6 changes: 3 additions & 3 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...)
biperm_dest = blockedpermvcat(permblocks_dest...)
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...)
biperm1 = blockedpermvcat(permblocks1...)
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...)
biperm2 = blockedpermvcat(permblocks2...)
return biperm_dest, biperm1, biperm2
end
12 changes: 6 additions & 6 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ default_contract_alg() = Matricize()
function contract!(
alg::Algorithm,
a_dest::AbstractArray,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation,
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation,
α::Number,
β::Number,
)
Expand Down Expand Up @@ -110,11 +110,11 @@ end

function contract(
alg::Algorithm,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation,
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation,
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation,
α::Number;
kwargs...,
)
Expand Down
101 changes: 9 additions & 92 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
@@ -1,103 +1,20 @@
using LinearAlgebra: mul!

function contract!(
alg::Matricize,
::Matricize,
a_dest::AbstractArray,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number,
β::Number,
)
a_dest_mat = fusedims(a_dest, biperm_dest)
a1_mat = fusedims(a1, biperm1)
a2_mat = fusedims(a2, biperm2)
_mul!(a_dest_mat, a1_mat, a2_mat, α, β)
splitdims!(a_dest, a_dest_mat, biperm_dest)
return a_dest
end

# Matrix multiplication.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Inner product.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractVector,
a2::AbstractVector,
α::Number,
β::Number,
)
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
return a_dest
end

# Vec-mat.
function _mul!(
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
)
mul!(transpose(a_dest), transpose(a1), a2, α, β)
return a_dest
end

# Mat-vec.
function _mul!(
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Outer product.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, transpose(a2), α, β)
return a_dest
end

# Array-scalar contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractVector,
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
α′ = a2[] * α
a_dest .= a1 .* α′ .+ a_dest .* β
return a_dest
end

# Scalar-array contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractArray{<:Any,0},
a2::AbstractVector,
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
return a_dest
end

# Scalar-scalar contraction.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractArray{<:Any,0},
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest[] = a1[] * a2[] * α + a_dest[] * β
a_dest_mat = matricize(a_dest, biperm_dest)
a1_mat = matricize(a1, biperm1)
a2_mat = matricize(a2, biperm2)
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
unmatricize!(a_dest, a_dest_mat, biperm_dest)
return a_dest
end
Loading
Loading