Skip to content

Commit 1a009b0

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Add a function to group merged metadata leaves by checkpoint index.
PiperOrigin-RevId: 833561014
1 parent 1f5536e commit 1a009b0

File tree

1 file changed

+19
-0
lines changed
  • checkpoint/orbax/checkpoint/experimental/v1/_src/partial

1 file changed

+19
-0
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/partial/merging.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@
1414

1515
"""Partial merging utils."""
1616

17+
import collections
1718
from typing import Any, NamedTuple, TypeVar
1819

1920
from etils import epath
2021
import jax
2122
from orbax.checkpoint._src.metadata import value as value_metadata
2223
from orbax.checkpoint._src.tree import structure_utils
24+
from orbax.checkpoint._src.tree import utils as tree_utils
2325
from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading
2426

2527
PyTree = Any
2628
T = TypeVar('T')
2729
PyTreeOf = PyTree | T
30+
Keypath = tuple[Any, ...]
2831
ArrayMetadata = value_metadata.ArrayMetadata
2932

3033

@@ -49,3 +52,19 @@ def merge_ckpt_metadata(
4952
overwrite=True,
5053
is_leaf=lambda x: isinstance(x, SourceIndexedMetadata)
5154
)
55+
56+
57+
def group_leaves_by_ckpt(
58+
merged_metadata: PyTreeOf[SourceIndexedMetadata],
59+
) -> dict[int, dict[Keypath, ArrayMetadata]]:
60+
"""Groups leaves by the checkpoint index they belong to."""
61+
leaves_by_ckpt = collections.defaultdict(dict)
62+
for keypath, (ckpt_idx, metadata) in sorted(
63+
tree_utils.to_flat_dict(
64+
merged_metadata,
65+
is_leaf=lambda x: isinstance(x, SourceIndexedMetadata),
66+
).items(),
67+
key=lambda x: x[0],
68+
):
69+
leaves_by_ckpt[ckpt_idx][keypath] = metadata
70+
return leaves_by_ckpt

0 commit comments

Comments
 (0)