Skip to content

Exponential of a matrix #968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b0a74b1
Working implementation of expm.
loiseaujc Mar 27, 2025
fa36f33
Improved implementation + error handling.
loiseaujc Mar 27, 2025
d8b1857
Added docstring for the interface.
loiseaujc Mar 27, 2025
4089d18
Specs + example.
loiseaujc Mar 27, 2025
c6857bc
Update doc/specs/stdlib_linalg.md
loiseaujc Mar 30, 2025
4310db5
Replace matmul with gemm.
loiseaujc Mar 31, 2025
59ffb20
Error handling tests.
loiseaujc Mar 31, 2025
56474e1
Merge branch 'matrix_exponential' into master
loiseaujc Jul 3, 2025
75e0892
Merge pull request #2 from loiseaujc/master
loiseaujc Jul 3, 2025
8d6a3f9
Remove tests for failure to pinpoint seg fault.
loiseaujc Jul 3, 2025
65ad5f2
Pinpointing why the expm test fails.
loiseaujc Jul 3, 2025
cc3f1f2
Revert "Pinpointing why the expm test fails."
loiseaujc Jul 3, 2025
28bb69b
Remove print statement.
loiseaujc Jul 3, 2025
f479582
Change operator norm to standard norm for error checking.
loiseaujc Jul 3, 2025
b092515
Fix import
loiseaujc Jul 3, 2025
1bb1e01
Make use of stdlib_constants to avoid redefining some variables.
loiseaujc Jul 4, 2025
3ebdf9e
Replaced matmul with gemm.
loiseaujc Jul 4, 2025
f99e804
merge trick replacing if i == j.
loiseaujc Jul 8, 2025
8acc8de
Merge branch 'master'
loiseaujc Jul 8, 2025
34745cb
Make use of the new handle_gesv_info function.
loiseaujc Jul 8, 2025
534a88d
Define log(2.0) as a constant.
loiseaujc Jul 8, 2025
d97043d
Specify integer kind in size function.
loiseaujc Jul 8, 2025
a381d0b
subroutine driver and interface (in-place and out-of-place)
loiseaujc Jul 9, 2025
1bc6427
Fixed error handling.
loiseaujc Jul 9, 2025
531261a
Looking for the msys2-build error
loiseaujc Jul 9, 2025
3941673
Revert "Looking for the msys2-build error"
loiseaujc Jul 9, 2025
3b9c77d
Print computed matrix for reference.
loiseaujc Jul 9, 2025
2258de3
Revert "Print computed matrix for reference."
loiseaujc Jul 9, 2025
ae14bb5
Looking for the mysys2-build error.
loiseaujc Jul 9, 2025
1fb76b2
Revert "Looking for the mysys2-build error."
loiseaujc Jul 9, 2025
5eb6b47
Looking for mysys2-build error.
loiseaujc Jul 9, 2025
ed0a4e0
Jose's fix for the expm tests.
loiseaujc Jul 15, 2025
8d561ca
Update doc/specs/stdlib_linalg.md
loiseaujc Jul 15, 2025
22b8686
Update doc/specs/stdlib_linalg.md
loiseaujc Jul 15, 2025
c3ed61a
Update src/stdlib_linalg_matrix_functions.fypp
loiseaujc Jul 15, 2025
40e35d5
Fix typo.
loiseaujc Jul 15, 2025
7e8cca3
Merge branch 'master' into matrix_exponential
jalvesz Jul 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -1880,3 +1880,37 @@ If `err` is not present, exceptions trigger an `error stop`.
{!example/linalg/example_mnorm.f90!}
```

## `expm` - Computes the matrix exponential {#expm}

### Status

Experimental

### Description

Given a matrix \(A\), this function compute its matrix exponential \(E = \exp(A)\) using a Pade approximation.

### Syntax

`E = ` [[stdlib_linalg(module):expm(interface)]] `(a [, order, err])`

### Arguments

`a`: Shall be a rank-2 `real` or `complex` array containing the data. It is an `intent(in)` argument.

`order` (optional): Shall be a non-negative `integer` value specifying the order of the Pade approximation. By default `order=10`. It is an `intent(in)` argument.

`err` (optional): Shall be a `type(linalg_state_type)` value. This is an `intent(out)` argument.
### Return value

The returned array `E` contains the Pade approximation of \(\exp(A)\).

If `A` is non-square or `order` is negative, it raise a `LINALG_VALUE_ERROR`.
If `err` is not present, exceptions trigger an `error stop`.

### Example

```fortran
{!example/linalg/example_expm.f90!}
```

1 change: 1 addition & 0 deletions example/linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ ADD_EXAMPLE(qr)
ADD_EXAMPLE(qr_space)
ADD_EXAMPLE(cholesky)
ADD_EXAMPLE(chol)
ADD_EXAMPLE(expm)
7 changes: 7 additions & 0 deletions example/linalg/example_expm.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
program example_expm
use stdlib_linalg, only: expm
implicit none
real :: A(3, 3), E(3, 3)
A = reshape([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3])
E = expm(A)
end program example_expm
5 changes: 3 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ set(fppFiles
stdlib_linalg_kronecker.fypp
stdlib_linalg_cross_product.fypp
stdlib_linalg_eigenvalues.fypp
stdlib_linalg_solve.fypp
stdlib_linalg_solve.fypp
stdlib_linalg_determinant.fypp
stdlib_linalg_qr.fypp
stdlib_linalg_inverse.fypp
stdlib_linalg_pinv.fypp
stdlib_linalg_norms.fypp
stdlib_linalg_state.fypp
stdlib_linalg_svd.fypp
stdlib_linalg_svd.fypp
stdlib_linalg_cholesky.fypp
stdlib_linalg_schur.fypp
stdlib_linalg_matrix_functions.fypp
stdlib_optval.fypp
stdlib_selection.fypp
stdlib_sorting.fypp
Expand Down
48 changes: 48 additions & 0 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module stdlib_linalg
public :: eigh
public :: eigvals
public :: eigvalsh
public :: expm
public :: eye
public :: inv
public :: invert
Expand Down Expand Up @@ -1678,6 +1679,53 @@ module stdlib_linalg
#:endfor
end interface mnorm

!> Matrix exponential: function interface
interface expm
!! version : experimental
!!
!! Computes the exponential of a matrix using a rational Pade approximation.
!! ([Specification](../page/specs/stdlib_linalg.html#expm))
!!
!! ### Description
!!
!! This interface provides methods for computing the exponential of a matrix
!! represented as a standard Fortran rank-2 array. Supported data types include
!! `real` and `complex`.
!!
!! By default, the order of the Pade approximation is set to 10. It can be changed
!! via the `order` argument which must be non-negative.
!!
!! If the input matrix is non-square or the order of the Pade approximation is
!! negative, the function returns an error state.
!!
!! ### Example
!!
!! ```fortran
!! real(dp) :: A(3, 3), E(3, 3)
!!
!! A = reshape([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3])
!!
!! ! Default Pade approximation of the matrix exponential.
!! E = expm(A)
!!
!! ! Pade approximation with specified order.
!! E = expm(A, order=12)
!! ```
!!
#:for rk,rt,ri in RC_KINDS_TYPES
module function stdlib_expm_${ri}$(A, order, err) result(E)
!> Input matrix a(n, n).
${rt}$, intent(in) :: A(:, :)
!> [optional] Order of the Pade approximation (default `order=10`)
integer(ilp), optional, intent(in) :: order
!> [optional] State return flag. On error, if not requested, the code will stop.
type(linalg_state_type), optional, intent(out) :: err
!> Exponential of the input matrix E = exp(A).
${rt}$, allocatable :: E(:, :)
end function stdlib_expm_${ri}$
#:endfor
end interface expm

contains


Expand Down
130 changes: 130 additions & 0 deletions src/stdlib_linalg_matrix_functions.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#:include "common.fypp"
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
submodule (stdlib_linalg) stdlib_linalg_matrix_functions
use stdlib_linalg_constants
use stdlib_linalg_lapack, only: gesv
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
implicit none

#:for rk, rt, ri in (REAL_KINDS_TYPES)
${rt}$, parameter :: zero_${ri}$ = 0._${rk}$
${rt}$, parameter :: one_${ri}$ = 1._${rk}$
#:endfor
#:for rk, rt, ri in (CMPLX_KINDS_TYPES)
${rt}$, parameter :: zero_${ri}$ = (0._${rk}$, 0._${rk}$)
${rt}$, parameter :: one_${ri}$ = (1._${rk}$, 0._${rk}$)
#:endfor

contains

#:for rk,rt,ri in RC_KINDS_TYPES
module function stdlib_expm_${ri}$(A, order, err) result(E)
!> Input matrix A(n, n).
${rt}$, intent(in) :: A(:, :)
!> [optional] Order of the Pade approximation.
integer(ilp), optional, intent(in) :: order
!> [optional] State return flag.
type(linalg_state_type), optional, intent(out) :: err
!> Exponential of the input matrix E = exp(A).
${rt}$, allocatable :: E(:, :)

! Internal variables.
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
real(${rk}$) :: a_norm, c
integer(ilp) :: m, n, ee, k, s, order_, i, j
logical(lk) :: p
character(len=*), parameter :: this = "expm"
type(linalg_state_type) :: err0

! Deal with optional args.
order_ = 10 ; if (present(order)) order_ = order

! Problem's dimension.
m = size(A, 1) ; n = size(A, 2)

if (m /= n) then
err = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n])
call linalg_error_handling(err0, err)
else if (order_ < 0) then
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation &
needs to be positive, order=', order_)
call linalg_error_handling(err0, err)
endif

! Compute the L-infinity norm.
a_norm = mnorm(A, "inf")

! Determine scaling factor for the matrix.
ee = int(log(a_norm) / log(2.0_${rk}$)) + 1
s = max(0, ee+1)

! Scale the input matrix & initialize polynomial.
A2 = A/2.0_${rk}$**s ; X = A2

! First step of the Pade approximation.
c = 0.5_${rk}$
allocate (E, source=A2) ; allocate (Q, source=A2)
do concurrent(i=1:n, j=1:n)
E(i, j) = c*E(i, j) ; if (i == j) E(i, j) = 1.0_${rk}$ + E(i, j) ! E = I + c*A2
Q(i, j) = -c*Q(i, j) ; if (i == j) Q(i, j) = 1.0_${rk}$ + Q(i, j) ! Q = I - c*A2
enddo

! Iteratively compute the Pade approximation.
p = .true.
do k = 2, order_
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
X = matmul(A2, X)
do concurrent(i=1:n, j=1:n)
E(i, j) = E(i, j) + c*X(i, j) ! E = E + c*X
enddo
if (p) then
do concurrent(i=1:n, j=1:n)
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
enddo
else
do concurrent(i=1:n, j=1:n)
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
enddo
endif
p = .not. p
enddo

block
integer(ilp) :: ipiv(n), info
call gesv(n, n, Q, n, ipiv, E, n, info) ! E = inv(Q) @ E
call handle_gesv_info(info, n, n, n, err0)
call linalg_error_handling(err0, err)
end block

! This loop should eventually be replaced by a fast matrix_power function.
do k = 1, s
E = matmul(E, E)
enddo
return
contains
elemental subroutine handle_gesv_info(info,lda,n,nrhs,err)
integer(ilp), intent(in) :: info,lda,n,nrhs
type(linalg_state_type), intent(out) :: err
! Process output
select case (info)
case (0)
! Success
case (-1)
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size n=',n)
case (-2)
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid rhs size n=',nrhs)
case (-4)
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid matrix size a=',[lda,n])
case (-7)
err = linalg_state_type(this,LINALG_ERROR,'invalid matrix size a=',[lda,n])
case (1:)
err = linalg_state_type(this,LINALG_ERROR,'singular matrix')
case default
err = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
end select
end subroutine handle_gesv_info
end function stdlib_expm_${ri}$
#:endfor

end submodule stdlib_linalg_matrix_functions
2 changes: 2 additions & 0 deletions test/linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(
"test_linalg_svd.fypp"
"test_linalg_matrix_property_checks.fypp"
"test_linalg_sparse.fypp"
"test_linalg_expm.fypp"
)
fypp_f90("${fyppFlags}" "${fppFiles}" outFiles)

Expand All @@ -35,3 +36,4 @@ ADDTEST(linalg_schur)
ADDTEST(linalg_svd)
ADDTEST(blas_lapack)
ADDTEST(linalg_sparse)
ADDTEST(linalg_expm)
90 changes: 90 additions & 0 deletions test/linalg/test_linalg_expm.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#:include "common.fypp"
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
! Test Schur decomposition
module test_linalg_expm
use testdrive, only: error_type, check, new_unittest, unittest_type
use stdlib_linalg_constants
use stdlib_linalg, only: expm, eye, mnorm

implicit none (type,external)

public :: test_expm_computation

contains

!> schur decomposition tests
subroutine test_expm_computation(tests)
!> Collection of tests
type(unittest_type), allocatable, intent(out) :: tests(:)

allocate(tests(0))

#:for rk,rt,ri in RC_KINDS_TYPES
tests = [tests, new_unittest("expm_${ri}$",test_expm_${ri}$)]
#:endfor

end subroutine test_expm_computation

!> Matrix exponential with analytic expression.
#:for rk,rt,ri in RC_KINDS_TYPES
subroutine test_expm_${ri}$(error)
type(error_type), allocatable, intent(out) :: error
! Problem dimension.
integer(ilp), parameter :: n = 5, m = 6
! Test matrix.
${rt}$ :: A(n, n), E(n, n), Eref(n, n)
real(${rk}$) :: err
integer(ilp) :: i, j

! Initialize matrix.
A = 0.0_${rk}$
do i = 1, n-1
A(i, i+1) = m*1.0_${rk}$
enddo

! Reference with analytical exponential.
Eref = eye(n, mold=1.0_${rk}$)
do i = 1, n-1
do j = 1, n-i
Eref(i, i+j) = Eref(i, i+j-1)*m/j
enddo
enddo

! Compute matrix exponential.
E = expm(A)

! Check result.
err = mnorm(Eref - E, "inf")
call check(error, err < (n**2)*epsilon(1.0_${rk}$), "Analytical matrix exponential.")
if (allocated(error)) return
return
end subroutine test_expm_${ri}$
#:endfor

end module test_linalg_expm

program test_expm
use, intrinsic :: iso_fortran_env, only : error_unit
use testdrive, only : run_testsuite, new_testsuite, testsuite_type
use test_linalg_expm, only : test_expm_computation
implicit none
integer :: stat, is
type(testsuite_type), allocatable :: testsuites(:)
character(len=*), parameter :: fmt = '("#", *(1x, a))'

stat = 0

testsuites = [ &
new_testsuite("linalg_expm", test_expm_computation) &
]

do is = 1, size(testsuites)
write(error_unit, fmt) "Testing:", testsuites(is)%name
call run_testsuite(testsuites(is)%collect, error_unit, stat)
end do

if (stat > 0) then
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
error stop
end if
end program test_expm
Loading