From 877f0bbb4f7693196284ff6f3ea40e323b8f258e Mon Sep 17 00:00:00 2001 From: Thomas Arildsen Date: Thu, 29 Mar 2018 16:59:49 +0200 Subject: [PATCH 1/6] BUG: Soft thresholding helper function now handles complex numbers This is a proposed implementation that does away with the `handle_complex` option and determines the method automatically. I also added tests to validate the function with complex input. --- pyunlocbox/functions.py | 8 ++++---- pyunlocbox/tests/test_functions.py | 27 ++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/pyunlocbox/functions.py b/pyunlocbox/functions.py index 400807f..9ab11af 100644 --- a/pyunlocbox/functions.py +++ b/pyunlocbox/functions.py @@ -58,7 +58,7 @@ from pyunlocbox import operators as op -def _soft_threshold(z, T, handle_complex=True): +def _soft_threshold(z, T): r""" Return the soft thresholded signal. @@ -86,17 +86,17 @@ def _soft_threshold(z, T, handle_complex=True): array([-1, 0, 0, 0, 1]) """ - sz = np.maximum(np.abs(z) - T, 0) - if not handle_complex: + if not z.dtype == np.complex: # This soft thresholding method only supports real signal. - sz[:] = np.sign(z) * sz + sz = np.sign(z) * np.maximum(np.abs(z) - T, 0) else: # This soft thresholding method supports complex complex signal. # Transform to float to avoid integer division. # In our case 0 divided by 0 should be 0, not NaN, and is not an error. # It corresponds to 0 thresholded by 0, which is 0. + sz = np.maximum(np.abs(z) - T, 0, dtype = z.dtype) old_err_state = np.seterr(invalid='ignore') sz[:] = np.nan_to_num(1. * sz / (sz + T) * z) np.seterr(**old_err_state) diff --git a/pyunlocbox/tests/test_functions.py b/pyunlocbox/tests/test_functions.py index 8f26295..503ad43 100644 --- a/pyunlocbox/tests/test_functions.py +++ b/pyunlocbox/tests/test_functions.py @@ -131,16 +131,33 @@ def test_soft_thresholding(self): """ x = np.arange(-4, 5, 1) - # Test integer division for complex method. + # Test with integers Ts = [2] y_gold = [[-2, -1, 0, 0, 0, 0, 0, 1, 2]] - # Test division by 0 for complex method. + # Test with floats Ts.append([.4, .3, .2, .1, 0, .1, .2, .3, .4]) y_gold.append([-3.6, -2.7, -1.8, -.9, 0, .9, 1.8, 2.7, 3.6]) for k, T in enumerate(Ts): - for cmplx in [False, True]: - y_test = functions._soft_threshold(x, T, cmplx) - nptest.assert_array_equal(y_test, y_gold[k]) + y_test = functions._soft_threshold(x, T) + nptest.assert_array_equal(y_test, y_gold[k]) + # Test division by 0 for complex method. + x = np.arange(-4, 5, 1) + 0j + y_gold = [-3.6 + 0j, -2.7 + 0j, -1.8 + 0j, -.9 + 0j, + 0 + 0j, .9 + 0j, 1.8 + 0j, 2.7 + 0j, 3.6 + 0j] + y_test = functions._soft_threshold(x, Ts[-1]) + nptest.assert_array_equal(y_test, y_gold) + + x = 1j * np.arange(-4, 5, 1) + y_gold = [-3.6j, -2.7j, -1.8j, -.9j, 0j, .9j, 1.8j, 2.7j, 3.6j] + y_test = functions._soft_threshold(x, Ts[-1]) + nptest.assert_array_equal(y_test, y_gold) + + x = (1 + 1j) * np.array([-1, 1]) + T = .1 + y_gold = x - T * x / np.abs(x) + y_test = functions._soft_threshold(x, T) + nptest.assert_array_almost_equal(y_test, y_gold) + def test_norm_l1(self): """ From 04f5939579d20bd73f2bbe56ce1dfe306dd9c809 Mon Sep 17 00:00:00 2001 From: Thomas Arildsen Date: Thu, 29 Mar 2018 17:23:57 +0200 Subject: [PATCH 2/6] BUG: Fixed handling of complex input as list in soft thres. function --- pyunlocbox/functions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyunlocbox/functions.py b/pyunlocbox/functions.py index 9ab11af..7dcc593 100644 --- a/pyunlocbox/functions.py +++ b/pyunlocbox/functions.py @@ -87,7 +87,9 @@ def _soft_threshold(z, T): """ - if not z.dtype == np.complex: + input_dtype = np.asarray(z).dtype + + if not input_dtype == np.complex: # This soft thresholding method only supports real signal. sz = np.sign(z) * np.maximum(np.abs(z) - T, 0) @@ -96,7 +98,7 @@ def _soft_threshold(z, T): # Transform to float to avoid integer division. # In our case 0 divided by 0 should be 0, not NaN, and is not an error. # It corresponds to 0 thresholded by 0, which is 0. - sz = np.maximum(np.abs(z) - T, 0, dtype = z.dtype) + sz = np.maximum(np.abs(z) - T, 0, dtype = input_dtype) old_err_state = np.seterr(invalid='ignore') sz[:] = np.nan_to_num(1. * sz / (sz + T) * z) np.seterr(**old_err_state) From 1fb0deb5085b8a7202755d637eb1da7b234a91f1 Mon Sep 17 00:00:00 2001 From: Thomas Arildsen Date: Fri, 30 Mar 2018 22:22:58 +0200 Subject: [PATCH 3/6] BUG: Removed some whitespace bothering the lint check --- pyunlocbox/functions.py | 2 +- pyunlocbox/tests/test_functions.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyunlocbox/functions.py b/pyunlocbox/functions.py index 7dcc593..8612f01 100644 --- a/pyunlocbox/functions.py +++ b/pyunlocbox/functions.py @@ -98,7 +98,7 @@ def _soft_threshold(z, T): # Transform to float to avoid integer division. # In our case 0 divided by 0 should be 0, not NaN, and is not an error. # It corresponds to 0 thresholded by 0, which is 0. - sz = np.maximum(np.abs(z) - T, 0, dtype = input_dtype) + sz = np.maximum(np.abs(z) - T, 0, dtype=input_dtype) old_err_state = np.seterr(invalid='ignore') sz[:] = np.nan_to_num(1. * sz / (sz + T) * z) np.seterr(**old_err_state) diff --git a/pyunlocbox/tests/test_functions.py b/pyunlocbox/tests/test_functions.py index 503ad43..5395c2a 100644 --- a/pyunlocbox/tests/test_functions.py +++ b/pyunlocbox/tests/test_functions.py @@ -158,7 +158,6 @@ def test_soft_thresholding(self): y_test = functions._soft_threshold(x, T) nptest.assert_array_almost_equal(y_test, y_gold) - def test_norm_l1(self): """ Test the norm_l1 derived class. From ded1832cab21fe77c6abb50a5483537ffe99e8fb Mon Sep 17 00:00:00 2001 From: Thomas Arildsen Date: Sun, 1 Apr 2018 23:23:36 +0200 Subject: [PATCH 4/6] BUG: Attempt at initializing forward-backward solution as complex This should initialize the `sol` variable of the forward-backward solver as a complex variable if necessary. This is determined from whether y, A*x or AT*y of any of the objective function constituents is complex. --- pyunlocbox/acceleration.py | 11 ++++++++++- pyunlocbox/solvers.py | 15 ++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pyunlocbox/acceleration.py b/pyunlocbox/acceleration.py index 3badd91..c415fbd 100644 --- a/pyunlocbox/acceleration.py +++ b/pyunlocbox/acceleration.py @@ -306,7 +306,16 @@ def __init__(self, **kwargs): super(fista, self).__init__(**kwargs) def _pre(self, functions, x0): - self.sol = np.array(x0, copy=True) + enable_complex = False + for current_func in functions: + if (current_func.y().dtype == np.complex + or current_func.A(np.asarray(x0)).dtype == np.complex + or current_func.At(current_func.y()).dtype == np.complex): + enable_complex = True + if enable_complex: + self.sol = np.array(x0, copy=True, dtype=np.complex) + else: + self.sol = np.array(x0, copy=True) def _update_sol(self, solver, objective, niter): self.t = 1. if (niter == 1) else self.t # Restart variable t if needed diff --git a/pyunlocbox/solvers.py b/pyunlocbox/solvers.py index 97aff14..6dfc5e9 100644 --- a/pyunlocbox/solvers.py +++ b/pyunlocbox/solvers.py @@ -372,11 +372,20 @@ def pre(self, functions, x0): context as to how the solver is using the functions. """ - self.sol = np.asarray(x0) self.smooth_funs = [] self.non_smooth_funs = [] - self._pre(functions, self.sol) - self.accel.pre(functions, self.sol) + self._pre(functions, np.asarray(x0)) + self.accel.pre(functions, np.asarray(x0)) + enable_complex = False + for current_func in self.smooth_funs + self.non_smooth_funs: + if (current_func.y().dtype == np.complex + or current_func.A(np.asarray(x0)).dtype == np.complex + or current_func.At(current_func.y()).dtype == np.complex): + enable_complex = True + if enable_complex: + self.sol = np.asarray(x0, dtype=np.complex) + else: + self.sol = np.asarray(x0) def _pre(self, functions, x0): raise NotImplementedError("Class user should define this method.") From 140482d7fcfb16f90431651256518b325fe373b7 Mon Sep 17 00:00:00 2001 From: Thomas Arildsen Date: Sun, 1 Apr 2018 23:52:23 +0200 Subject: [PATCH 5/6] BUG: Fixed some complex number handling in objective functions The squared two-norm could not handle complex vectors correctly (absolute value was omitted). --- pyunlocbox/functions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pyunlocbox/functions.py b/pyunlocbox/functions.py index 8612f01..6b4e42d 100644 --- a/pyunlocbox/functions.py +++ b/pyunlocbox/functions.py @@ -196,7 +196,7 @@ def __init__(self, y=0, A=None, At=None, tight=True, nu=1, tol=1e-3, elif callable(A): self.At = A else: - self.At = lambda x: A.T.dot(x) + self.At = lambda x: A.T.conj().dot(x) else: if callable(At): self.At = At @@ -493,7 +493,7 @@ def __init__(self, **kwargs): def _eval(self, x): sol = self.A(x) - self.y() - return self.lambda_ * np.sum((self.w * sol)**2) + return self.lambda_ * np.sum(np.abs(self.w * sol)**2) def _prox(self, x, T): # Gamma is T in the matlab UNLocBox implementation. @@ -502,8 +502,9 @@ def _prox(self, x, T): sol = x + 2. * gamma * self.At(self.y() * self.w**2) sol /= 1. + 2. * gamma * self.nu * self.w**2 else: - res = minimize(fun=lambda z: 0.5 * np.sum((z - x)**2) + gamma * - np.sum((self.w * (self.A(z) - self.y()))**2), + res = minimize(fun=lambda z: 0.5 * np.sum(np.abs(z - x)**2) + + gamma * np.sum((self.w * np.abs(self.A(z) + - self.y()))**2), x0=x, method='BFGS', jac=lambda z: z - x + 2. * gamma * From a9cfec3b51140973f135985ebd01068d97edff0a Mon Sep 17 00:00:00 2001 From: Thomas Arildsen Date: Thu, 5 Jul 2018 18:58:52 +0200 Subject: [PATCH 6/6] TST: Basic function class tests extended with complex data --- pyunlocbox/tests/test_functions.py | 40 ++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/pyunlocbox/tests/test_functions.py b/pyunlocbox/tests/test_functions.py index 5395c2a..a06c6a8 100644 --- a/pyunlocbox/tests/test_functions.py +++ b/pyunlocbox/tests/test_functions.py @@ -36,24 +36,42 @@ def test_func(self): # f.grad = lambda x: 2 * x # f.prox = lambda x, T: x + T - def assert_equivalent(param1, param2): - x = [[7, 8, 9], [10, 324, -45], [-7, -.2, 5]] + def assert_equivalent(param1, param2, test_data): funcs = inspect.getmembers(functions, inspect.isclass) for f in funcs: if f[0] not in ['func', 'norm', 'proj']: f1 = f[1](**param1) f2 = f[1](**param2) - self.assertEqual(f1.eval(x), f2.eval(x)) - nptest.assert_array_equal(f1.prox(x, 3), f2.prox(x, 3)) - if 'GRAD' in f1.cap(x): - nptest.assert_array_equal(f1.grad(x), f2.grad(x)) - + self.assertEqual(f1.eval(test_data), f2.eval(test_data)) + nptest.assert_array_equal(f1.prox(test_data, 3), + f2.prox(test_data, 3)) + if 'GRAD' in f1.cap(test_data): + nptest.assert_array_equal(f1.grad(test_data), + f2.grad(test_data)) + + # First check with real data + x = [[7, 8, 9], [10, 324, -45], [-7, -.2, 5]] + # Default parameters. Callable or matrices. + assert_equivalent({'y': 3.2}, {'y': lambda: 3.2}, x) + assert_equivalent({'A': None}, {'A': np.identity(3)}, x) + A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]]) + assert_equivalent({'A': A}, {'A': A, 'At': A.T}, x) + # Repeat with complex A + A = np.array([[-4, 2j, 5], [1j, 3, -7], [2, -1j, 0]]) + assert_equivalent({'A': A}, {'A': A, 'At': A.conj().T}, x) + assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A}, x) + + # Repeat the whole thing with complex data + x = [[7j, 8, 9], [10, 324, -45j], [-7, -.2j, 5]] # Default parameters. Callable or matrices. - assert_equivalent({'y': 3.2}, {'y': lambda: 3.2}) - assert_equivalent({'A': None}, {'A': np.identity(3)}) + assert_equivalent({'y': 3.2}, {'y': lambda: 3.2}, x) + assert_equivalent({'A': None}, {'A': np.identity(3)}, x) A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]]) - assert_equivalent({'A': A}, {'A': A, 'At': A.T}) - assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A}) + assert_equivalent({'A': A}, {'A': A, 'At': A.T}, x) + # Repeat with complex A + A = np.array([[-4, 2j, 5], [1j, 3, -7], [2, -1j, 0]]) + assert_equivalent({'A': A}, {'A': A, 'At': A.conj().T}, x) + assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A}, x) def test_dummy(self): """