Skip to content

Commit 262293b

Browse files
mat mul
1 parent ba49d16 commit 262293b

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

src/fillalgebra.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
## vec
1+
const FillVector{F,A} = Fill{F,1,A}
2+
const FillMatrix{F,A} = Fill{F,2,A}
23

4+
## vec
35
vec(a::Ones{T}) where T = Ones{T}(length(a))
46
vec(a::Zeros{T}) where T = Zeros{T}(length(a))
57
vec(a::Fill{T}) where T = Fill{T}(a.value,length(a))
@@ -100,23 +102,34 @@ end
100102
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
101103
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))
102104

103-
function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
104-
fB = similar(parent(a), size(b, 1), size(b, 2))
105-
fill!(fB, b.value)
106-
return a*fB
105+
function *(x::AbstractMatrix, f::FillMatrix)
106+
axes(x, 2) axes(f, 1) &&
107+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
108+
m = size(f, 2)
109+
repeat(sum(x, dims=2) * f.value, 1, m)
107110
end
108111

109-
function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
110-
fB = similar(parent(a), size(b, 1), size(b, 2))
111-
fill!(fB, b.value)
112-
return a*fB
112+
function *(f::FillMatrix, x::AbstractMatrix)
113+
axes(f, 2) axes(x, 1) &&
114+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
115+
m = size(f, 1)
116+
repeat(sum(x, dims=1) * f.value, m, 1)
113117
end
114118

115-
function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
116-
fB = similar(a, size(b, 1), size(b, 2))
117-
fill!(fB, b.value)
118-
return a*fB
119+
function *(x::AbstractMatrix, f::Ones)
120+
axes(x, 2) axes(f, 1) &&
121+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
122+
m = size(f, 2)
123+
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
119124
end
125+
126+
function *(f::Ones, x::AbstractMatrix)
127+
axes(f, 2) axes(x, 1) &&
128+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
129+
m = size(f, 1)
130+
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
131+
end
132+
120133
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
121134
la, lb = length(a), length(b)
122135
if la lb

0 commit comments

Comments
 (0)