Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyunlocbox/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,16 @@ def __init__(self, **kwargs):
super(fista, self).__init__(**kwargs)

def _pre(self, functions, x0):
self.sol = np.array(x0, copy=True)
enable_complex = False
for current_func in functions:
if (current_func.y().dtype == np.complex
or current_func.A(np.asarray(x0)).dtype == np.complex
or current_func.At(current_func.y()).dtype == np.complex):
enable_complex = True
if enable_complex:
self.sol = np.array(x0, copy=True, dtype=np.complex)
else:
self.sol = np.array(x0, copy=True)

def _update_sol(self, solver, objective, niter):
self.t = 1. if (niter == 1) else self.t # Restart variable t if needed
Expand Down
19 changes: 11 additions & 8 deletions pyunlocbox/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from pyunlocbox import operators as op


def _soft_threshold(z, T, handle_complex=True):
def _soft_threshold(z, T):
r"""
Return the soft thresholded signal.

Expand Down Expand Up @@ -86,17 +86,19 @@ def _soft_threshold(z, T, handle_complex=True):
array([-1, 0, 0, 0, 1])

"""
sz = np.maximum(np.abs(z) - T, 0)

if not handle_complex:
input_dtype = np.asarray(z).dtype

if not input_dtype == np.complex:
# This soft thresholding method only supports real signal.
sz[:] = np.sign(z) * sz
sz = np.sign(z) * np.maximum(np.abs(z) - T, 0)

else:
# This soft thresholding method supports complex complex signal.
# Transform to float to avoid integer division.
# In our case 0 divided by 0 should be 0, not NaN, and is not an error.
# It corresponds to 0 thresholded by 0, which is 0.
sz = np.maximum(np.abs(z) - T, 0, dtype=input_dtype)
old_err_state = np.seterr(invalid='ignore')
sz[:] = np.nan_to_num(1. * sz / (sz + T) * z)
np.seterr(**old_err_state)
Expand Down Expand Up @@ -194,7 +196,7 @@ def __init__(self, y=0, A=None, At=None, tight=True, nu=1, tol=1e-3,
elif callable(A):
self.At = A
else:
self.At = lambda x: A.T.dot(x)
self.At = lambda x: A.T.conj().dot(x)
else:
if callable(At):
self.At = At
Expand Down Expand Up @@ -491,7 +493,7 @@ def __init__(self, **kwargs):

def _eval(self, x):
sol = self.A(x) - self.y()
return self.lambda_ * np.sum((self.w * sol)**2)
return self.lambda_ * np.sum(np.abs(self.w * sol)**2)

def _prox(self, x, T):
# Gamma is T in the matlab UNLocBox implementation.
Expand All @@ -500,8 +502,9 @@ def _prox(self, x, T):
sol = x + 2. * gamma * self.At(self.y() * self.w**2)
sol /= 1. + 2. * gamma * self.nu * self.w**2
else:
res = minimize(fun=lambda z: 0.5 * np.sum((z - x)**2) + gamma *
np.sum((self.w * (self.A(z) - self.y()))**2),
res = minimize(fun=lambda z: 0.5 * np.sum(np.abs(z - x)**2)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth replacing the sum-abs operations with inner products with the complex conjugate of vectors here. Pro: abs makes these operations quite a bit slower. Con: I think the sum-abs formulation is a bit more readable, see line 496.

+ gamma * np.sum((self.w * np.abs(self.A(z)
- self.y()))**2),
x0=x,
method='BFGS',
jac=lambda z: z - x + 2. * gamma *
Expand Down
15 changes: 12 additions & 3 deletions pyunlocbox/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,20 @@ def pre(self, functions, x0):
context as to how the solver is using the functions.

"""
self.sol = np.asarray(x0)
self.smooth_funs = []
self.non_smooth_funs = []
self._pre(functions, self.sol)
self.accel.pre(functions, self.sol)
self._pre(functions, np.asarray(x0))
self.accel.pre(functions, np.asarray(x0))
enable_complex = False
for current_func in self.smooth_funs + self.non_smooth_funs:
if (current_func.y().dtype == np.complex
or current_func.A(np.asarray(x0)).dtype == np.complex
or current_func.At(current_func.y()).dtype == np.complex):
enable_complex = True
if enable_complex:
self.sol = np.asarray(x0, dtype=np.complex)
else:
self.sol = np.asarray(x0)

def _pre(self, functions, x0):
raise NotImplementedError("Class user should define this method.")
Expand Down
66 changes: 50 additions & 16 deletions pyunlocbox/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,42 @@ def test_func(self):
# f.grad = lambda x: 2 * x
# f.prox = lambda x, T: x + T

def assert_equivalent(param1, param2):
x = [[7, 8, 9], [10, 324, -45], [-7, -.2, 5]]
def assert_equivalent(param1, param2, test_data):
funcs = inspect.getmembers(functions, inspect.isclass)
for f in funcs:
if f[0] not in ['func', 'norm', 'proj']:
f1 = f[1](**param1)
f2 = f[1](**param2)
self.assertEqual(f1.eval(x), f2.eval(x))
nptest.assert_array_equal(f1.prox(x, 3), f2.prox(x, 3))
if 'GRAD' in f1.cap(x):
nptest.assert_array_equal(f1.grad(x), f2.grad(x))

self.assertEqual(f1.eval(test_data), f2.eval(test_data))
nptest.assert_array_equal(f1.prox(test_data, 3),
f2.prox(test_data, 3))
if 'GRAD' in f1.cap(test_data):
nptest.assert_array_equal(f1.grad(test_data),
f2.grad(test_data))

# First check with real data
x = [[7, 8, 9], [10, 324, -45], [-7, -.2, 5]]
# Default parameters. Callable or matrices.
assert_equivalent({'y': 3.2}, {'y': lambda: 3.2}, x)
assert_equivalent({'A': None}, {'A': np.identity(3)}, x)
A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]])
assert_equivalent({'A': A}, {'A': A, 'At': A.T}, x)
# Repeat with complex A
A = np.array([[-4, 2j, 5], [1j, 3, -7], [2, -1j, 0]])
assert_equivalent({'A': A}, {'A': A, 'At': A.conj().T}, x)
assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A}, x)

# Repeat the whole thing with complex data
x = [[7j, 8, 9], [10, 324, -45j], [-7, -.2j, 5]]
# Default parameters. Callable or matrices.
assert_equivalent({'y': 3.2}, {'y': lambda: 3.2})
assert_equivalent({'A': None}, {'A': np.identity(3)})
assert_equivalent({'y': 3.2}, {'y': lambda: 3.2}, x)
assert_equivalent({'A': None}, {'A': np.identity(3)}, x)
A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]])
assert_equivalent({'A': A}, {'A': A, 'At': A.T})
assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A})
assert_equivalent({'A': A}, {'A': A, 'At': A.T}, x)
# Repeat with complex A
A = np.array([[-4, 2j, 5], [1j, 3, -7], [2, -1j, 0]])
assert_equivalent({'A': A}, {'A': A, 'At': A.conj().T}, x)
assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A}, x)

def test_dummy(self):
"""
Expand Down Expand Up @@ -131,16 +149,32 @@ def test_soft_thresholding(self):

"""
x = np.arange(-4, 5, 1)
# Test integer division for complex method.
# Test with integers
Ts = [2]
y_gold = [[-2, -1, 0, 0, 0, 0, 0, 1, 2]]
# Test division by 0 for complex method.
# Test with floats
Ts.append([.4, .3, .2, .1, 0, .1, .2, .3, .4])
y_gold.append([-3.6, -2.7, -1.8, -.9, 0, .9, 1.8, 2.7, 3.6])
for k, T in enumerate(Ts):
for cmplx in [False, True]:
y_test = functions._soft_threshold(x, T, cmplx)
nptest.assert_array_equal(y_test, y_gold[k])
y_test = functions._soft_threshold(x, T)
nptest.assert_array_equal(y_test, y_gold[k])
# Test division by 0 for complex method.
x = np.arange(-4, 5, 1) + 0j
y_gold = [-3.6 + 0j, -2.7 + 0j, -1.8 + 0j, -.9 + 0j,
0 + 0j, .9 + 0j, 1.8 + 0j, 2.7 + 0j, 3.6 + 0j]
y_test = functions._soft_threshold(x, Ts[-1])
nptest.assert_array_equal(y_test, y_gold)

x = 1j * np.arange(-4, 5, 1)
y_gold = [-3.6j, -2.7j, -1.8j, -.9j, 0j, .9j, 1.8j, 2.7j, 3.6j]
y_test = functions._soft_threshold(x, Ts[-1])
nptest.assert_array_equal(y_test, y_gold)

x = (1 + 1j) * np.array([-1, 1])
T = .1
y_gold = x - T * x / np.abs(x)
y_test = functions._soft_threshold(x, T)
nptest.assert_array_almost_equal(y_test, y_gold)

def test_norm_l1(self):
"""
Expand Down