Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class SupervisedTrainer(Trainer):
#ignite.engine.engine.Engine.register_events.
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
default to `False` as training slows due to tensor movement to CPU for decollation when enabled.
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(
amp: bool = False,
event_names: list[str | EventEnum | type[EventEnum]] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
decollate: bool = False,
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run_interaction(self, train, compose):
optimizer=opt,
loss_function=loss,
iteration_update=i,
decollate=True,
)
engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_deepedit_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def run_interaction(self, train):
loss_function=loss,
postprocessing=post_transforms,
iteration_update=i,
decollate=True,
)
engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)
Expand Down
3 changes: 2 additions & 1 deletion tests/testing_data/config_fl_train.json
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@
"loss_function": "@loss",
"optimizer": "@optimizer",
"inferer": "@train#inferer",
"train_handlers": "@train#handlers"
"train_handlers": "@train#handlers",
"decollate": true
}
},
"validate": {
Expand Down
Loading