Skip to content

Commit a050516

Browse files
author
emcastillo
authored
Merge pull request #716 from linshokaku/dict-state-object
Enable state_dict management for options in the tree structure
2 parents a4b4b50 + c0abd38 commit a050516

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

pytorch_pfn_extras/engine.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
Any,
44
Callable,
55
Dict,
6+
List,
67
Mapping,
8+
NamedTuple,
79
Optional,
810
Sequence,
911
Tuple,
1012
Type,
1113
Union,
14+
cast,
1215
)
1316

1417
import pytorch_pfn_extras.handler as handler_module
@@ -27,6 +30,41 @@
2730
from pytorch_pfn_extras.training.trigger import TriggerLike
2831

2932

33+
def filter_state_objects(
34+
args: Any, key_name: str = ""
35+
) -> List[Tuple[str, StateObjectProtocol]]:
36+
if isinstance(args, tuple) and hasattr(args, "_fields"):
37+
# namedtuple
38+
return filter_state_objects_dict(
39+
cast(NamedTuple, args)._asdict(), key_name=key_name
40+
)
41+
if isinstance(args, dict):
42+
return filter_state_objects_dict(args, key_name=key_name)
43+
if isinstance(args, (list, tuple)):
44+
return sum(
45+
[
46+
filter_state_objects(v, key_name=f"{key_name}.__{i}__")
47+
for i, v in enumerate(args)
48+
],
49+
[],
50+
)
51+
if isinstance(args, StateObjectProtocol):
52+
return [(key_name, args)]
53+
return []
54+
55+
56+
def filter_state_objects_dict(
57+
args: Dict[str, Any], key_name: str = "option"
58+
) -> List[Tuple[str, StateObjectProtocol]]:
59+
return sum(
60+
[
61+
filter_state_objects(v, key_name=f"{key_name}__::__{k}")
62+
for k, v in sorted(args.items())
63+
],
64+
[],
65+
)
66+
67+
3068
def create_trainer(
3169
models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
3270
optimizers: Union[
@@ -109,9 +147,13 @@ def create_trainer(
109147
options = options.copy() if options else {}
110148

111149
state_objects: Dict[str, StateObjectProtocol] = {}
112-
for key, value in options.items():
113-
if isinstance(value, StateObjectProtocol):
114-
state_objects[f"options_{key}"] = value
150+
state_objects_list = filter_state_objects(options)
151+
for key, value in state_objects_list:
152+
state_objects[f"options_{key}"] = value
153+
154+
assert len(state_objects_list) == len(
155+
state_objects
156+
), "There was a duplicate key name in the flattened options dictionary object."
115157

116158
# TODO(kmaehashi): deprecate specifying 'runtime' key in options
117159
runtime_options = dict(

tests/pytorch_pfn_extras_tests/training_tests/test_engine.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Dict
2+
13
import pytest
24
import pytorch_pfn_extras as ppe
35

@@ -75,3 +77,35 @@ def test_extend_after_init(self):
7577
extension = ppe.training.extensions.LogReport()
7678
with pytest.raises(RuntimeError, match="cannot extend after"):
7779
engine.extend(extension)
80+
81+
82+
class DummyStateObjects:
83+
def __init__(self) -> None:
84+
pass
85+
86+
def state_dict(self) -> Dict[str, Any]:
87+
return {}
88+
89+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
90+
return
91+
92+
93+
@pytest.mark.parametrize(
94+
"args",
95+
[
96+
{"a.b": DummyStateObjects(), "a": {"b": DummyStateObjects()}},
97+
{"a": DummyStateObjects(), "b": DummyStateObjects()},
98+
{
99+
"a__.__b": DummyStateObjects(),
100+
"__a__": {"__b__": DummyStateObjects()},
101+
},
102+
{"a::b": DummyStateObjects(), "a": {"b": DummyStateObjects()}},
103+
{"a:b": DummyStateObjects(), "a": {"b": DummyStateObjects()}},
104+
],
105+
)
106+
def test_filter_state_objects(args) -> None:
107+
out = ppe.engine.filter_state_objects(args)
108+
key_set = set()
109+
for key, _ in out:
110+
assert key not in key_set
111+
key_set.add(key)

0 commit comments

Comments
 (0)