Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rich.console import Console
from rich.text import Text

from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS
from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS, download_from_gs_bucket

console = Console()

Expand Down Expand Up @@ -464,6 +464,9 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami:
continue

filtered_command = build_command_without_args(command[idx:], CACHE_EXCLUDED_ARGS)
filtered_command = maybe_download_tokenizer_from_gs_bucket(
filtered_command, args.auto_output_dir, whoami
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect attribute name causes AttributeError at runtime

The code accesses args.auto_output_dir but the argument is defined as --auto_output_dir_path, meaning the correct attribute name is args.auto_output_dir_path. Other parts of the codebase correctly use args.auto_output_dir_path. This typo will cause an AttributeError when this code path executes.

Fix in Cursor Fix in Web

)
caching_command = "python " + " ".join(filtered_command) + " --cache_dataset_only"
console.log("📦📦📦 Running the caching command with `--cache_dataset_only`")
import subprocess
Expand Down Expand Up @@ -780,6 +783,35 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami:
return spec


def maybe_download_tokenizer_from_gs_bucket(filtered_command: str, auto_output_dir_path: str, whoami: str):
"""if model is only on gs, download tokenizer from gs to local cache folder for dataset preprocessing"""

if "--model_name_or_path" not in filtered_command:
return filtered_command

model_arg_idx = filtered_command.index("--model_name_or_path")
model_name_idx = model_arg_idx + 1
model_name_or_path = filtered_command[model_name_idx].rstrip("/")

if model_name_or_path.startswith("gs://"):
model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8]
local_cache_folder = f"{auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/"

if not os.path.exists(local_cache_folder):
download_from_gs_bucket(
[
f"{model_name_or_path}/tokenizer.json",
f"{model_name_or_path}/tokenizer_config.json",
f"{model_name_or_path}/config.json",
],
local_cache_folder,
)

filtered_command[model_name_idx] = local_cache_folder

return filtered_command


def main():
args, commands = get_args()
# If the user is not in Ai2, we run the command as is
Expand Down
32 changes: 32 additions & 0 deletions open_instruct/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# Copied from https://github.com/huggingface/alignment-handbook/blob/main/tests/test_data.py
import json
import pathlib
import tempfile
import time
import unittest
from unittest import mock
Expand Down Expand Up @@ -262,6 +263,37 @@ def test_send_slack_alert_with_beaker_url(self, mock_environ_get, mock_get_beake
self.assertIn("Test error message", request_body["text"])


class TestDownloadFromGsBucket(unittest.TestCase):
def test_download_from_gs_bucket(self):
src_paths = ["gs://bucket/data1", "gs://bucket/data2"]

with tempfile.TemporaryDirectory() as tmp_dir:
dest_path = pathlib.Path(tmp_dir) / "downloads"
captured_cmd: dict[str, list[str]] = {}

def mock_live_subprocess_output(cmd):
captured_cmd["cmd"] = cmd

with mock.patch.object(utils, "live_subprocess_output", side_effect=mock_live_subprocess_output):
utils.download_from_gs_bucket(src_paths=src_paths, dest_path=str(dest_path))

expected_cmd = [
"gsutil",
"-o",
"GSUtil:parallel_thread_count=1",
"-o",
"GSUtil:sliced_object_download_threshold=150",
"-m",
"cp",
"-r",
*src_paths,
str(dest_path),
]

self.assertEqual(captured_cmd["cmd"], expected_cmd)
self.assertTrue(dest_path.exists())


class TestUtilityFunctions(unittest.TestCase):
"""Test utility functions in utils module."""

Expand Down
7 changes: 4 additions & 3 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,8 @@ def download_from_hf(model_name_or_path: str, revision: str) -> None:
return output


def download_from_gs_bucket(src_path: str, dest_path: str) -> None:
def download_from_gs_bucket(src_paths: list[str], dest_path: str) -> None:
os.makedirs(dest_path, exist_ok=True)
cmd = [
"gsutil",
"-o",
Expand All @@ -1113,9 +1114,9 @@ def download_from_gs_bucket(src_path: str, dest_path: str) -> None:
"-m",
"cp",
"-r",
src_path,
dest_path,
]
cmd.extend(src_paths)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests here? Let's mock live_subprocess_output so we capture the cmd it's called with and verify it against some known correct values. Should be a one prompt change with Codex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple prompts but I think its good

cmd.append(dest_path)
print(f"Downloading from GS bucket with command: {cmd}")
live_subprocess_output(cmd)

Expand Down