diff --git a/pytorch_pfn_extras/engine.py b/pytorch_pfn_extras/engine.py index ed7a8c58c..037cc0435 100644 --- a/pytorch_pfn_extras/engine.py +++ b/pytorch_pfn_extras/engine.py @@ -72,7 +72,7 @@ def create_trainer( Device name used for selecting a corresponding runtime class. logic: A logic object. If `None` is given, an logic object is instantiated - from the default logic class. + from the :class:`pytorch_pfn_extras.handler.CodeBlockLogic` class. transform_model: A function to transform a model structure, often used to unwrap the a module from DDP module. @@ -96,7 +96,7 @@ def create_trainer( runtime_options = dict( runtime_options if runtime_options else options.pop('runtime', {})) - logic = handler_module.Logic() if logic is None else logic + logic = handler_module.CodeBlockLogic() if logic is None else logic handler_class = handler_class if handler_class else handler_module.Handler entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec( @@ -128,7 +128,7 @@ def create_evaluator( progress_bar: bool = False, device: 'DeviceLike' = 'cpu', metrics: Optional[Sequence['MetricType']] = None, - logic: Optional[handler_module.Logic] = None, + logic: Optional[handler_module.BaseLogic] = None, handler_class: Optional[Type[handler_module.Handler]] = None, options: Optional[Dict[str, Any]] = None, runtime_options: Optional[Mapping[str, Any]] = None, @@ -151,7 +151,7 @@ def create_evaluator( output for the reporting. logic: A logic object. If `None` is given, an logic object is instantiated - from the default logic class. + from the :class:`pytorch_pfn_extras.handler.CodeBlockLogic` class. handler_class: A handler class that instantiates a handler object. If `None` is given, `ppe.handler.Handler` is used as a default handler class. @@ -173,7 +173,7 @@ def create_evaluator( runtime_options = dict( runtime_options if runtime_options else options.pop('runtime', {})) - logic = handler_module.Logic() if logic is None else logic + logic = handler_module.CodeBlockLogic() if logic is None else logic handler_class = handler_class if handler_class else handler_module.Handler entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec( diff --git a/pytorch_pfn_extras/handler/_code_block.py b/pytorch_pfn_extras/handler/_code_block.py index 94d4396ad..a354a9d04 100644 --- a/pytorch_pfn_extras/handler/_code_block.py +++ b/pytorch_pfn_extras/handler/_code_block.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Set import torch +import pytorch_pfn_extras as ppe @dataclass @@ -26,6 +27,7 @@ class CodeBlock: backprop: bool backprop_from: Optional[str] backprop_to: Optional[Set[str]] + backprop_fn : Optional[Callable[..., Any]] state: Dict[str, Any] runtime: Any @@ -56,6 +58,7 @@ def update_parameters( optimizers: List[torch.optim.Optimizer], backprop_from: Optional[str] = None, backprop_to: Optional[Set[str]] = None, + backprop_fn : Optional[Callable[..., Any]] = None, ) -> CodeBlock: """ Returns a ``CodeBlock`` that performs the forward, backward passes and @@ -80,6 +83,7 @@ def update_parameters( backprop=True, backprop_from=backprop_from, backprop_to=backprop_to, + backprop_fn=backprop_fn, state=codeblock.state, runtime=codeblock.runtime, ) @@ -106,9 +110,34 @@ def forward(block: Callable) -> CodeBlock: else: module = getattr(block, '__self__', None) assert module is not None - func = block + runtime = ppe.runtime._runtime._module_runtime_tag(module) + + def _forward(batch: Any) -> Any: + + def _normalize_outputs(outputs: Any) -> Dict[str, Any]: + target: Dict[str, Any] + if isinstance(outputs, tuple) and hasattr(outputs, '_fields'): + # namedtuple + target = outputs._asdict() # type: ignore[attr-defined] + elif isinstance(outputs, dict): + target = outputs + elif isinstance(outputs, (list, tuple)): + target = {str(i): out for i, out in enumerate(outputs)} + else: + target = {"0": outputs} + return target + + if isinstance(batch, tuple) and hasattr(batch, '_fields'): + # namedtuple + return _normalize_outputs(block(batch)) + if isinstance(batch, dict): + return _normalize_outputs(block(**batch)) + if isinstance(batch, (list, tuple)): + return _normalize_outputs(block(*batch)) + return _normalize_outputs(block(batch)) + + func = _forward state = {} - runtime = getattr(module, '_ppe_runtime', None) assert runtime is not None return CodeBlock( @@ -117,6 +146,7 @@ def forward(block: Callable) -> CodeBlock: backprop=False, backprop_from=None, backprop_to=None, + backprop_fn=None, state=state, runtime=runtime, ) diff --git a/pytorch_pfn_extras/handler/_logic.py b/pytorch_pfn_extras/handler/_logic.py index 37ced6120..36e97b3ac 100644 --- a/pytorch_pfn_extras/handler/_logic.py +++ b/pytorch_pfn_extras/handler/_logic.py @@ -383,6 +383,7 @@ def consume_options(self, options: Dict[str, Any]) -> None: self.backward_outputs = options.pop('backward_outputs', None) if self.backward_outputs is not None: assert isinstance(self.backward_outputs, str) + self._backward_fn = options.pop('backward_function', None) def train_epoch_begin( self, @@ -433,6 +434,7 @@ def train_step( list(optimizers.values()), self.backward_outputs, None, + self._backward_fn, )(batch) def train_validation_begin( diff --git a/pytorch_pfn_extras/runtime/_runtime.py b/pytorch_pfn_extras/runtime/_runtime.py index 391383e04..55c8dce4a 100644 --- a/pytorch_pfn_extras/runtime/_runtime.py +++ b/pytorch_pfn_extras/runtime/_runtime.py @@ -432,9 +432,10 @@ def _scale(x: torch.Tensor) -> torch.Tensor: # with autocast with _autocast(enabled=self._autocast): - out = code_block.func(**batch) + out = code_block.func(batch) # codeblocks return Dicts-per-se so it is not necessary to normalize + to_backprop = [] if code_block.backprop: if code_block.backprop_from is None: for v in out.values(): @@ -447,10 +448,15 @@ def _scale(x: torch.Tensor) -> torch.Tensor: or v.dtype.is_complex ) ): - _scale(v).backward() # type: ignore[no-untyped-call] + to_backprop.append(_scale(v)) else: - _scale(out[code_block.backprop_from]).backward() # type: ignore + to_backprop.append(_scale(out[code_block.backprop_from])) + for v in to_backprop: + if code_block.backprop_fn is not None: + code_block.backprop_fn(v) # type: ignore + else: + v.backward() # type: ignore[no-untyped-call] if len(code_block.optimizers) == 0: return out diff --git a/tests/pytorch_pfn_extras_tests/test_logic.py b/tests/pytorch_pfn_extras_tests/test_logic.py index eade4d6aa..2da8b704b 100644 --- a/tests/pytorch_pfn_extras_tests/test_logic.py +++ b/tests/pytorch_pfn_extras_tests/test_logic.py @@ -46,8 +46,8 @@ def test_trainer(device): iters_per_epoch = 10 epochs = 20 model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)]) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py index bbc889ae1..32458324a 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py @@ -74,8 +74,8 @@ def test_trainer(device, path): if not torch.cuda.is_available() and device == 'cuda': pytest.skip() model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) @@ -109,8 +109,8 @@ def test_trainer_no_to(path): def test_trainer_invalid_options(path): device = 'cpu' model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) extensions = _make_extensions() options = {'UNKNOWN_OPTIONS': True} @@ -129,8 +129,8 @@ def test_train_with_evaluator(device, progress_bar, path): if not torch.cuda.is_available() and device == 'cuda': pytest.skip() model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) @@ -158,8 +158,8 @@ def test_evaluator_trigger(evaluator_trigger, path): device = 'cpu' progress_bar = False model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) @@ -183,8 +183,8 @@ def test_evaluator_dict(path): device = 'cpu' progress_bar = False model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) @@ -220,8 +220,8 @@ def test_train_result_equal(device, path): def get_result_from_trainer(): model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) extensions = _make_extensions() @@ -238,8 +238,8 @@ def get_result_from_trainer(): def get_result_from_training_loop(): model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) model_with_loss.train() @@ -293,8 +293,8 @@ def _compare_states(s1, s2): class TestTrainerState: def _get_trainer(self, epochs, out_dir): model = MyModel() - ppe.to(model, 'cpu') model_with_loss = MyModelWithLossFn(model) + ppe.to(model_with_loss, 'cpu') optimizer = torch.optim.SGD(model.parameters(), lr=0.1) extensions = _make_extensions() trainer = engine.create_trainer( @@ -356,8 +356,8 @@ def test_trainer_dict_input(device, progress_bar, path): if not torch.cuda.is_available() and device == 'cuda': pytest.skip() model = MyModel() - ppe.to(model, device) model_with_loss = MyModelWithLossDictOutput(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)]) @@ -405,8 +405,8 @@ def test_trainer_namedtuple_input(device, progress_bar, path): if not torch.cuda.is_available() and device == 'cuda': pytest.skip() model = MyModel() - ppe.to(model, device) model_with_loss = ModelNamedTupleIO(model) + ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( [Input(torch.rand(20,), torch.rand(10,), str(i)) for i in range(10)])