Skip to content

add rank-1 update of Cholesky decomposition #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
Cargo.lock
.idea
69 changes: 69 additions & 0 deletions src/cholesky_update.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use ndarray::{Array1, ArrayBase, DataMut, Ix2, NdFloat};


pub trait CholeskyUpdate<F> {
fn cholesky_update_inplace(&mut self, update_vector: &Array1<F>);
}
Comment on lines +4 to +6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add documentation to this trait?


impl<V,F> CholeskyUpdate<F> for ArrayBase<V,Ix2>
where
F: NdFloat,
V: DataMut<Elem=F>,
{
fn cholesky_update_inplace(&mut self, update_vector: &Array1<F>) {
let n = self.shape()[0];
if self.shape()[0] != update_vector.len() {
panic!("update_vector should be same size as self");
}
let mut w=update_vector.to_owned();
let mut b=F::from(1.0).unwrap();
for j in 0..n{
let ljj=self[(j,j)];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For indexing, use the traits in index.rs, which are faster in release mode.

let ljj2=ljj*ljj;
let wj=w[j];
let wj2=wj*wj;
let nljj=(ljj2+wj2/b).sqrt();
let gamma=ljj2*b+wj2;
for k in j+1..n{
let lkj=self[(k,j)];
let wk=w[k]-wj*lkj/ljj;
self[(k,j)]=nljj*(lkj/ljj+wj*wk/gamma);
w[k]=wk;
}
b=b+wj2/ljj2;
self[(j,j)]=nljj;
}
}
}



#[cfg(test)]
mod test{
use approx::assert_abs_diff_eq;
use super::*;
use ndarray::{array, Array};
use crate::cholesky::Cholesky;

#[test]
fn test_cholesky_update(){
let mut arr=array![[1.0, 0.0, 2.0, 3.0, 4.0],
[-2.0, 3.0, 10.0,5.0, 6.0],
[-1.0,-2.0,-7.0, 8.0, 9.0],
[11.0, 12.0, 3.0, 14.0, 5.0],
[8.0, 2.0, 13.0, 4.0, 5.0]];
arr=arr.t().dot(&arr);
let mut l_tri = arr.cholesky().unwrap();

let x = Array::from(vec![1.0, 2.0, 3.0,0.0, 1.0]);
let vt=x.clone().into_shape((1,x.shape()[0])).unwrap();

Check warning on line 59 in src/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-ubuntu-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 59 in src/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-windows-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`
let v=x.clone().into_shape((x.shape()[0],1)).unwrap();

Check warning on line 60 in src/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-ubuntu-latest

use of deprecated method `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 60 in src/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-ubuntu-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 60 in src/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-windows-latest

use of deprecated method `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 60 in src/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-windows-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

l_tri.cholesky_update_inplace(&x);

let restore=l_tri.dot(&l_tri.t());
let expected=arr+v.dot(&vt);

assert_abs_diff_eq!(restore, expected, epsilon=1e-7);
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod reflection;
pub mod svd;
pub mod triangular;
pub mod tridiagonal;
pub mod cholesky_update;

use ndarray::{ArrayBase, Ix2, RawData, ShapeError};
use thiserror::Error;
Expand Down
32 changes: 32 additions & 0 deletions tests/cholesky_update.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use approx::assert_abs_diff_eq;
use ndarray::prelude::*;
use proptest::prelude::*;
use linfa_linalg::{cholesky::*, cholesky_update::*};
mod common;

prop_compose! {
fn gram_arr()
(arr in common::square_arr()) -> (Array2<f64>,Array1<f64>){
let dim = arr.nrows();
let mut mul = arr.t().dot(&arr);
for i in 0..dim {
mul[(i, i)] += 1.0;
}

(mul,arr.slice(s![0,..]).to_owned())
}
}

fn run_cholesky_update_test(orig: (Array2<f64>, Array1<f64>)) {

Check warning on line 20 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-ubuntu-latest

function `run_cholesky_update_test` is never used

Check warning on line 20 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-ubuntu-latest

function `run_cholesky_update_test` is never used

Check warning on line 20 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-windows-latest

function `run_cholesky_update_test` is never used

Check warning on line 20 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-windows-latest

function `run_cholesky_update_test` is never used
let (arr, x) = orig;
let mut l_tri = arr.cholesky().unwrap();
l_tri.cholesky_update_inplace(&x);

let vt=x.clone().into_shape((1,x.shape()[0])).unwrap();

Check warning on line 25 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-ubuntu-latest

use of deprecated method `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 25 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-ubuntu-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 25 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-windows-latest

use of deprecated method `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 25 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-windows-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`
let v=x.clone().into_shape((x.shape()[0],1)).unwrap();

Check warning on line 26 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-ubuntu-latest

use of deprecated method `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 26 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-ubuntu-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 26 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-stable-windows-latest

use of deprecated method `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

Check warning on line 26 in tests/cholesky_update.rs

View workflow job for this annotation

GitHub Actions / testing-1.65.0-windows-latest

use of deprecated associated function `ndarray::impl_methods::<impl ndarray::ArrayBase<S, D>>::into_shape`: Use `.into_shape_with_order()` or `.to_shape()`

let restore = l_tri.dot(&l_tri.t());
let expected = arr + v.dot(&vt);
assert_abs_diff_eq!(restore, expected, epsilon = 1e-7);
}

Loading