Skip to content

Commit 842964d

Browse files
authored
[Fix] Redundant operations to transfer data between CPU and GPU (vqdang#206)
* fix typo * remove redundant operations
1 parent 9b21c86 commit 842964d

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self):
4242
raise Exception("If using `original` mode, input shape must be [270,270] and output shape must be [80,80]")
4343
if model_mode == "fast":
4444
if act_shape != [256,256] or out_shape != [164,164]:
45-
raise Exception("If using `original` mode, input shape must be [256,256] and output shape must be [164,164]")
45+
raise Exception("If using `fast` mode, input shape must be [256,256] and output shape must be [164,164]")
4646

4747
self.dataset_name = "consep" # extracts dataset info from dataset.py
4848
self.log_dir = "logs/" # where checkpoints will be saved

models/hovernet/run_desc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def valid_step(batch_data, run_info):
125125
imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous()
126126

127127
# HWC
128-
true_np = torch.squeeze(true_np).to("cuda").type(torch.int64)
129-
true_hv = torch.squeeze(true_hv).to("cuda").type(torch.float32)
128+
true_np = torch.squeeze(true_np).type(torch.int64)
129+
true_hv = torch.squeeze(true_hv).type(torch.float32)
130130

131131
true_dict = {
132132
"np": true_np,
@@ -135,7 +135,7 @@ def valid_step(batch_data, run_info):
135135

136136
if model.module.nr_types is not None:
137137
true_tp = batch_data["tp_map"]
138-
true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64)
138+
true_tp = torch.squeeze(true_tp).type(torch.int64)
139139
true_dict["tp"] = true_tp
140140

141141
# --------------------------------------------------------------
@@ -155,14 +155,14 @@ def valid_step(batch_data, run_info):
155155
result_dict = { # protocol for contents exchange within `raw`
156156
"raw": {
157157
"imgs": imgs.numpy(),
158-
"true_np": true_dict["np"].cpu().numpy(),
159-
"true_hv": true_dict["hv"].cpu().numpy(),
158+
"true_np": true_dict["np"].numpy(),
159+
"true_hv": true_dict["hv"].numpy(),
160160
"prob_np": pred_dict["np"].cpu().numpy(),
161161
"pred_hv": pred_dict["hv"].cpu().numpy(),
162162
}
163163
}
164164
if model.module.nr_types is not None:
165-
result_dict["raw"]["true_tp"] = true_dict["tp"].cpu().numpy()
165+
result_dict["raw"]["true_tp"] = true_dict["tp"].numpy()
166166
result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy()
167167
return result_dict
168168

0 commit comments

Comments
 (0)