-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsubmit_ntire.py
More file actions
225 lines (187 loc) · 12.2 KB
/
submit_ntire.py
File metadata and controls
225 lines (187 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Package Imports
from os import makedirs
from statistics import mean
from torch import load, no_grad, clamp, Tensor
from torch.cuda import Event, synchronize, get_device_name
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
from pathlib import Path
from shutil import make_archive, unpack_archive, copytree, rmtree
from datetime import datetime
from warnings import warn
# Local Imports
from dataset.loader import RealBokeh
from dataset.util import Mode
from util.parser import get_ntire_parser
# TODO Architecture imports, replace with your model!
from method.config import bokehlicious_size_builder
from method.model import Bokehlicious
"""
!!!!! NTIRE CHALLENGE README: !!!!!
This script produces a submission ready .zip archive to be uploaded at
https://www.codabench.org/competitions/12764/#/participate-tab for evaluation by our server.
You will have to replace our Baseline network with your own solution, makes sure to check the relevant #TODO comments!
The image input archives are expected to be at the -dataset_root_dir (default is './dataset').
Run `python submit_ntire.py -h` to see an overview of the script arguments.
Run `python submit_ntire.py -c small.pt -n 'BokN_S_(Baseline)'` to generate a sample submission based on the challenge baseline method.
"""
def unsqueeeze_batch(batch):
for k, v in batch.items():
if isinstance(v, Tensor):
batch[k] = v.unsqueeze(0).cuda()
return batch
if __name__ == "__main__":
parser = get_ntire_parser()
args = parser.parse_args()
# put the path to your checkpoint file here, or use the parser argument -checkpoint / -c
checkpoint = Path(f"./checkpoints/{args.checkpoint}")
# Setup and sanity checking
assert args.phase in ['dev', 'test'], (f"Unknown argument for phase (-phase, -p): {args.phase}, "
f"only ['dev', 'test'] are supported for this script.")
assert checkpoint.is_file(), f"Checkpoint {checkpoint} is not a file."
if args.image_format != 'png':
warn(f"Image format {args.image_format} might lead to a lower score du to default compression behaviour, "
f"we recomment .png for final submission!")
dataset_path = args.dataset_root_dir / 'Bokeh_NTIRE2026'
assert args.dataset_root_dir.is_dir(), f"Path {args.dataset_root_dir} is not a directory."
if args.phase == 'dev':
if (dataset_path / 'validation').exists():
print(f"Found NTIRE 2026 Bokeh Challenge Development inputs in: {dataset_path.absolute()}")
else: # development set is missing
print(f"Could not locate NTIRE 2026 Bokeh Challenge Development inputs (\'validation\' folder) in "
f"{dataset_path.absolute()}, attempting to extract them from \'Bokeh_NTIRE2026_Development_Inputs.zip\'")
assert (args.dataset_root_dir / 'Bokeh_NTIRE2026_Development_Inputs.zip').is_file(), \
(f'{f"Could not find the \'validation\' split (which is used in the dev phase) at {dataset_path.absolute()}"
if dataset_path.is_dir() else
f"Could not find the \'Bokeh_NTIRE26\' folder at {args.dataset_root_dir.absolute()}"} '
f'OR the development inputs archive \'Bokeh_NTIRE2026_Development_Inputs.zip\'.')
unpack_archive(args.dataset_root_dir / 'Bokeh_NTIRE2026_Development_Inputs.zip', args.dataset_root_dir)
copytree(args.dataset_root_dir / 'Bokeh_NTIRE2026_Development_Inputs',
args.dataset_root_dir, dirs_exist_ok=True)
rmtree(args.dataset_root_dir / 'Bokeh_NTIRE2026_Development_Inputs')
print("Successfully finished development dataset setup!")
else: # args.phase = 'test'
if (dataset_path / 'test').exists():
print(f"Found NTIRE 2026 Bokeh Challenge Test inputs in: {dataset_path.absolute()}")
else: # Test set is missing
print(f"Could not locate NTIRE 2026 Bokeh Challenge Test inputs (\'test\' folder) in "
f"{dataset_path.absolute()}, attempting to extract them from \'Bokeh_NTIRE2026_Test_Inputs.zip\'")
assert (args.dataset_root_dir / 'Bokeh_NTIRE2026_Test_Inputs.zip').is_file(), \
(f'{f"Could not find the \'test\' split (which is used in the final test phase) at {dataset_path.absolute()}"
if dataset_path.is_dir() else
f"Could not find the \'Bokeh_NTIRE26\' folder at {args.dataset_root_dir.absolute()}"} '
f'OR the test inputs archive \'Bokeh_NTIRE2026_Test_Inputs.zip\'.')
unpack_archive(args.dataset_root_dir / 'Bokeh_NTIRE2026_Test_Inputs.zip', args.dataset_root_dir)
copytree(args.dataset_root_dir / 'Bokeh_NTIRE2026_Test_Inputs ', args.dataset_root_dir,
dirs_exist_ok=True)
rmtree(args.dataset_root_dir / 'Bokeh_NTIRE2026_Test_Inputs ')
print("Successfully finished test dataset setup!")
# Setup dirs for saving results
output_directory = args.out_path / args.name / 'NTIRE2026BokehChallenge' / args.phase
# Clean output directory
try:
rmtree(output_directory)
except FileNotFoundError:
pass
makedirs(output_directory, exist_ok=False)
print(f"Saving outputs to {output_directory.absolute()}")
# Network Name is used in the Codalab leaderboard, set as desired with the -name argument
print(f"Running Architecture {args.name} on {'Development' if args.phase == 'dev' else 'Test'} set...")
# We initialize our own method here, replace this with the appropriate code to initialize your network!
# TODO: Load your own model here for evaluation!
config = bokehlicious_size_builder('small')
model = Bokehlicious(**config)
print(f"Initialized model on {args.device}")
# This code should load your checkpoint and set up the model for evaluation
state_dict = load(checkpoint)
model.load_state_dict(state_dict)
model.to(args.device)
model.eval()
print(f"Loaded weights from {checkpoint.absolute()}")
# Initialize evaluation dataset, use '-phase test' argument for the final test phase
dataloader = RealBokeh(data_path=dataset_path, mode=Mode.VAL if args.phase=='dev' else Mode.TEST, device=args.device, challenge=True)
try:
if args.phase == 'dev':
assert len(dataloader) == 78, (
f"There should be 78 images in the development set, but {len(dataloader)} were found. \n"
f"Please make sure you are using the correct input data found here: "
f"https://www.codabench.org/datasets/download/35a0a692-48df-4b91-a716-b438beaf94de/ \n")
else: # args.phase == test
pass
except AssertionError as error:
warn(f"Incorrect input data found at {dataset_path / 'validation' if args.phase == 'dev' else 'test'}. Resetting directory!")
rmtree(dataset_path / 'validation' if args.phase == 'dev' else 'test')
raise error
print(f"Initialized RealBokeh (NTIRE 2026 challenge) {'Development' if args.phase == 'dev' else 'Test'} phase dataloader")
# We use cuda events for timing network inference times
start_events = [Event(enable_timing=True) for _ in range(len(dataloader))]
end_events = [Event(enable_timing=True) for _ in range(len(dataloader))]
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
with no_grad():
# TODO: Depending on whatever inputs your model uses you might need to modify the batch pre-processing
batch = unsqueeeze_batch(batch) # (c h w) to (b c h w) for all tensors in batch dict
synchronize() # wait for GPU to complete any current workload
start_events[i].record() # log prediction start time
### ONLY THE MODEL FORWARD CALL SHOULD BE BETWEEN start_events[i].record()
output = model(**batch) # unpack batch dict for network forward call
### AND nd_events[i].record()
end_events[i].record() # log prediction end time
output = clamp(output, 0, 1)
to_pil_image(output.squeeze(0).cpu()).save(output_directory / f"{batch['image_name']}.{args.image_format}")
print("Finished prediction!")
# record important metadata
avg_time = mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)]) # avg is in ms
print(f"Average time taken for {args.phase} phase: {avg_time / 1000:.3f} seconds on {get_device_name()}.")
parameters = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {parameters / 1e6:.2f}M")
metadata_file = output_directory / "readme.txt"
with open(metadata_file, "w") as readme:
readme.write(f"This file contains relevant metadata for the challenge leaderboard, "
f"everything should be automatically generated by the submit.py script, "
f"but feel free to double check!\n")
readme.write(f"Architecture Name:{args.name}\n")
readme.write(f"Parameters:{parameters / 1e6:.2f}M\n")
readme.write(f"Runtime:{avg_time / 1000:.3f}s\n")
readme.write(f"Device:{get_device_name()}\n")
readme.write(f"Extra data:{'Yes' if args.extra_data else 'No'}\n")
readme.write(f"Script Version:{1.2}\n")
print(f"Wrote metadata to {metadata_file.absolute()}:")
print("")
with open(metadata_file, "r") as readme:
print(readme.read())
print(f"Creating zip archive for Codabench submission...")
archive_file = args.out_path / f'{args.name}_{args.phase}_{datetime.now().strftime("%Y-%m-%d_%H:%M")}'
# check if correct number of images are in the output directory
if args.phase == 'dev':
image_names = [
'8_f9.0.png', '6_f5.6.png', '23_f2.0.png', '21_f8.0.png', '17_f2.0.png', '2_f2.0.png', '25_f2.0.png',
'28_f9.0.png', '6_f3.5.png', '9_f16.png', '27_f18.png', '28_f2.0.png', '5_f2.0.png', '15_f2.2.png',
'24_f5.0.png', '18_f2.0.png', '5_f4.5.png', '28_f14.png', '21_f2.8.png', '21_f4.5.png', '6_f4.5.png',
'19_f5.0.png', '29_f4.5.png', '24_f13.png', '21_f2.0.png', '14_f18.png', '2_f6.3.png', '28_f6.3.png',
'18_f14.png', '13_f18.png', '10_f5.6.png', '10_f2.8.png', '14_f2.0.png', '12_f2.0.png', '22_f4.0.png',
'17_f20.png', '24_f2.5.png', '24_f2.0.png', '29_f2.0.png', '7_f5.0.png', '26_f2.0.png', '2_f2.2.png',
'25_f6.3.png', '3_f14.png', '2_f5.0.png', '27_f2.0.png', '10_f3.5.png', '16_f9.0.png', '1_f2.0.png',
'20_f2.2.png', '30_f2.0.png', '22_f2.0.png', '5_f4.0.png', '15_f2.0.png', '30_f18.png', '8_f2.0.png',
'19_f2.0.png', '27_f16.png', '3_f2.0.png', '7_f2.0.png', '4_f8.0.png', '9_f2.0.png', '1_f7.1.png',
'12_f16.png', '20_f2.0.png', '27_f8.0.png', '10_f2.0.png', '13_f2.0.png', '11_f2.0.png', '11_f2.2.png',
'26_f5.6.png', '13_f2.8.png', '13_f9.0.png', '5_f8.0.png', '6_f2.0.png', '16_f2.0.png', '4_f2.0.png',
'23_f8.0.png'
]
name_in_output = [file.name for file in output_directory.glob(f'*.{args.image_format}')]
assert len(set(name_in_output) - set(image_names)) == 0, \
("The following images should not be in the output directory! \n"
f"{set(name_in_output) - set(image_names)}")
assert len(set(image_names) - set(name_in_output)) == 0, \
("The following images are missing from the output directory! \n"
f"{set(image_names) - set(name_in_output)}")
for image_name in image_names:
assert (output_directory / image_name).exists(), (f"Could not find image file {image_name} "
f"in outout directory {output_directory}!")
assert len(list(output_directory.glob(f'*.{args.image_format}'))) == 78, (f"Expected 78 {args.image_format} images in the output directory ({output_directory}), "
f"but found {len(output_directory.glob(f'*.{args.image_format}'))} images.")
else:
pass
# check if the metadata file exists
assert metadata_file.exists(), "Could not find metadata file!"
make_archive(archive_file, 'zip', root_dir=output_directory)
print(f"Please upload your submission file found at {archive_file.absolute()}.zip to https://www.codabench.org/competitions/12764/#/participate-tab!")