Skip to content

Commit 59b8792

Browse files
Merge pull request #1976 from hsuan-lun-chiang:feat/sharding-test
PiperOrigin-RevId: 793823533
2 parents 252d1fc + a97bfca commit 59b8792

File tree

33 files changed

+53256
-0
lines changed

33 files changed

+53256
-0
lines changed

MaxText/tests/run_sharding_dump.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
Copyright 2023 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
""" Run script to dump sharding of various combination of model and topology. """
18+
19+
20+
from typing import Sequence
21+
from MaxText.tests.sharding_dump import TEST_CASES
22+
import os
23+
import subprocess
24+
from absl import app
25+
26+
27+
def run_single_dump(model_name: str, topology: str, num_slice: str) -> None:
28+
"""Generate sharding json file for one specific model, topology and slice."""
29+
subprocess.run(
30+
[
31+
"python",
32+
"-m",
33+
"MaxText.tests.sharding_dump",
34+
"MaxText/configs/base.yml",
35+
f"compile_topology={topology}",
36+
f"compile_topology_num_slices={num_slice}",
37+
f"model_name={model_name}",
38+
],
39+
check=True,
40+
)
41+
42+
43+
def main(argv: Sequence[str]) -> None:
44+
"""Generate sharding json files for every combination of model, topology and slices."""
45+
for model_name, topology, num_slice in TEST_CASES:
46+
json_path = f"sharding_info/{model_name}/{topology}/slice_{num_slice}/named_shardings.json"
47+
if os.path.exists(json_path):
48+
continue
49+
run_single_dump(model_name, topology, num_slice)
50+
51+
52+
if __name__ == "__main__":
53+
app.run(main)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
Copyright 2023 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""Compare expected sharding of models with actual sharding of models."""
18+
19+
20+
import hashlib
21+
from MaxText.train_compile import get_shaped_inputs, get_topology_mesh, validate_config
22+
from MaxText.tests.sharding_dump import named_shardings_to_json, load_named_sharding_json, TEST_CASES
23+
from MaxText import pyconfig
24+
import pytest
25+
import os
26+
import json
27+
28+
29+
def compute_checksum(d: dict) -> str:
30+
"""Compute a checksum (SHA256) of a dictionary."""
31+
# Serialize the dictionary into a JSON string (ensuring consistent ordering of keys)
32+
json_str = json.dumps(d, sort_keys=True)
33+
34+
# Compute the SHA256 checksum of the serialized string
35+
checksum = hashlib.sha256(json_str.encode("utf-8")).hexdigest()
36+
37+
return checksum
38+
39+
40+
def compare_named_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_name: str) -> bool:
41+
"""Compare two json files and print the differences if any."""
42+
keys1 = set(json1.keys())
43+
keys2 = set(json2.keys())
44+
45+
only_in_1 = keys1 - keys2
46+
only_in_2 = keys2 - keys1
47+
shared_keys = keys1 & keys2
48+
49+
if only_in_1:
50+
print(f"Keys only in {model1_name}:")
51+
for k in sorted(only_in_1):
52+
print(f" {k}")
53+
54+
if only_in_2:
55+
print(f"Keys only in {model2_name}:")
56+
for k in sorted(only_in_2):
57+
print(f" {k}")
58+
59+
for key in sorted(shared_keys):
60+
entry1 = json1[key]
61+
entry2 = json2[key]
62+
63+
mesh1 = entry1.get("mesh", {})
64+
mesh2 = entry2.get("mesh", {})
65+
spec1 = entry1.get("partition_spec", [])
66+
spec2 = entry2.get("partition_spec", [])
67+
68+
if mesh1 != mesh2:
69+
print(f"\nMesh mismatch at '{key}':")
70+
print(f" mesh1: {mesh1}")
71+
print(f" mesh2: {mesh2}")
72+
73+
if spec1 != spec2:
74+
print(f"\nPartitionSpec mismatch at '{key}':")
75+
print(f" spec1: {spec1}")
76+
print(f" spec2: {spec2}")
77+
78+
return not only_in_1 and not only_in_2 and all(json1[k] == json2[k] for k in shared_keys)
79+
80+
81+
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
82+
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
83+
"""Test if the sharding of new model implementation is as expected."""
84+
params = [
85+
"/deps/MaxText/tests/sharding_compare_test",
86+
"MaxText/configs/base.yml",
87+
f"compile_topology={topology}",
88+
f"compile_topology_num_slices={num_slice}",
89+
f"model_name={model_name}",
90+
]
91+
92+
json_path = f"sharding_info/" f"{model_name}/" f"{topology}/" f"slice_{num_slice}/named_shardings.json"
93+
if not os.path.exists(json_path):
94+
return
95+
96+
config = pyconfig.initialize(params)
97+
validate_config(config)
98+
99+
topology_mesh = get_topology_mesh(config)
100+
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
101+
actual_json = named_shardings_to_json(state_mesh_shardings)
102+
expected_json = load_named_sharding_json(json_path)
103+
104+
actual_checksum = compute_checksum(actual_json)
105+
expected_checksum2 = compute_checksum(expected_json)
106+
result = actual_checksum == expected_checksum2
107+
108+
if not result:
109+
compare_named_sharding_jsons(expected_json, f"expected_{model_name}", actual_json, f"actual_{model_name}")
110+
111+
assert result is True

0 commit comments

Comments
 (0)