File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed
checkpoint/orbax/checkpoint/experimental/v1/_src/partial Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change 1414
1515"""Partial merging utils."""
1616
17+ import collections
1718from typing import Any , NamedTuple , TypeVar
1819
1920from etils import epath
2021import jax
2122from orbax .checkpoint ._src .metadata import value as value_metadata
2223from orbax .checkpoint ._src .tree import structure_utils
24+ from orbax .checkpoint ._src .tree import utils as tree_utils
2325from orbax .checkpoint .experimental .v1 ._src .metadata import loading as metadata_loading
2426
2527PyTree = Any
2628T = TypeVar ('T' )
2729PyTreeOf = PyTree | T
30+ Keypath = tuple [Any , ...]
2831ArrayMetadata = value_metadata .ArrayMetadata
2932
3033
@@ -49,3 +52,19 @@ def merge_ckpt_metadata(
4952 overwrite = True ,
5053 is_leaf = lambda x : isinstance (x , SourceIndexedMetadata )
5154 )
55+
56+
57+ def group_leaves_by_ckpt (
58+ merged_metadata : PyTreeOf [SourceIndexedMetadata ],
59+ ) -> dict [int , dict [Keypath , ArrayMetadata ]]:
60+ """Groups leaves by the checkpoint index they belong to."""
61+ leaves_by_ckpt = collections .defaultdict (dict )
62+ for keypath , (ckpt_idx , metadata ) in sorted (
63+ tree_utils .to_flat_dict (
64+ merged_metadata ,
65+ is_leaf = lambda x : isinstance (x , SourceIndexedMetadata ),
66+ ).items (),
67+ key = lambda x : x [0 ],
68+ ):
69+ leaves_by_ckpt [ckpt_idx ][keypath ] = metadata
70+ return leaves_by_ckpt
You can’t perform that action at this time.
0 commit comments