Skip to content

Commit 754ae5b

Browse files
committed
feat(layer): add Conv2DTranspose
1 parent 0c11017 commit 754ae5b

File tree

1 file changed

+178
-17
lines changed

1 file changed

+178
-17
lines changed

neuralnetlib/layers.py

Lines changed: 178 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2540,6 +2540,165 @@ def from_config(config: dict):
25402540
)
25412541

25422542

2543+
class Conv2DTranspose(Layer):
2544+
def __init__(self, filters: int, kernel_size: int | tuple, strides: int | tuple = 1,
2545+
padding: str = 'valid', weights_init: str = "default", bias_init: str = "default",
2546+
random_state: int = None, **kwargs):
2547+
self.filters = filters
2548+
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
2549+
self.strides = (strides, strides) if isinstance(strides, int) else strides
2550+
self.padding = padding
2551+
2552+
self.weights = None
2553+
self.bias = None
2554+
self.d_weights = None
2555+
self.d_bias = None
2556+
2557+
self.weights_init = weights_init
2558+
self.bias_init = bias_init
2559+
self.random_state = random_state
2560+
2561+
for key, value in kwargs.items():
2562+
setattr(self, key, value)
2563+
2564+
def initialize_weights(self, input_shape: tuple):
2565+
_, _, _, in_channels = input_shape
2566+
2567+
self.rng = np.random.default_rng(
2568+
self.random_state if self.random_state is not None else int(time.time_ns()))
2569+
2570+
if self.weights_init == "xavier":
2571+
self.weights = self.rng.normal(0, np.sqrt(2 / (np.prod(self.kernel_size) * self.filters)),
2572+
(*self.kernel_size, self.filters, in_channels))
2573+
elif self.weights_init == "he":
2574+
self.weights = self.rng.normal(0, np.sqrt(2 / (in_channels * np.prod(self.kernel_size))),
2575+
(*self.kernel_size, self.filters, in_channels))
2576+
elif self.weights_init == "default":
2577+
self.weights = self.rng.normal(0, 0.01, (*self.kernel_size, self.filters, in_channels))
2578+
else:
2579+
raise ValueError("Invalid weights_init value. Possible values are 'xavier', 'he', and 'default'.")
2580+
2581+
if self.bias_init == "default":
2582+
self.bias = np.zeros((1, 1, 1, self.filters))
2583+
elif self.bias_init == "normal":
2584+
self.bias = self.rng.normal(0, 0.01, (1, 1, 1, self.filters))
2585+
elif self.bias_init == "uniform":
2586+
self.bias = self.rng.uniform(-0.1, 0.1, (1, 1, 1, self.filters))
2587+
elif self.bias_init == "small":
2588+
self.bias = np.full((1, 1, 1, self.filters), 0.01)
2589+
else:
2590+
raise ValueError("Invalid bias_init value.")
2591+
2592+
self.d_weights = np.zeros_like(self.weights)
2593+
self.d_bias = np.zeros_like(self.bias)
2594+
2595+
def forward_pass(self, input_data: np.ndarray) -> np.ndarray:
2596+
if self.weights is None:
2597+
assert len(input_data.shape) == 4, "Conv2DTranspose input must be 4D (batch_size, height, width, channels)"
2598+
self.initialize_weights(input_data.shape)
2599+
2600+
self.input = input_data
2601+
batch_size, in_height, in_width, in_channels = input_data.shape
2602+
kernel_height, kernel_width, out_channels, _ = self.weights.shape
2603+
2604+
if self.padding == 'same':
2605+
out_height = in_height * self.strides[0]
2606+
out_width = in_width * self.strides[1]
2607+
pad_height = max((in_height - 1) * self.strides[0] + kernel_height - out_height, 0)
2608+
pad_width = max((in_width - 1) * self.strides[1] + kernel_width - out_width, 0)
2609+
pad_top = pad_height // 2
2610+
pad_bottom = pad_height - pad_top
2611+
pad_left = pad_width // 2
2612+
pad_right = pad_width - pad_left
2613+
else:
2614+
out_height = (in_height - 1) * self.strides[0] + kernel_height
2615+
out_width = (in_width - 1) * self.strides[1] + kernel_width
2616+
pad_top = pad_bottom = pad_left = pad_right = 0
2617+
2618+
padded_output = np.zeros((batch_size, out_height + pad_top + pad_bottom,
2619+
out_width + pad_left + pad_right, out_channels))
2620+
2621+
for h in range(in_height):
2622+
for w in range(in_width):
2623+
h_start = h * self.strides[0]
2624+
w_start = w * self.strides[1]
2625+
2626+
out_slice = padded_output[:, h_start:h_start + kernel_height,
2627+
w_start:w_start + kernel_width, :]
2628+
2629+
for c in range(in_channels):
2630+
weight_slice = self.weights[:, :, :, c]
2631+
input_val = input_data[:, h, w, c:c+1]
2632+
out_slice += np.expand_dims(weight_slice, 0) * np.expand_dims(input_val, (1, 2))
2633+
2634+
if self.padding == 'valid':
2635+
output = padded_output
2636+
else:
2637+
output = padded_output[:, pad_top:pad_top + out_height,
2638+
pad_left:pad_left + out_width, :]
2639+
2640+
return output + self.bias
2641+
2642+
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
2643+
batch_size = output_error.shape[0]
2644+
kernel_height, kernel_width, out_channels, in_channels = self.weights.shape
2645+
2646+
d_input = np.zeros_like(self.input)
2647+
self.d_weights = np.zeros_like(self.weights)
2648+
self.d_bias = np.sum(output_error, axis=(0, 1, 2), keepdims=True)
2649+
2650+
for h in range(d_input.shape[1]):
2651+
for w in range(d_input.shape[2]):
2652+
h_start = h * self.strides[0]
2653+
w_start = w * self.strides[1]
2654+
2655+
error_field = output_error[:, h_start:h_start + kernel_height,
2656+
w_start:w_start + kernel_width, :]
2657+
2658+
if error_field.shape[1:3] == (kernel_height, kernel_width):
2659+
for c in range(in_channels):
2660+
weight_slice = self.weights[:, :, :, c]
2661+
d_input[:, h, w, c] = np.sum(error_field * weight_slice, axis=(1, 2, 3))
2662+
2663+
for b in range(batch_size):
2664+
self.d_weights[:, :, :, c] += error_field[b] * self.input[b, h, w, c]
2665+
2666+
return d_input
2667+
2668+
def __str__(self) -> str:
2669+
return f'Conv2DTranspose(filters={self.filters}, kernel_size={self.kernel_size}, strides={self.strides}, padding={self.padding})'
2670+
2671+
def get_config(self) -> dict:
2672+
return {
2673+
'name': self.__class__.__name__,
2674+
'weights': self.weights.tolist() if self.weights is not None else None,
2675+
'bias': self.bias.tolist() if self.bias is not None else None,
2676+
'filters': self.filters,
2677+
'kernel_size': self.kernel_size,
2678+
'strides': self.strides,
2679+
'padding': self.padding,
2680+
'weights_init': self.weights_init,
2681+
'bias_init': self.bias_init,
2682+
'random_state': self.random_state
2683+
}
2684+
2685+
@staticmethod
2686+
def from_config(config: dict):
2687+
layer = Conv2DTranspose(
2688+
config['filters'],
2689+
config['kernel_size'],
2690+
config['strides'],
2691+
config['padding'],
2692+
config['weights_init'],
2693+
config['bias_init'],
2694+
config['random_state']
2695+
)
2696+
if config['weights'] is not None:
2697+
layer.weights = np.array(config['weights'])
2698+
layer.bias = np.array(config['bias'])
2699+
return layer
2700+
2701+
25432702
class UpSampling2D(Layer):
25442703
def __init__(self, size=(2, 2), interpolation="nearest", **kwargs):
25452704
super().__init__()
@@ -3627,56 +3786,58 @@ def from_config(config: dict) -> "TransformerDecoderLayer":
36273786
Conv2D: [Conv1D, LSTM, GRU, Bidirectional, Unidirectional],
36283787

36293788
UpSampling2D: [Conv1D, LSTM, GRU, Bidirectional, Unidirectional],
3789+
3790+
Conv2DTranspose: [Conv1D, LSTM, GRU, Bidirectional, Unidirectional],
36303791

36313792
MaxPooling2D: [Conv1D, MaxPooling1D, AveragePooling1D, LSTM, GRU, Bidirectional, Unidirectional],
36323793

36333794
AveragePooling2D: [Conv1D, MaxPooling1D, AveragePooling1D, LSTM, GRU, Bidirectional, Unidirectional],
36343795

36353796
GlobalAveragePooling2D: [Conv1D, MaxPooling1D, AveragePooling1D, LSTM, GRU, Bidirectional, Unidirectional],
36363797

3637-
Conv1D: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D],
3798+
Conv1D: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D],
36383799

3639-
MaxPooling1D: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D],
3800+
MaxPooling1D: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D],
36403801

3641-
AveragePooling1D: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D],
3802+
AveragePooling1D: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D],
36423803

3643-
GlobalAveragePooling1D: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3804+
GlobalAveragePooling1D: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36443805

36453806
Flatten: [],
36463807

36473808
Dropout: [],
36483809

3649-
Embedding: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3810+
Embedding: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36503811

36513812
BatchNormalization: [],
36523813

36533814
LayerNormalization: [],
36543815

36553816
Permute: [],
36563817

3657-
TextVectorization: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3818+
TextVectorization: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36583819

36593820
Reshape: [],
36603821

3661-
LSTM: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3822+
LSTM: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36623823

3663-
GRU: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3824+
GRU: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36643825

3665-
Bidirectional: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3826+
Bidirectional: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36663827

3667-
Unidirectional: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3828+
Unidirectional: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36683829

3669-
Attention: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3830+
Attention: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36703831

3671-
MultiHeadAttention: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3832+
MultiHeadAttention: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36723833

3673-
PositionalEncoding: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3834+
PositionalEncoding: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36743835

3675-
FeedForward: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3836+
FeedForward: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36763837

3677-
AddNorm: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3838+
AddNorm: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36783839

3679-
TransformerEncoderLayer: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3840+
TransformerEncoderLayer: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36803841

3681-
TransformerDecoderLayer: [Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D],
3842+
TransformerDecoderLayer: [Conv2D, UpSampling2D, Conv2DTranspose, MaxPooling2D, AveragePooling2D],
36823843
}

0 commit comments

Comments
 (0)