Skip to content
This repository was archived by the owner on Nov 29, 2023. It is now read-only.

Commit 1040efa

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent dda6247 commit 1040efa

File tree

12 files changed

+3
-20
lines changed

12 files changed

+3
-20
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
4646
4747
<!-- ALL-CONTRIBUTORS-LIST:END -->
4848
49-
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
49+
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!

satflow/data/datamodules.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def per_worker_init(self, worker_id: int):
186186
pass
187187

188188
def __getitem__(self, idx):
189-
190189
x = {
191190
SATELLITE_DATA: torch.randn(
192191
self.batch_size, self.seq_length, self.width, self.height, self.number_sat_channels

satflow/data/utils/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def eumetsat_name_to_datetime(filename: str):
3333

3434
def retrieve_pixel_value(geo_coord, data_source):
3535
"""Return floating-point value that corresponds to given point.
36-
Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal"""
36+
Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal
37+
"""
3738
x, y = geo_coord[0], geo_coord[1]
3839
forward_transform = affine.Affine.from_gdal(*data_source.GetGeoTransform())
3940
reverse_transform = ~forward_transform

satflow/models/conv_lstm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "s
169169
)
170170

171171
def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):
172-
173172
outputs = []
174173

175174
# encoder
@@ -203,7 +202,6 @@ def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3,
203202
return outputs
204203

205204
def forward(self, x, forecast_steps=0, hidden_state=None):
206-
207205
"""
208206
Parameters
209207
----------

satflow/models/gan/generators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def __init__(
146146

147147
mult = 2**n_downsampling
148148
for i in range(n_blocks): # add ResNet blocks
149-
150149
model += [
151150
ResnetBlock(
152151
ngf * mult,

satflow/models/layers/Attention.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def __init__(
1616
)
1717

1818
def forward(self, x):
19-
2019
return self.model(x)
2120

2221

@@ -56,7 +55,6 @@ def init_conv(self, conv, glu=True):
5655
conv.bias.data.zero_()
5756

5857
def forward(self, x):
59-
6058
batch_size, C, T, W, H = x.size()
6159

6260
assert T % 2 == 0 and W % 2 == 0 and H % 2 == 0, "T, W, H is not even"
@@ -111,7 +109,6 @@ def forward(self, x):
111109

112110
class SelfAttention(nn.Module):
113111
def __init__(self, in_dim, activation=F.relu, pooling_factor=2): # TODO for better compability
114-
115112
super(SelfAttention, self).__init__()
116113
self.activation = activation
117114

@@ -134,7 +131,6 @@ def init_conv(self, conv, glu=True):
134131
conv.bias.data.zero_()
135132

136133
def forward(self, x):
137-
138134
if len(x.size()) == 4:
139135
batch_size, C, W, H = x.size()
140136
T = 1
@@ -224,7 +220,6 @@ def forward(self, x):
224220

225221

226222
if __name__ == "__main__":
227-
228223
self_attn = SelfAttention(16) # no less than 8
229224
print(self_attn)
230225

satflow/models/layers/Discriminator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,6 @@ def forward(self, x, class_id):
651651

652652

653653
if __name__ == "__main__":
654-
655654
batch_size = 6
656655
n_frames = 8
657656
n_class = 4

satflow/models/layers/GResBlock.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
self.CBNorm2 = ConditionalNorm(out_channel, n_class)
5353

5454
def forward(self, x, condition=None):
55-
5655
# The time dimension is combined with the batch dimension here, so each frame proceeds
5756
# through the blocks independently
5857
BT, C, W, H = x.size()
@@ -100,7 +99,6 @@ def forward(self, x, condition=None):
10099

101100

102101
if __name__ == "__main__":
103-
104102
n_class = 96
105103
batch_size = 4
106104
n_frames = 20

satflow/models/layers/Generator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hier
6868
self.colorize = SpectralNorm(nn.Conv2d(2 * ch, 3, kernel_size=(3, 3), padding=1))
6969

7070
def forward(self, x, class_id):
71-
7271
if self.hierar_flag is True:
7372
noise_emb = torch.split(x, self.in_dim, dim=1)
7473
else:
@@ -87,7 +86,6 @@ def forward(self, x, class_id):
8786

8887
for k, conv in enumerate(self.conv):
8988
if isinstance(conv, ConvGRU):
90-
9189
if k > 0:
9290
_, C, W, H = y.size()
9391
y = y.view(-1, self.n_frames, C, W, H).contiguous()
@@ -132,7 +130,6 @@ def forward(self, x, class_id):
132130

133131

134132
if __name__ == "__main__":
135-
136133
batch_size = 5
137134
in_dim = 120
138135
n_class = 4

satflow/models/layers/Normalization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def forward(self, x, class_id):
8787

8888

8989
if __name__ == "__main__":
90-
9190
cn = ConditionalNorm(3, 2)
9291
x = torch.rand([4, 3, 64, 64])
9392
class_id = torch.rand([4, 2])

0 commit comments

Comments
 (0)