Skip to content

Commit 6a8a326

Browse files
gdalleamontoison
authored andcommitted
First version with passing minimal tests
1 parent 2feb679 commit 6a8a326

File tree

7 files changed

+66
-28
lines changed

7 files changed

+66
-28
lines changed

src/adjtrans.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,25 @@ function mul!(
141141
conj!(res)
142142
end
143143

144+
function mul!(
145+
res::AbstractMatrix,
146+
op::AdjointLinearOperator{T, S},
147+
m::AbstractMatrix,
148+
α,
149+
β,
150+
) where {T, S}
151+
p = op.parent
152+
(size(m, 1) == size(p, 1) && size(res, 1) == size(p, 2) && size(m, 2) == size(res, 2)) ||
153+
throw(LinearOperatorException("shape mismatch"))
154+
if ishermitian(p)
155+
return mul!(res, p, m, α, β)
156+
elseif p.ctprod! !== nothing
157+
return p.ctprod!(res, m, α, β)
158+
else
159+
error("Not implemented")
160+
end
161+
end
162+
144163
function mul!(
145164
res::AbstractVector,
146165
op::TransposeLinearOperator{T, S},
@@ -188,6 +207,25 @@ function mul!(
188207
conj!(res)
189208
end
190209

210+
function mul!(
211+
res::AbstractMatrix,
212+
op::TransposeLinearOperator{T, S},
213+
m::AbstractMatrix,
214+
α,
215+
β,
216+
) where {T, S}
217+
p = op.parent
218+
(size(m, 1) == size(p, 1) && size(res, 1) == size(p, 2) && size(m, 2) == size(res, 2)) ||
219+
throw(LinearOperatorException("shape mismatch"))
220+
if issymmetric(p)
221+
return mul!(res, p, m, α, β)
222+
elseif p.tprod! !== nothing
223+
return p.tprod!(res, m, α, β)
224+
else
225+
error("Not implemented")
226+
end
227+
end
228+
191229
function mul!(
192230
res::AbstractVector,
193231
op::ConjugateLinearOperator{T, S},
@@ -200,7 +238,17 @@ function mul!(
200238
conj!(res)
201239
end
202240

203-
# TODO: overload the above for matrices?
241+
function mul!(
242+
res::AbstractMatrix,
243+
op::ConjugateLinearOperator{T, S},
244+
v::AbstractMatrix,
245+
α,
246+
β,
247+
) where {T, S}
248+
p = op.parent
249+
mul!(res, p, conj.(v), α, β)
250+
conj!(res)
251+
end
204252

205253
-(op::AdjointLinearOperator) = adjoint(-op.parent)
206254
-(op::TransposeLinearOperator) = transpose(-op.parent)

src/cat.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ function hcat_ctprod!(
3232
mul!(view(res, (Ancol + 1):nV), B, u, α, β)
3333
end
3434

35-
# TODO: overload the above for matrices?
36-
3735
function hcat(A::AbstractLinearOperator, B::AbstractLinearOperator)
3836
size(A, 1) == size(B, 1) || throw(LinearOperatorException("hcat: inconsistent row sizes"))
3937

@@ -93,8 +91,6 @@ function vcat_ctprod!(
9391
mul!(res, B, view(v, (Anrow + 1):nV), α, one(T))
9492
end
9593

96-
# TODO: overload the above for matrices?
97-
9894
function vcat(A::AbstractLinearOperator, B::AbstractLinearOperator)
9995
size(A, 2) == size(B, 2) || throw(LinearOperatorException("vcat: inconsistent column sizes"))
10096

src/constructors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ op = LinearOperator(Float64, 2, 2, false, false,
102102
(res, v) -> mul!(res, A, v),
103103
(res, w) -> mul!(res, A', w))
104104
```
105+
106+
The 3-args `mul!` also works when applying the operator on a matrix.
105107
"""
106108
function LinearOperator(
107109
::Type{T},

src/kron.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator)
4444
return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
4545
end
4646

47-
# TODO: overload the above for matrices?
48-
4947
kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B)
5048

5149
kron(A::AbstractLinearOperator, B::AbstractMatrix) = kron(A, LinearOperator(B))

src/operations.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,14 @@ function mul!(res::AbstractVector, op::AbstractLinearOperator{T}, v::AbstractVec
3232
end
3333
end
3434

35-
function mul!(res::AbstractMatrix, op::AbstractLinearOperator{T}, m::AbstractMatrix, α, β) where {T}
36-
# TODO: how to handle storage?
37-
error("5-argument `mul!` is not defined between a linear operator and a matrix.")
35+
function mul!(res::AbstractMatrix, op::AbstractLinearOperator, m::AbstractMatrix{T}, α, β) where {T}
36+
op.prod!(res, m, α, β)
3837
end
3938

40-
function mul!(res::AbstractVector, op::AbstractLinearOperator, v::AbstractVector{T}) where {T}
39+
function mul!(res::AbstractVecOrMat, op::AbstractLinearOperator, v::AbstractVecOrMat{T}) where {T}
4140
mul!(res, op, v, one(T), zero(T))
4241
end
4342

44-
function mul!(res::AbstractMatrix, op::AbstractLinearOperator, m::AbstractMatrix{T}) where {T}
45-
mul!(res, op, m, one(T), zero(T))
46-
end
47-
4843
# Apply an operator to a vector.
4944
function *(op::AbstractLinearOperator{T}, v::AbstractVector{S}) where {T, S}
5045
nrow, ncol = size(op)

src/special-operators.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ function mulOpEye!(res, v, α, β::T, n_min) where {T}
4343
end
4444
end
4545

46-
# TODO: overload the above for matrices?
47-
4846
"""
4947
opEye(T, n; S = Vector{T})
5048
opEye(n)
@@ -86,8 +84,6 @@ function mulOpOnes!(res, v, α, β::T) where {T}
8684
end
8785
end
8886

89-
# TODO: overload the above for matrices?
90-
9187
"""
9288
opOnes(T, nrow, ncol; S = Vector{T})
9389
opOnes(nrow, ncol)
@@ -111,8 +107,6 @@ function mulOpZeros!(res, v, α, β::T) where {T}
111107
end
112108
end
113109

114-
# TODO: overload the above for matrices?
115-
116110
"""
117111
opZeros(T, nrow, ncol; S = Vector{T})
118112
opZeros(nrow, ncol)
@@ -136,8 +130,6 @@ function mulSquareOpDiagonal!(res, d, v, α, β::T) where {T}
136130
end
137131
end
138132

139-
# TODO: overload the above for matrices?
140-
141133
"""
142134
opDiagonal(d)
143135
@@ -157,9 +149,6 @@ function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T}
157149
end
158150
res[(n_min + 1):end] .= 0
159151
end
160-
161-
# TODO: overload the above for matrices?
162-
163152
"""
164153
opDiagonal(nrow, ncol, d)
165154
@@ -184,8 +173,6 @@ function multRestrict!(res, I, u, α, β)
184173
res[I] = u
185174
end
186175

187-
# TODO: overload the above for matrices?
188-
189176
"""
190177
Z = opRestriction(I, ncol)
191178
Z = opRestriction(:, ncol)

test/test_linop.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ function test_linop()
6161
@test(norm(transpose(u) * A - transpose(u) * op) <= rtol * norm(u))
6262
@test(typeof(u' * op * v) <: Number)
6363
@test(norm(u' * A * v - u' * op * v) <= rtol * norm(u))
64+
65+
mv = hcat(v, -2v)
66+
mu = hcat(u, -2u)
67+
res_mat = similar(mu)
68+
res_trans = similar(mv)
69+
res_adj = similar(mv)
70+
mul!(res_mat, op, mv)
71+
mul!(res_trans, transpose(op), mu)
72+
mul!(res_adj, op', mu)
73+
@test(norm(A * mv - res_mat) <= rtol * norm(mv))
74+
@test(norm(transpose(A) * mu - res_trans) <= rtol * norm(mu))
75+
@test(norm(A' * mu - res_adj) <= rtol * norm(mu))
6476
end
6577

6678
A3 = Hermitian(A2' * A2)

0 commit comments

Comments
 (0)