From 98e7976ee532707e0dd24fc7ceac25fb9c28a4f3 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 12 Nov 2025 21:41:42 +0000 Subject: [PATCH 1/4] pass in `model_name_or_path` that is on augusta and it works --- mason.py | 32 +++++++++++++++++++++++++++++++- open_instruct/utils.py | 9 ++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/mason.py b/mason.py index 4ce339eb9e..9b1a51efca 100644 --- a/mason.py +++ b/mason.py @@ -441,7 +441,12 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami: is_open_instruct_training = any(cmd in command for cmd in OPEN_INSTRUCT_COMMANDS) if is_open_instruct_training: from open_instruct.dataset_transformation import get_commit_hash - from open_instruct.utils import download_from_hf, gs_folder_exists, upload_to_gs_bucket + from open_instruct.utils import ( + download_from_gs_bucket, + download_from_hf, + gs_folder_exists, + upload_to_gs_bucket, + ) # HACK: Cache dataset logic: # Here we basically try to run the tokenization full_command locally before running it on beaker @@ -467,6 +472,31 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami: continue filtered_command = build_command_without_args(command[idx:], CACHE_EXCLUDED_ARGS) + + # if model is only on gs, download tokenizer from gs for dataset preprocessing + try: + model_arg_idx = filtered_command.index("--model_name_or_path") + model_name_idx = model_arg_idx + 1 + model_name_or_path = filtered_command[model_name_idx].rstrip("/") + + if model_name_or_path.startswith("gs://"): + model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8] + local_cache_folder = f"{args.auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/" + + if not os.path.exists(local_cache_folder): + download_from_gs_bucket( + [ + f"{model_name_or_path}/tokenizer.json", + f"{model_name_or_path}/tokenizer_config.json", + f"{model_name_or_path}/config.json", + ], + local_cache_folder, + ) + + filtered_command[model_name_idx] = local_cache_folder + except ValueError: + pass + caching_command = "python " + " ".join(filtered_command) + " --cache_dataset_only" console.log("📦📦📦 Running the caching command with `--cache_dataset_only`") import subprocess diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 4fcb82a867..3461eb2759 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1064,7 +1064,8 @@ def download_from_hf(model_name_or_path: str, revision: str) -> None: return output -def download_from_gs_bucket(src_path: str, dest_path: str) -> None: +def download_from_gs_bucket(src_paths: str | list[str], dest_path: str) -> None: + os.makedirs(dest_path, exist_ok=True) cmd = [ "gsutil", "-o", @@ -1074,9 +1075,11 @@ def download_from_gs_bucket(src_path: str, dest_path: str) -> None: "-m", "cp", "-r", - src_path, - dest_path, ] + if not isinstance(src_paths, list): + src_paths = [src_paths] + cmd.extend(src_paths) + cmd.append(dest_path) print(f"Downloading from GS bucket with command: {cmd}") live_subprocess_output(cmd) From 622d99a1fd14e97269ef5b0d073702eebd37e568 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 26 Nov 2025 16:06:45 -0500 Subject: [PATCH 2/4] make src path list --- open_instruct/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 48f32c9e6a..119b052819 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1103,7 +1103,7 @@ def download_from_hf(model_name_or_path: str, revision: str) -> None: return output -def download_from_gs_bucket(src_paths: str | list[str], dest_path: str) -> None: +def download_from_gs_bucket(src_paths: list[str], dest_path: str) -> None: os.makedirs(dest_path, exist_ok=True) cmd = [ "gsutil", @@ -1115,8 +1115,6 @@ def download_from_gs_bucket(src_paths: str | list[str], dest_path: str) -> None: "cp", "-r", ] - if not isinstance(src_paths, list): - src_paths = [src_paths] cmd.extend(src_paths) cmd.append(dest_path) print(f"Downloading from GS bucket with command: {cmd}") From e461f7255c5ee6c89cb24a4ba580fc5e563a666e Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 27 Nov 2025 16:37:40 -0500 Subject: [PATCH 3/4] Refactor gs bucket download test --- open_instruct/test_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index c46742e42c..dafbe8719e 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -14,6 +14,7 @@ # Copied from https://github.com/huggingface/alignment-handbook/blob/main/tests/test_data.py import json import pathlib +import tempfile import time import unittest from unittest import mock @@ -262,6 +263,37 @@ def test_send_slack_alert_with_beaker_url(self, mock_environ_get, mock_get_beake self.assertIn("Test error message", request_body["text"]) +class TestDownloadFromGsBucket(unittest.TestCase): + def test_download_from_gs_bucket(self): + src_paths = ["gs://bucket/data1", "gs://bucket/data2"] + + with tempfile.TemporaryDirectory() as tmp_dir: + dest_path = pathlib.Path(tmp_dir) / "downloads" + captured_cmd: dict[str, list[str]] = {} + + def mock_live_subprocess_output(cmd): + captured_cmd["cmd"] = cmd + + with mock.patch.object(utils, "live_subprocess_output", side_effect=mock_live_subprocess_output): + utils.download_from_gs_bucket(src_paths=src_paths, dest_path=str(dest_path)) + + expected_cmd = [ + "gsutil", + "-o", + "GSUtil:parallel_thread_count=1", + "-o", + "GSUtil:sliced_object_download_threshold=150", + "-m", + "cp", + "-r", + *src_paths, + str(dest_path), + ] + + self.assertEqual(captured_cmd["cmd"], expected_cmd) + self.assertTrue(dest_path.exists()) + + class TestUtilityFunctions(unittest.TestCase): """Test utility functions in utils module.""" From 9f136a9947d9dde946358d0df647e9cb3b7aaff0 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 27 Nov 2025 16:52:59 -0500 Subject: [PATCH 4/4] download_from_gs_bucket a separate command and removed try except --- mason.py | 66 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/mason.py b/mason.py index 43aeecdaf3..9a886160df 100644 --- a/mason.py +++ b/mason.py @@ -15,7 +15,7 @@ from rich.console import Console from rich.text import Text -from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS +from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS, download_from_gs_bucket console = Console() @@ -438,12 +438,7 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami: is_open_instruct_training = any(cmd in command for cmd in OPEN_INSTRUCT_COMMANDS) if is_open_instruct_training: from open_instruct.dataset_transformation import get_commit_hash - from open_instruct.utils import ( - download_from_gs_bucket, - download_from_hf, - gs_folder_exists, - upload_to_gs_bucket, - ) + from open_instruct.utils import download_from_hf, gs_folder_exists, upload_to_gs_bucket # HACK: Cache dataset logic: # Here we basically try to run the tokenization full_command locally before running it on beaker @@ -469,31 +464,9 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami: continue filtered_command = build_command_without_args(command[idx:], CACHE_EXCLUDED_ARGS) - - # if model is only on gs, download tokenizer from gs for dataset preprocessing - try: - model_arg_idx = filtered_command.index("--model_name_or_path") - model_name_idx = model_arg_idx + 1 - model_name_or_path = filtered_command[model_name_idx].rstrip("/") - - if model_name_or_path.startswith("gs://"): - model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8] - local_cache_folder = f"{args.auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/" - - if not os.path.exists(local_cache_folder): - download_from_gs_bucket( - [ - f"{model_name_or_path}/tokenizer.json", - f"{model_name_or_path}/tokenizer_config.json", - f"{model_name_or_path}/config.json", - ], - local_cache_folder, - ) - - filtered_command[model_name_idx] = local_cache_folder - except ValueError: - pass - + filtered_command = maybe_download_tokenizer_from_gs_bucket( + filtered_command, args.auto_output_dir, whoami + ) caching_command = "python " + " ".join(filtered_command) + " --cache_dataset_only" console.log("📦📦📦 Running the caching command with `--cache_dataset_only`") import subprocess @@ -810,6 +783,35 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: return spec +def maybe_download_tokenizer_from_gs_bucket(filtered_command: str, auto_output_dir_path: str, whoami: str): + """if model is only on gs, download tokenizer from gs to local cache folder for dataset preprocessing""" + + if "--model_name_or_path" not in filtered_command: + return filtered_command + + model_arg_idx = filtered_command.index("--model_name_or_path") + model_name_idx = model_arg_idx + 1 + model_name_or_path = filtered_command[model_name_idx].rstrip("/") + + if model_name_or_path.startswith("gs://"): + model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8] + local_cache_folder = f"{auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/" + + if not os.path.exists(local_cache_folder): + download_from_gs_bucket( + [ + f"{model_name_or_path}/tokenizer.json", + f"{model_name_or_path}/tokenizer_config.json", + f"{model_name_or_path}/config.json", + ], + local_cache_folder, + ) + + filtered_command[model_name_idx] = local_cache_folder + + return filtered_command + + def main(): args, commands = get_args() # If the user is not in Ai2, we run the command as is