|
| 1 | +""" |
| 2 | +Copyright 2023 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +"""Compare expected sharding of models with actual sharding of models.""" |
| 18 | + |
| 19 | + |
| 20 | +import hashlib |
| 21 | +from MaxText.train_compile import get_shaped_inputs, get_topology_mesh, validate_config |
| 22 | +from MaxText.tests.sharding_dump import named_shardings_to_json, load_named_sharding_json, TEST_CASES |
| 23 | +from MaxText import pyconfig |
| 24 | +import pytest |
| 25 | +import os |
| 26 | +import json |
| 27 | + |
| 28 | + |
| 29 | +def compute_checksum(d: dict) -> str: |
| 30 | + """Compute a checksum (SHA256) of a dictionary.""" |
| 31 | + # Serialize the dictionary into a JSON string (ensuring consistent ordering of keys) |
| 32 | + json_str = json.dumps(d, sort_keys=True) |
| 33 | + |
| 34 | + # Compute the SHA256 checksum of the serialized string |
| 35 | + checksum = hashlib.sha256(json_str.encode("utf-8")).hexdigest() |
| 36 | + |
| 37 | + return checksum |
| 38 | + |
| 39 | + |
| 40 | +def compare_named_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_name: str) -> bool: |
| 41 | + """Compare two json files and print the differences if any.""" |
| 42 | + keys1 = set(json1.keys()) |
| 43 | + keys2 = set(json2.keys()) |
| 44 | + |
| 45 | + only_in_1 = keys1 - keys2 |
| 46 | + only_in_2 = keys2 - keys1 |
| 47 | + shared_keys = keys1 & keys2 |
| 48 | + |
| 49 | + if only_in_1: |
| 50 | + print(f"Keys only in {model1_name}:") |
| 51 | + for k in sorted(only_in_1): |
| 52 | + print(f" {k}") |
| 53 | + |
| 54 | + if only_in_2: |
| 55 | + print(f"Keys only in {model2_name}:") |
| 56 | + for k in sorted(only_in_2): |
| 57 | + print(f" {k}") |
| 58 | + |
| 59 | + for key in sorted(shared_keys): |
| 60 | + entry1 = json1[key] |
| 61 | + entry2 = json2[key] |
| 62 | + |
| 63 | + mesh1 = entry1.get("mesh", {}) |
| 64 | + mesh2 = entry2.get("mesh", {}) |
| 65 | + spec1 = entry1.get("partition_spec", []) |
| 66 | + spec2 = entry2.get("partition_spec", []) |
| 67 | + |
| 68 | + if mesh1 != mesh2: |
| 69 | + print(f"\nMesh mismatch at '{key}':") |
| 70 | + print(f" mesh1: {mesh1}") |
| 71 | + print(f" mesh2: {mesh2}") |
| 72 | + |
| 73 | + if spec1 != spec2: |
| 74 | + print(f"\nPartitionSpec mismatch at '{key}':") |
| 75 | + print(f" spec1: {spec1}") |
| 76 | + print(f" spec2: {spec2}") |
| 77 | + |
| 78 | + return not only_in_1 and not only_in_2 and all(json1[k] == json2[k] for k in shared_keys) |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES) |
| 82 | +def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None: |
| 83 | + """Test if the sharding of new model implementation is as expected.""" |
| 84 | + params = [ |
| 85 | + "/deps/MaxText/tests/sharding_compare_test", |
| 86 | + "MaxText/configs/base.yml", |
| 87 | + f"compile_topology={topology}", |
| 88 | + f"compile_topology_num_slices={num_slice}", |
| 89 | + f"model_name={model_name}", |
| 90 | + ] |
| 91 | + |
| 92 | + json_path = f"sharding_info/" f"{model_name}/" f"{topology}/" f"slice_{num_slice}/named_shardings.json" |
| 93 | + if not os.path.exists(json_path): |
| 94 | + return |
| 95 | + |
| 96 | + config = pyconfig.initialize(params) |
| 97 | + validate_config(config) |
| 98 | + |
| 99 | + topology_mesh = get_topology_mesh(config) |
| 100 | + _, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config) |
| 101 | + actual_json = named_shardings_to_json(state_mesh_shardings) |
| 102 | + expected_json = load_named_sharding_json(json_path) |
| 103 | + |
| 104 | + actual_checksum = compute_checksum(actual_json) |
| 105 | + expected_checksum2 = compute_checksum(expected_json) |
| 106 | + result = actual_checksum == expected_checksum2 |
| 107 | + |
| 108 | + if not result: |
| 109 | + compare_named_sharding_jsons(expected_json, f"expected_{model_name}", actual_json, f"actual_{model_name}") |
| 110 | + |
| 111 | + assert result is True |
0 commit comments