Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ The following are currently implemented:
* 2-D DCT-II and its inverse (which is a scaled DCT-III)
* 3-D DCT-II and its inverse (which is a scaled DCT-III)

From now on, DST is also supported, for type-II DST and type-III DST, a reference could be found in the paper by [Xuancheng Shao and Steven G. Johnson](https://arxiv.org/pdf/cs/0703150).

The following are currently implemented:

* 1-D DST-I and its inverse (which is a scaled DST-I)
* 1-D DST-II and its inverse (which is a scaled DST-III)
* 2-D DST-II and its inverse (which is a scaled DST-III)
* 3-D DST-II and its inverse (which is a scaled DST-III)

## Install

```
Expand Down
1 change: 1 addition & 0 deletions torch_dct/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._dct import dct, idct, dct1, idct1, dct_2d, idct_2d, dct_3d, idct_3d, LinearDCT, apply_linear_2d, apply_linear_3d
from ._dct import dst, idst, dst1, idst1, dst_2d, idst_2d, dst_3d, idst_3d, LinearDST
256 changes: 236 additions & 20 deletions torch_dct/_dct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def dct1_rfft_impl(x):
return torch.view_as_real(torch.fft.rfft(x, dim=1))

def dct_fft_impl(v):
return torch.view_as_real(torch.fft.fft(v, dim=1))

Expand All @@ -18,13 +18,14 @@ def idct_irfft_impl(V):
# PyTorch 1.6.0 and older versions
def dct1_rfft_impl(x):
return torch.rfft(x, 1)

def dct_fft_impl(v):
return torch.rfft(v, 1, onesided=False)

def idct_irfft_impl(V):
return torch.irfft(V, 1, onesided=False)

# ---------- DCT Implementation Section ----------


def dct1(x):
Expand Down Expand Up @@ -73,13 +74,13 @@ def dct(x, norm=None):

Vc = dct_fft_impl(v)

k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)

V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

if norm == 'ortho':
if norm == "ortho":
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2

Expand Down Expand Up @@ -107,11 +108,15 @@ def idct(X, norm=None):

X_v = X.contiguous().view(-1, x_shape[-1]) / 2

if norm == 'ortho':
if norm == "ortho":
X_v[:, 0] *= np.sqrt(N) * 2
X_v[:, 1:] *= np.sqrt(N / 2) * 2

k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
k = (
torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :]
* np.pi
/ (2 * N)
)
W_r = torch.cos(k)
W_i = torch.sin(k)

Expand All @@ -125,8 +130,8 @@ def idct(X, norm=None):

v = idct_irfft_impl(V)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, :N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, :N // 2]
x[:, ::2] += v[:, : N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, : N // 2]

return x.view(*x_shape)

Expand Down Expand Up @@ -203,10 +208,11 @@ def idct_3d(X, norm=None):

class LinearDCT(nn.Linear):
"""Implement any DCT as a linear layer; in practice this executes around
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
increase memory usage.
:param in_features: size of expected input
:param type: which dct function in this file to use"""

def __init__(self, in_features, type, norm=None, bias=False):
self.type = type
self.N = in_features
Expand All @@ -216,15 +222,218 @@ def __init__(self, in_features, type, norm=None, bias=False):
def reset_parameters(self):
# initialise using dct function
I = torch.eye(self.N)
if self.type == 'dct1':
if self.type == "dct1":
self.weight.data = dct1(I).data.t()
elif self.type == 'idct1':
elif self.type == "idct1":
self.weight.data = idct1(I).data.t()
elif self.type == 'dct':
elif self.type == "dct":
self.weight.data = dct(I, norm=self.norm).data.t()
elif self.type == 'idct':
elif self.type == "idct":
self.weight.data = idct(I, norm=self.norm).data.t()
self.weight.requires_grad = False # don't learn this!
self.weight.requires_grad = False # don't learn this!


# ---------- DST Implementation Section ----------
def dst1(x):
"""
Discrete Sine Transform, Type I

:param x: the input signal
:return: the DST-I of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.view(-1, N)

x_odd = torch.cat(
[
torch.zeros(x.shape[0], 1, dtype=x.dtype, device=x.device),
x,
torch.zeros(x.shape[0], 1, dtype=x.dtype, device=x.device),
-x.flip([1]),
],
dim=1,
)

X = dct_fft_impl(x_odd)
dst_result = -X[:, 1 : N + 1, 1]

return dst_result.view(*x_shape)


def idst1(X):
"""
The inverse of DST-I, which is just a scaled DST-I

Our definition of idst1 is such that idst1(dst1(x)) == x

:param X: the input signal
:return: the inverse DST-I of the signal over the last dimension
"""
n = X.shape[-1]
return dst1(X) / (2 * (n + 1))


def dst(x, norm=None):
"""
Discrete Sine Transform, Type II (a.k.a. the DST)

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dst.html

:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DST-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)

alt_pattern = torch.arange(N, dtype=x.dtype, device=x.device) % 2
alt_signs = 1 - 2 * alt_pattern
x_alt = x * alt_signs.unsqueeze(0)

dct_result = dct(x_alt.reshape(x_shape), norm=None)
dst_result = dct_result.flip([-1])

if norm == "ortho":
dst_result = dst_result.contiguous().view(-1, N)
dst_result[:, 0] /= np.sqrt(N) * 2
dst_result[:, 1:] /= np.sqrt(N / 2) * 2
dst_result = dst_result.view(*x_shape)

return dst_result


def idst(X, norm=None):
"""
The inverse to DST-II, which is a scaled Discrete Sine Transform, Type III

Our definition of idst is that idst(dst(x)) == x

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dst.html

:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DST-II of the signal over the last dimension
"""
X_shape = X.shape
N = X_shape[-1]

X_v = X.contiguous().view(-1, N)
if norm == "ortho":
X_v = X_v.clone()
X_v[:, 0] *= np.sqrt(N) * 2
X_v[:, 1:] *= np.sqrt(N / 2) * 2

X_v = X_v.view(*X_shape)
X_rev = X_v.flip([-1])
idct_result = idct(X_rev, norm=None)

x = idct_result.contiguous().view(-1, N)
alt_pattern = torch.arange(N, dtype=X.dtype, device=X.device) % 2
alt_signs = 1 - 2 * alt_pattern
x = x * alt_signs.unsqueeze(0)

return x.view(*X_shape)


def dst_2d(x, norm=None):
"""
2-dimensional Discrete Sine Transform, Type II (a.k.a. the DST)

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dst.html

:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DST-II of the signal over the last 2 dimensions
"""
X1 = dst(x, norm=norm)
X2 = dst(X1.transpose(-1, -2), norm=norm)
return X2.transpose(-1, -2)


def idst_2d(X, norm=None):
"""
The inverse to 2D DST-II, which is a scaled Discrete Sine Transform, Type III

Our definition of idst is that idst_2d(dst_2d(x)) == x

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dst.html

:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DST-II of the signal over the last 2 dimensions
"""
x1 = idst(X, norm=norm)
x2 = idst(x1.transpose(-1, -2), norm=norm)
return x2.transpose(-1, -2)


def dst_3d(x, norm=None):
"""
3-dimensional Discrete Sine Transform, Type II (a.k.a. the DST)

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dst.html

:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DST-II of the signal over the last 3 dimensions
"""
X1 = dst(x, norm=norm)
X2 = dst(X1.transpose(-1, -2), norm=norm)
X3 = dst(X2.transpose(-1, -3), norm=norm)
return X3.transpose(-1, -3).transpose(-1, -2)


def idst_3d(X, norm=None):
"""
The inverse to 3D DST-II, which is a scaled Discrete Sine Transform, Type III

Our definition of idst is that idst_3d(dst_3d(x)) == x

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dst.html

:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DST-II of the signal over the last 3 dimensions
"""
x1 = idst(X, norm=norm)
x2 = idst(x1.transpose(-1, -2), norm=norm)
x3 = idst(x2.transpose(-1, -3), norm=norm)
return x3.transpose(-1, -3).transpose(-1, -2)


class LinearDST(nn.Linear):
"""Implement any DST as a linear layer; in practice this executes around
50x faster on GPU. Unfortunately, the DST matrix is stored, which will
increase memory usage.
:param in_features: size of expected input
:param type: which dst function in this file to use"""

def __init__(self, in_features, type, norm=None, bias=False):
self.type = type
self.N = in_features
self.norm = norm
super(LinearDST, self).__init__(in_features, in_features, bias=bias)

def reset_parameters(self):
# initialise using dst function
I = torch.eye(self.N)
if self.type == "dst1":
self.weight.data = dst1(I).data.t()
elif self.type == "idst1":
self.weight.data = idst1(I).data.t()
elif self.type == "dst":
self.weight.data = dst(I, norm=self.norm).data.t()
elif self.type == "idst":
self.weight.data = idst(I, norm=self.norm).data.t()
self.weight.requires_grad = False # don't learn this!


def apply_linear_2d(x, linear_layer):
Expand All @@ -237,6 +446,7 @@ def apply_linear_2d(x, linear_layer):
X2 = linear_layer(X1.transpose(-1, -2))
return X2.transpose(-1, -2)


def apply_linear_3d(x, linear_layer):
"""Can be used with a LinearDCT layer to do a 3D DCT.
:param x: the input signal
Expand All @@ -248,13 +458,19 @@ def apply_linear_3d(x, linear_layer):
X3 = linear_layer(X2.transpose(-1, -3))
return X3.transpose(-1, -3).transpose(-1, -2)

if __name__ == '__main__':
x = torch.Tensor(1000,4096)
x.normal_(0,1)
linear_dct = LinearDCT(4096, 'dct')

if __name__ == "__main__":
x = torch.Tensor(1000, 4096)
x.normal_(0, 1)
linear_dct = LinearDCT(4096, "dct")
error = torch.abs(dct(x) - linear_dct(x))
assert error.max() < 1e-3, (error, error.max())
linear_idct = LinearDCT(4096, 'idct')
linear_idct = LinearDCT(4096, "idct")
error = torch.abs(idct(x) - linear_idct(x))
assert error.max() < 1e-3, (error, error.max())

linear_dst = LinearDST(4096, "dst")
error = torch.abs(dst(x) - linear_dst(x))
assert error.max() < 1e-3, (error, error.max())
linear_idst = LinearDST(4096, "idst")
error = torch.abs(idst(x) - linear_idst(x))
assert error.max() < 1e-3, (error, error.max())
Loading