Skip to content

Commit 353a891

Browse files
committed
Merge branch 'main' of https://github.com/google/ml-flashpoint into rename-thread-count-to-files-per-rank
2 parents d7a4e3a + dcb28da commit 353a891

File tree

8 files changed

+620
-46
lines changed

8 files changed

+620
-46
lines changed

cloudbuild.yaml

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
# limitations under the License.
1414

1515
steps:
16-
# 1. Build the distribution (sdist and wheel).
17-
# This uses scikit-build-core as defined in pyproject.toml to compile C++ extensions.
18-
# We set BUILD_TESTING=OFF to ignore tests during the artifact build.
16+
# 1. Build the source distribution (sdist).
1917
- name: 'python:3.10'
20-
id: 'build'
18+
id: 'build-sdist'
2119
entrypoint: 'bash'
2220
args:
2321
- '-c'
@@ -27,16 +25,37 @@ steps:
2725
exit 1
2826
fi
2927
pip install build
30-
31-
echo "Building C++ extensions with editable install..."
32-
# Note: pip install -e . will also respect the build-system requirements
33-
SKBUILD_CMAKE_ARGS="-DBUILD_TESTING=OFF" pip install -e .
34-
35-
echo "Building sdist and wheel..."
36-
SKBUILD_CMAKE_ARGS="-DBUILD_TESTING=OFF" python -m build
28+
python -m build --sdist
3729
38-
# 2. Upload to internal Artifact Registry (AR) for OSS Exit Gate.
39-
# OSS Exit Gate fetches artifacts from this repository.
30+
# 2. Build the manylinux wheels.
31+
# We use the manylinux image directly to avoid the "reserved" docker.sock issue.
32+
# This image is pre-loaded with Python versions and build tools.
33+
- name: '$_MANYLINUX_IMAGE'
34+
id: 'build-wheels'
35+
entrypoint: 'bash'
36+
args:
37+
- '-c'
38+
- |
39+
# Use Python 3.10 as the build controller (it will produce the abi3 wheel)
40+
PYBIN="/opt/python/cp310-cp310/bin"
41+
42+
# Install build dependencies
43+
$$PYBIN/pip install build
44+
45+
echo "Building wheel (scikit-build-core will handle abi3 and C++ extension compilation)..."
46+
SKBUILD_CMAKE_ARGS="-DBUILD_TESTING=OFF" $$PYBIN/python -m build --wheel
47+
48+
echo "Repairing wheel with auditwheel to ensure manylinux compliance..."
49+
# auditwheel repair will bundle any external shared libraries and fix the platform tag
50+
$$PYBIN/auditwheel repair dist/*.whl --wheel-dir dist/
51+
52+
# Remove the original non-compliant wheel (the one with the 'linux' tag)
53+
# to ensure only the manylinux version is uploaded.
54+
rm dist/*-linux_*.whl
55+
waitFor: ['build-sdist']
56+
57+
58+
# 3. Upload to internal Artifact Registry (AR) for OSS Exit Gate.
4059
- name: 'python:3.10'
4160
id: 'upload-to-ar'
4261
entrypoint: 'bash'
@@ -45,9 +64,9 @@ steps:
4564
- |
4665
pip install -U twine keyring keyrings.google-artifactregistry-auth
4766
twine upload --repository-url https://us-python.pkg.dev/oss-exit-gate-prod/${_PROJECT_NAME}--pypi dist/*
48-
waitFor: ['build']
67+
waitFor: ['build-wheels']
4968

50-
# 3. Create and upload the manifest to GCS to trigger the Exit Gate publication.
69+
# 4. Create and upload the manifest to GCS to trigger the Exit Gate publication.
5170
# The presence of this file in the specific GCS bucket triggers the verification and publishing process.
5271
- name: 'gcr.io/cloud-builders/gcloud'
5372
id: 'trigger-exit-gate'
@@ -67,3 +86,5 @@ options:
6786

6887
substitutions:
6988
_PROJECT_NAME: 'ml-flashpoint'
89+
# Default to x86_64; can be overridden in the Trigger for ARM64.
90+
_MANYLINUX_IMAGE: 'quay.io/pypa/manylinux_2_28_x86_64'

pyproject.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ build-backend = "scikit_build_core.build"
132132
# Tells scikit-build-core to use setuptools-scm to retrieve the version from git.
133133
metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"
134134

135+
# Enable Stable ABI (abi3) for Python 3.10 and later.
136+
# This produces a single wheel per architecture that works on all future Python versions.
137+
wheel.py-api = "cp310"
138+
135139
# Specifies the minimum version of CMake that must be present on the system.
136140
cmake.version = ">=3.18"
137141

@@ -193,3 +197,27 @@ fail_under = 90
193197
[tool.gcovr]
194198
fail-under-line = "80"
195199
#fail-under-branch = "85"
200+
201+
# ===================================================================
202+
# Tool-specific Configuration for cibuildwheel
203+
# ===================================================================
204+
[tool.cibuildwheel]
205+
# Build only once per architecture (using Python 3.10).
206+
# Because abi3 is enabled, this wheel will work for 3.10, 3.11, 3.12, 3.13, etc.
207+
build = "cp310-*"
208+
# Target both Intel (x86_64) and ARM (aarch64) architectures.
209+
archs = ["x86_64", "aarch64"]
210+
# Skip 32-bit builds and musllinux (less common, so skipping for simplicity)
211+
skip = "*-manylinux_i686 *-musllinux_*"
212+
213+
[tool.cibuildwheel.linux]
214+
# Pass the flag to skip tests during the build inside the container
215+
environment = { SKBUILD_CMAKE_ARGS="-DBUILD_TESTING=OFF" }
216+
# We need to install the build requirements inside the cibuildwheel container.
217+
# We also print the architecture and environment for debugging.
218+
before-build = """
219+
uname -m && \
220+
pip install pybind11 scikit-build-core cmake ninja setuptools-scm && \
221+
cmake --version
222+
"""
223+

src/ml_flashpoint/adapter/megatron/save_strategies.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import json
1617
import logging
1718
import os
@@ -28,6 +29,8 @@
2829
_replace_state_dict_keys_with_sharded_keys,
2930
mcore_to_pyt_state_dict,
3031
)
32+
from torch.distributed.checkpoint.metadata import Metadata
33+
from torch.distributed.checkpoint.planner import SavePlan
3134
from torch.distributed.checkpoint.utils import _DistWrapper
3235
from typing_extensions import override
3336

@@ -83,17 +86,27 @@ def __init__(
8386
storage_writer: MemoryStorageWriter,
8487
backend: str = default_backend_format_name(),
8588
version: int = default_backend_format_version(),
89+
use_cached_ckpt_structure: bool = False,
8690
):
8791
"""
8892
Args:
8993
storage_writer (MemoryStorageWriter): The storage writer to use for saving operations.
9094
backend (str, optional): The name of the backend format. Defaults to "ml_flashpoint", which is recommended.
9195
version (int, optional): The version of the checkpoint format. Defaults to the latest version.
96+
use_cached_ckpt_structure (bool, optional): Whether to reuse the checkpoint structure (plan)
97+
from the previous save. Defaults to False.
9298
"""
9399
super().__init__(backend=backend, version=version)
94100
self._storage_writer: MemoryStorageWriter = storage_writer
95101
self._checkpoint_saver: MLFlashpointCheckpointSaver = storage_writer.checkpoint_saver
96102

103+
# Cache for state dict saving
104+
self._cached_central_plan: SavePlan | None = None
105+
self._cached_local_plan: SavePlan | None = None
106+
self._cached_global_metadata: Metadata | None = None
107+
self._validated_cache_reuse: bool = False
108+
self._use_cached_ckpt_structure: bool = use_cached_ckpt_structure
109+
97110
@override
98111
def can_handle_sharded_objects(self) -> bool:
99112
# Not currently used, but in case it is, ensure this strategy is used for ShardedObjects as well.
@@ -157,14 +170,42 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
157170
# we also use Megatron's SavePlanner during saving for compatibility.
158171
planner: MCoreSavePlanner = MCoreSavePlanner(can_run_decentralized_global_plan=False)
159172
world_dist_wrapper = _DistWrapper(group=None, use_dist=not disable_dist, coordinator_rank=0)
160-
plan, write_buckets, global_metadata = statedictsaver.generate_plan(
173+
# Try twice to validate the generated `central_plan` is the same across iterations
174+
# If so, reuse `cached_central_plan` and `cached_global_metadata`
175+
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
176+
# (return None) so `self.cached_global_metadata` is reused
177+
cached_structure_args = None
178+
if self._use_cached_ckpt_structure:
179+
cached_structure_args = (
180+
self._cached_central_plan,
181+
self._cached_local_plan,
182+
self._validated_cache_reuse,
183+
)
184+
185+
(
186+
write_buckets,
187+
global_metadata,
188+
self._cached_central_plan,
189+
self._cached_local_plan,
190+
self._validated_cache_reuse,
191+
) = statedictsaver.generate_plan(
161192
checkpoint_id=checkpoint_id,
162193
state_dict=pyt_state_dict,
163194
storage_writer=self._storage_writer,
164195
planner=planner,
165196
world_dist_wrapper=world_dist_wrapper,
197+
cached_ckpt_structure=cached_structure_args,
166198
)
167199

200+
if global_metadata is None:
201+
# We want to use the cached metadata structure, but ensure any modifications (like adding storage data)
202+
# are done on a copy so the cache remains clean.
203+
global_metadata = copy.deepcopy(self._cached_global_metadata)
204+
else:
205+
# Checkpoint structure (and thus metadata) changed or was generated for the first time.
206+
# Cache a clean copy of the metadata before storage data is potentially added later.
207+
self._cached_global_metadata = copy.deepcopy(global_metadata)
208+
168209
# 5. Stage to CPU.
169210
staged_write_buckets = self._storage_writer.stage_write_data_buckets(
170211
checkpoint_id, write_buckets, non_blocking=True

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
4545
write_files_per_rank: int = 1,
4646
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
4747
use_optimized_save: bool = True,
48+
use_cached_ckpt_structure: bool = False,
4849
) -> MLFlashpointAutoResume:
4950
"""Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`.
5051
@@ -62,6 +63,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
6263
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1.
6364
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
6465
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
66+
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
67+
Defaults to False.
6568
Returns:
6669
An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`.
6770
"""
@@ -90,6 +93,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
9093
write_files_per_rank=write_files_per_rank,
9194
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
9295
use_optimized_save=use_optimized_save,
96+
use_cached_ckpt_structure=use_cached_ckpt_structure,
9397
)
9498

9599
default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {}
@@ -111,6 +115,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
111115
write_files_per_rank: int = 1,
112116
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
113117
use_optimized_save: bool = True,
118+
use_cached_ckpt_structure: bool = False,
114119
):
115120
"""Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities.
116121
@@ -138,6 +143,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
138143
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1.
139144
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
140145
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
146+
use_cached_ckpt_structure: Whether to reuse the checkpoint structure (plan) from the previous save.
147+
Defaults to False.
141148
142149
Returns:
143150
None. The trainer's checkpoint_io is modified in-place.
@@ -218,7 +225,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
218225
),
219226
mp_manager=ctx.Manager(),
220227
files_per_rank=write_files_per_rank,
221-
)
228+
),
229+
use_cached_ckpt_structure=use_cached_ckpt_structure,
222230
)
223231
load_strategy = MLFlashpointMegatronLoadStrategy(
224232
replication_manager=replication_manager,

src/ml_flashpoint/adapter/pytorch/custom_state_dict_saver.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import torch.cuda
2424
from torch import distributed as torchdist
25-
from torch.distributed.checkpoint import Metadata
2625
from torch.distributed.checkpoint import state_dict_saver as torchdistsaver
2726
from torch.distributed.checkpoint.logger import _dcp_method_logger
2827
from torch.distributed.checkpoint.planner import SavePlan
@@ -46,7 +45,14 @@ def generate_plan(
4645
storage_writer: MemoryStorageWriter,
4746
planner: torchdistsaver.SavePlanner,
4847
world_dist_wrapper: _DistWrapper,
49-
) -> tuple[SavePlan, list[ObjectWriteBucket], Metadata]:
48+
cached_ckpt_structure: tuple[SavePlan, SavePlan, bool] | None = None,
49+
) -> tuple[
50+
list[ObjectWriteBucket],
51+
torchdistsaver.Metadata | None,
52+
SavePlan,
53+
SavePlan,
54+
bool,
55+
]:
5056
"""Performs the planning phase of checkpointing.
5157
5258
This function is similar to PyTorch's `state_dict_saver.save` but only
@@ -62,9 +68,27 @@ def generate_plan(
6268
planner: The SavePlanner to use for the save.
6369
world_dist_wrapper: The distributed wrapper for world (all ranks) communication.
6470
Typically created as `_DistWrapper(process_group, not no_dist, coordinator_rank)`.
71+
cached_ckpt_structure: Tuple of (cached_central_plan, cached_local_plan, validated_cache_reuse).
72+
6573
Returns:
66-
A tuple containing the updated local plan, write buckets, and global metadata.
74+
A tuple containing:
75+
- write_buckets: The buckets of data to be written.
76+
- global_metadata: The global metadata for the checkpoint.
77+
- central_plan (for caching): The centralized plan generated by the coordinator.
78+
- local_plan (for caching): The local plan generated by this rank.
79+
- validated_cache_reuse (bool): Whether the cached plan was successfully validated against the current
80+
plan.
81+
- If True: The structure of the checkpoint has not changed (e.g., same tensor shapes and sharding),
82+
so the cached plan can be safely reused for future steps to skip expensive planning.
83+
- If False: The structure has changed or this is the first run, so the plan was re-generated.
84+
- After 1st run: cached_central_plan is None, this value stays False -> 2nd run will validate cache.
85+
- After 2nd run: cached_central_plan == central_plan (if structure stable), so this value becomes True
86+
- After 3rd run+: reuse cached plan if structure stable, otherwise regenerate.
6787
"""
88+
cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False)
89+
if cached_ckpt_structure:
90+
cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure
91+
6892
global_metadata: torchdistsaver.Metadata | None = None
6993

7094
ckpt_kwargs = {"checkpoint_id": storage_writer.current_checkpoint_id, "process_group": world_dist_wrapper.group}
@@ -79,9 +103,12 @@ def local_step() -> SavePlan:
79103
)
80104
storage_writer.set_up_storage_writer(world_dist_wrapper.is_coordinator)
81105

82-
local_plan = planner.create_local_plan()
83-
local_plan = storage_writer.prepare_local_plan(local_plan)
84-
return local_plan
106+
if cached_local_plan and validated_cache_reuse:
107+
plan = cached_local_plan
108+
else:
109+
plan = planner.create_local_plan()
110+
111+
return storage_writer.prepare_local_plan(plan)
85112

86113
@_dcp_method_logger(**ckpt_kwargs)
87114
def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]:
@@ -91,19 +118,31 @@ def global_step(all_local_plans: list[SavePlan]) -> list[SavePlan]:
91118
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
92119
return all_local_plans
93120

94-
with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"):
95-
_LOGGER.debug("Executing plan reduce_scatter to get updated_local_plan...")
96-
updated_local_plan = world_dist_wrapper.reduce_scatter("plan", local_step, global_step)
97-
98-
with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"):
99-
_LOGGER.debug("Executing global_metadata broadcast...")
100-
# TODO(perf): - can broadcast only to local rank 0 to reduce comms
101-
global_metadata = world_dist_wrapper.broadcast_object(global_metadata)
102-
103-
final_local_plan = planner.finish_plan(updated_local_plan)
121+
central_plan = None
122+
if validated_cache_reuse and cached_central_plan:
123+
_LOGGER.debug("Passed cache reusable")
124+
local_plan = local_step()
125+
central_plan = cached_central_plan
126+
else:
127+
with log_execution_time(logger=_LOGGER, name="generate_plan__reduce_scatter_plan"):
128+
_LOGGER.debug("Executing plan reduce_scatter to get central_plan...")
129+
local_plan = local_step()
130+
central_plan = world_dist_wrapper.reduce_scatter("plan", lambda: local_plan, global_step)
131+
132+
with log_execution_time(logger=_LOGGER, name="generate_plan__broadcast_metadata"):
133+
_LOGGER.debug("Executing global_metadata broadcast...")
134+
global_metadata = world_dist_wrapper.broadcast_object(global_metadata)
135+
136+
final_local_plan = planner.finish_plan(central_plan)
104137
write_buckets = storage_writer.prepare_write_data_buckets(checkpoint_id, final_local_plan, planner)
105138

106-
return final_local_plan, write_buckets, global_metadata
139+
return (
140+
write_buckets,
141+
global_metadata,
142+
central_plan,
143+
local_plan,
144+
cached_central_plan == central_plan,
145+
)
107146

108147

109148
@log_execution_time(logger=_LOGGER, name="write_data", level=logging.INFO)

0 commit comments

Comments
 (0)