From df5f0674f41d64996358c430fd0032de5ec6438b Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Tue, 5 Aug 2025 19:01:03 +0000 Subject: [PATCH 01/19] Add tp command to save tokenizer to gcs and modify e2e test to use cached tokenizer --- .github/workflows/cpu_test.yml | 4 +- .github/workflows/e2e_test.yml | 18 ++++---- .github/workflows/reusable_e2e_check.yml | 4 +- torchprime/launcher/cli.py | 39 ++++++++++++++++ .../launcher/save_hf_tokenizer_and_model.py | 46 +++++++++++++++++++ .../torch_xla_models/model/model_utils.py | 37 +++++++++++++++ torchprime/torch_xla_models/train.py | 3 +- 7 files changed, 137 insertions(+), 14 deletions(-) create mode 100644 torchprime/launcher/save_hf_tokenizer_and_model.py diff --git a/.github/workflows/cpu_test.yml b/.github/workflows/cpu_test.yml index b6907764..bb9ee9d6 100644 --- a/.github/workflows/cpu_test.yml +++ b/.github/workflows/cpu_test.yml @@ -43,7 +43,7 @@ jobs: - name: Run PyTest env: # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - HF_TOKEN: ${{ secrets.HF_TOKEN }} + # HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | export PJRT_DEVICE=CPU export JAX_PLATFORMS=cpu @@ -52,7 +52,7 @@ jobs: - name: Run model forward env: # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - HF_TOKEN: ${{ secrets.HF_TOKEN }} + # HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | export PJRT_DEVICE=CPU export JAX_PLATFORMS=cpu diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 0d3828a6..37665a7e 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -68,7 +68,6 @@ jobs: - name: Run Llama 3.0 8B id: run-llama-3-8b env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -78,6 +77,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ dataset=wikitext \ task=train \ task.global_batch_size=8 \ @@ -112,7 +112,6 @@ jobs: - name: Run Llama 3.1 8B (Splash Attention) id: run-llama-3_1-8b-SplashAttention env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -122,6 +121,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3.1-405B \ model.attention_kernel=splash_attention \ dataset=wikitext \ task=train \ @@ -134,7 +134,6 @@ jobs: - name: Run Llama 3.1 8B (Scan + Offload) id: run-llama-3_1-8b-scan-offload env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -144,6 +143,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3.1-405B \ model/remat=llama-scan-offload \ dataset=wikitext \ task=train \ @@ -156,7 +156,6 @@ jobs: - name: Run Llama 3.0 8B (2D sharding) id: run-llama-3-8b-2d env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -166,6 +165,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ model/sharding=llama-fsdp-tp \ dataset=wikitext \ task=train \ @@ -179,7 +179,6 @@ jobs: - name: Run Llama 3.0 8B (fsdp + cp) id: run-llama-3-8b-fsdp-cp env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -189,6 +188,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b-cp \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ model/sharding=llama-fsdp-tp-cp \ dataset=wikitext \ task=train \ @@ -201,7 +201,6 @@ jobs: - name: Run Mixtral 8x7B id: run-mixtral-8x7b env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -211,6 +210,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=mixtral-8x7b \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/mistralai--Mixtral-8x7B-v0.1 \ model.num_hidden_layers=16 \ dataset=wikitext \ task=train \ @@ -223,7 +223,6 @@ jobs: - name: Run Llama 3.0 8B (2 slice) id: run-llama-3-8b-2-slice env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -233,6 +232,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ @@ -247,7 +247,6 @@ jobs: - name: Run Llama 3.0 8B SFT id: run-llama-3-8b-sft env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -257,6 +256,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ --config-name llama-3-8b-sft-w-gsm8k \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ ici_mesh.fsdp=4 \ task.max_steps=50 \ task.convert_to_safetensors=False \ @@ -265,7 +265,6 @@ jobs: - name: Run Llama 3.0 8B (ddp + fsdp) id: run-llama-3-8b-ddp-fsdp env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -275,6 +274,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ + model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ diff --git a/.github/workflows/reusable_e2e_check.yml b/.github/workflows/reusable_e2e_check.yml index f7faf677..fce9ee9a 100644 --- a/.github/workflows/reusable_e2e_check.yml +++ b/.github/workflows/reusable_e2e_check.yml @@ -31,8 +31,8 @@ on: GCP_SA_KEY: required: true # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - HF_TOKEN: - required: true + # HF_TOKEN: + # required: true jobs: results: diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index a5134337..765556a1 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -17,12 +17,14 @@ import click import toml from dataclasses_json import dataclass_json +from huggingface_hub.errors import RepositoryNotFoundError from pathspec import PathSpec from pathspec.patterns import GitWildMatchPattern # type: ignore from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer import torchprime.launcher.doctor +from torchprime.launcher import save_hf_tokenizer_and_model from torchprime.launcher.buildpush import buildpush from torchprime.launcher.util import run_docker @@ -77,6 +79,43 @@ def cli(ctx, interactive): ctx.obj["interactive"] = interactive +@cli.command("save-hf-tokenizers-to-gcs") +@click.option( + "--tokenizer-name", + type=str, + required=True, + multiple=True, + help="Hugging Face tokenizer name (e.g., 'meta-llama/Llama-3-8B-hf'). Can be specified multiple times.", +) +@click.option( + "--gcs-base-path", + type=str, + required=True, + help="Base GCS path for the tokenizers (e.g., 'gs://bucket/tokenizers').", +) +def save_hf_tokenizers_to_gcs(tokenizer_name: tuple[str], gcs_base_path: str): + """ + Downloads one or more tokenizers from Hugging Face Hub and saves them to a + Google Cloud Storage (GCS) bucket. + """ + for name in tokenizer_name: + # Create a safe directory name from the repo ID by replacing slashes + safe_dir_name = name.replace("/", "--") + gcs_path = f"{gcs_base_path.rstrip('/')}/{safe_dir_name}" + + click.echo(f"\nPreparing to save tokenizer from '{name}' to '{gcs_path}'...") + try: + save_hf_tokenizer_and_model.save_tokenizer_to_gcs(name, gcs_path) + click.secho(f" -> Successfully saved tokenizer to {gcs_path}", fg="green") + except RepositoryNotFoundError: + click.secho(f"\n❌ Error: Tokenizer '{name}' not found.", fg="red") + click.echo("Please check the following:") + click.echo(f"1. The tokenizer name '{name}' is spelled correctly.") + click.echo("2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN.") + except Exception as e: + click.secho(f"\n❌ An unexpected error occurred for tokenizer '{name}': {e}", fg="red") + + @cli.command() @click.option("--cluster", required=True, help="Name of the XPK cluster") @click.option("--project", required=True, help="GCP project the cluster belongs to") diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py new file mode 100644 index 00000000..c3befeb7 --- /dev/null +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -0,0 +1,46 @@ +"""Utilities for preparing Hugging Face assets (models and tokenizers) for GCS.""" + +import logging +import os +import subprocess +import tempfile +from pathlib import Path + +from transformers import AutoTokenizer + +logger = logging.getLogger(__name__) + + +def _upload_directory_to_gcs(local_path: Path, gcs_path: str): + """Uploads the contents of a local directory to GCS using gsutil.""" + if not gcs_path.startswith("gs://"): + raise ValueError("GCS path must start with gs://") + + logger.info(f"Uploading contents of '{local_path}' to '{gcs_path}'...") + # Using gsutil for efficient, parallel uploads. + # The '/*' at the end of local_path ensures the contents are copied, not the directory itself. + command = ["gsutil", "-m", "cp", "-r", f"{str(local_path).rstrip('/')}/*", gcs_path] + try: + subprocess.run(command, check=True, capture_output=True, text=True) + logger.info(f"Successfully uploaded assets to {gcs_path}.") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to upload to GCS. Error: {e.stderr}") + raise + + +def save_tokenizer_to_gcs(tokenizer_name: str, gcs_path: str): + """Downloads a tokenizer from a Hugging Face repo and uploads to GCS.""" + with tempfile.TemporaryDirectory() as tmpdir: + local_path = Path(tmpdir) + logger.info(f"Created temporary directory: {local_path}") + + # Re-saving the tokenizer ensures it's in a standardized format, which is a good practice. + # This will only download tokenizer-related files, not the large model weights. + logger.info(f"Standardizing tokenizer for '{tokenizer_name}'...") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ.get("HF_TOKEN")) + tokenizer.save_pretrained(str(local_path)) + + logger.info(f"Tokenizer for '{tokenizer_name}' downloaded and prepared locally.") + + # Upload the directory contents to the specified GCS path. + _upload_directory_to_gcs(local_path, gcs_path) \ No newline at end of file diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 0f9fa1b9..536506a4 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import atexit import importlib import json import logging @@ -31,6 +32,20 @@ "generation_config.json", ] +_TEMP_DIRS_TO_CLEAN = [] + + +def _cleanup_temp_dirs(): + """Removes all temporary directories created by this module.""" + for d in _TEMP_DIRS_TO_CLEAN: + try: + logger.info(f"Cleaning up temporary directory: {d}") + shutil.rmtree(d) + except OSError as e: + logger.warning(f"Failed to remove temporary directory {d}: {e}") + +atexit.register(_cleanup_temp_dirs) + def load_safetensors_to_state_dict(model_dir: str) -> dict: """Load a model state dict from safetensors, supporting both sharded and single-file formats. @@ -484,3 +499,25 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) tokenizer.save_pretrained(save_dir) + + +def download_gcs_dir_if_needed(path_or_repo: str) -> str: + """If a path is a GCS path, download it to a local temp dir and return the local path.""" + if path_or_repo.startswith("gs://"): + local_dir = tempfile.mkdtemp() + _TEMP_DIRS_TO_CLEAN.append(local_dir) + logger.info(f"Downloading {path_or_repo} to temporary directory {local_dir}") + + # Using gsutil for efficient, parallel downloads. + # The '/*' at the end of the GCS path ensures the contents are copied, not the directory itself. + command = ["gsutil", "-m", "cp", "-r", path_or_repo.rstrip("/") + "/*", local_dir] + try: + subprocess.run(command, check=True, capture_output=True, text=True) + logger.info(f"Successfully downloaded assets from {path_or_repo}.") + return local_dir + except (subprocess.CalledProcessError, FileNotFoundError) as e: + stderr = getattr(e, "stderr", str(e)) + logger.error(f"Failed to download from GCS using gsutil. Error: {stderr}") + raise RuntimeError(f"Could not download {path_or_repo}") from e + + return path_or_repo diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 534a37f0..d25a832f 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -51,8 +51,9 @@ def main(config: omegaconf.DictConfig): # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. tokenizer_name = config.model.tokenizer_name + local_tokenizer_path = model_utils.download_gcs_dir_if_needed(tokenizer_name) tokenizer = retry.retry( - lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_name) + lambda: transformers.AutoTokenizer.from_pretrained(local_tokenizer_path) ) assert config.torch_dtype == "bfloat16", "Currently only bfloat16 is supported" From aefad4a4f315f85e07d1dece83276c3f02fae97f Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Tue, 5 Aug 2025 19:05:00 +0000 Subject: [PATCH 02/19] Ruff formatting --- torchprime/launcher/cli.py | 8 ++++++-- torchprime/launcher/save_hf_tokenizer_and_model.py | 6 ++++-- torchprime/torch_xla_models/model/model_utils.py | 1 + 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 765556a1..9aa1259d 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -111,9 +111,13 @@ def save_hf_tokenizers_to_gcs(tokenizer_name: tuple[str], gcs_base_path: str): click.secho(f"\n❌ Error: Tokenizer '{name}' not found.", fg="red") click.echo("Please check the following:") click.echo(f"1. The tokenizer name '{name}' is spelled correctly.") - click.echo("2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN.") + click.echo( + "2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN." + ) except Exception as e: - click.secho(f"\n❌ An unexpected error occurred for tokenizer '{name}': {e}", fg="red") + click.secho( + f"\n❌ An unexpected error occurred for tokenizer '{name}': {e}", fg="red" + ) @cli.command() diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index c3befeb7..aaea0778 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -37,10 +37,12 @@ def save_tokenizer_to_gcs(tokenizer_name: str, gcs_path: str): # Re-saving the tokenizer ensures it's in a standardized format, which is a good practice. # This will only download tokenizer-related files, not the large model weights. logger.info(f"Standardizing tokenizer for '{tokenizer_name}'...") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ.get("HF_TOKEN")) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, token=os.environ.get("HF_TOKEN") + ) tokenizer.save_pretrained(str(local_path)) logger.info(f"Tokenizer for '{tokenizer_name}' downloaded and prepared locally.") # Upload the directory contents to the specified GCS path. - _upload_directory_to_gcs(local_path, gcs_path) \ No newline at end of file + _upload_directory_to_gcs(local_path, gcs_path) diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 536506a4..48c97399 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -44,6 +44,7 @@ def _cleanup_temp_dirs(): except OSError as e: logger.warning(f"Failed to remove temporary directory {d}: {e}") + atexit.register(_cleanup_temp_dirs) From 0560ffb099c153fde85ee707e3cb9f4fc4fcfe03 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 6 Aug 2025 21:19:10 +0000 Subject: [PATCH 03/19] Save pretrained weights for SFT e2e --- .github/workflows/e2e_test.yml | 1 + torchprime/launcher/cli.py | 39 ++++++++++ .../launcher/save_hf_tokenizer_and_model.py | 34 ++++++++ .../torch_xla_models/model/base_causal_lm.py | 11 +-- .../torch_xla_models/model/model_utils.py | 77 +++++++++++++++---- 5 files changed, 141 insertions(+), 21 deletions(-) diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 37665a7e..272216fb 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -256,6 +256,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ --config-name llama-3-8b-sft-w-gsm8k \ + model.pretrained_model=gs://torchprime/jackoh-exp/models/Llama-3-8B \ model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ ici_mesh.fsdp=4 \ task.max_steps=50 \ diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 9aa1259d..b1113300 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -120,6 +120,45 @@ def save_hf_tokenizers_to_gcs(tokenizer_name: tuple[str], gcs_base_path: str): ) +@cli.command("save-hf-model-to-gcs") +@click.option( + "--model-name", + type=str, + required=True, + help="Hugging Face model name (e.g., 'meta-llama/Llama-3-8B').", +) +@click.option( + "--gcs-path", + type=str, + required=True, + help="Target GCS path for the model (e.g., 'gs://bucket/models/Llama-3-8B').", +) +@click.option( + "--temp-dir", + type=str, + default=None, + help="Path to a temporary directory with sufficient space. Defaults to system temp.", +) +def save_hf_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None): + """ + Downloads a full model from Hugging Face Hub and saves it to a + Google Cloud Storage (GCS) bucket. + """ + click.echo(f"Preparing to save model from '{model_name}' to '{gcs_path}'...") + try: + save_hf_tokenizer_and_model.save_model_to_gcs(model_name, gcs_path, temp_dir=temp_dir) + click.secho(f" -> Successfully saved model to {gcs_path}", fg="green") + except RepositoryNotFoundError: + click.secho(f"\n❌ Error: Model '{model_name}' not found.", fg="red") + click.echo("Please check the following:") + click.echo(f"1. The model name '{model_name}' is spelled correctly.") + click.echo("2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN.") + except Exception as e: + click.secho( + f"\n❌ An unexpected error occurred for model '{model_name}': {e}", fg="red" + ) + + @cli.command() @click.option("--cluster", required=True, help="Name of the XPK cluster") @click.option("--project", required=True, help="GCP project the cluster belongs to") diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index aaea0778..75851b58 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -6,6 +6,7 @@ import tempfile from pathlib import Path +from huggingface_hub import snapshot_download from transformers import AutoTokenizer logger = logging.getLogger(__name__) @@ -46,3 +47,36 @@ def save_tokenizer_to_gcs(tokenizer_name: str, gcs_path: str): # Upload the directory contents to the specified GCS path. _upload_directory_to_gcs(local_path, gcs_path) + +def save_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None = None): + """Downloads a model from a Hugging Face repo and uploads to GCS.""" + with tempfile.TemporaryDirectory(dir=temp_dir) as tmpdir: + # tmpdir is the root for the temporary cache. + logger.info(f"Created temporary directory: {tmpdir}") + + logger.info(f"Downloading model snapshot for '{model_name}'...") + # We use the temporary directory as the cache_dir to ensure that the + # files are downloaded directly into it. This avoids a copy operation + # from the default Hugging Face cache (~/.cache/huggingface) to /tmp, + # which can fail if /tmp and /home are on different filesystems and /tmp + # has limited space. The function returns the path to the actual snapshot directory. + # We explicitly list the patterns for files we need to ensure we don't download + # unnecessary files like READMEs or large non-safetensors model weights. + allow_patterns = [ + "*.safetensors*", + "config.json", + "generation_config.json", + "tokenizer*.json", + "special_tokens_map.json", + ] + snapshot_path = snapshot_download( + repo_id=model_name, + cache_dir=str(tmpdir), + token=os.environ.get("HF_TOKEN"), + allow_patterns=allow_patterns, + ) + + logger.info(f"Model '{model_name}' downloaded locally to '{snapshot_path}'.") + + # Upload the directory contents to the specified GCS path. + _upload_directory_to_gcs(Path(snapshot_path), gcs_path) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 5dc5efbe..1b683e11 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,14 +113,15 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - if os.path.isdir(model_path_or_repo): - model_dir = model_path_or_repo - else: + # Convert GCS path to gcsfuse path if necessary. This allows us to treat + # GCS paths like local directories for the subsequent logic. + model_dir = model_utils.download_gcs_dir_if_needed(model_path_or_repo) + + if not os.path.isdir(model_dir): model_dir = huggingface_hub.snapshot_download( - repo_id=model_path_or_repo, + repo_id=model_dir, allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, ) - # Load weights state_dict = model_utils.load_safetensors_to_state_dict(model_dir) self.load_state_dict(state_dict) diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 48c97399..6c1ada29 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -463,6 +463,9 @@ def copy_hf_config_files(model_path_or_repo: str, save_dir: Path) -> None: """ patterns = HF_MODEL_CONFIG_FILES + # Convert GCS path to gcsfuse path if necessary. This allows us to treat + # GCS paths like local directories for the subsequent logic. + model_path_or_repo = download_gcs_dir_if_needed(model_path_or_repo) if os.path.isdir(model_path_or_repo): model_dir = Path(model_path_or_repo) else: @@ -496,6 +499,8 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: model repo ID (e.g., "meta-llama/Llama-2-7b-hf"). save_dir: Directory where the tokenizer files will be saved. """ + # If it's a GCS path, convert it to the gcsfuse mount path. + model_path_or_repo = download_gcs_dir_if_needed(model_path_or_repo) tokenizer = AutoTokenizer.from_pretrained(model_path_or_repo) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -503,22 +508,62 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: def download_gcs_dir_if_needed(path_or_repo: str) -> str: - """If a path is a GCS path, download it to a local temp dir and return the local path.""" - if path_or_repo.startswith("gs://"): + """Resolves a GCS path to a local path, trying gcsfuse first and falling back to download.""" + if not path_or_repo.startswith("gs://"): + return path_or_repo + + from urllib.parse import urlparse + + # Consistently parse the GCS path to get the path inside the bucket. + path_inside_bucket = urlparse(path_or_repo).path.lstrip("/") + + # Strategy 1: Try to use the gcsfuse mount point. This is the most efficient. + fuse_path = os.path.join("/tmp/gcs-mount", path_inside_bucket) + if os.path.exists(fuse_path): + logger.info("Found existing gcsfuse mount for %s at %s", path_or_repo, fuse_path) + return fuse_path + + # Strategy 2: Fallback to downloading if the fuse mount doesn't exist. + logger.warning( + "gcsfuse path %s not found. Falling back to downloading from %s.", + fuse_path, + path_or_repo, + ) + from google.cloud import storage + + try: local_dir = tempfile.mkdtemp() _TEMP_DIRS_TO_CLEAN.append(local_dir) - logger.info(f"Downloading {path_or_repo} to temporary directory {local_dir}") - # Using gsutil for efficient, parallel downloads. - # The '/*' at the end of the GCS path ensures the contents are copied, not the directory itself. - command = ["gsutil", "-m", "cp", "-r", path_or_repo.rstrip("/") + "/*", local_dir] - try: - subprocess.run(command, check=True, capture_output=True, text=True) - logger.info(f"Successfully downloaded assets from {path_or_repo}.") - return local_dir - except (subprocess.CalledProcessError, FileNotFoundError) as e: - stderr = getattr(e, "stderr", str(e)) - logger.error(f"Failed to download from GCS using gsutil. Error: {stderr}") - raise RuntimeError(f"Could not download {path_or_repo}") from e - - return path_or_repo + # Parse GCS path to get bucket and prefix for listing blobs. + parsed_url = urlparse(path_or_repo, scheme="gs") + bucket_name = parsed_url.netloc + prefix = path_inside_bucket + if prefix and not prefix.endswith("/"): + prefix += "/" + + storage_client = storage.Client() + blobs = list(storage_client.list_blobs(bucket_name, prefix=prefix)) + + if not blobs: + raise FileNotFoundError(f"No objects found at GCS path: {path_or_repo}") + + for blob in blobs: + # Recreate the directory structure locally. + relative_path = os.path.relpath(blob.name, prefix) + dest_path = os.path.join(local_dir, relative_path) + + # Handle subdirectories explicitly. + if blob.name.endswith("/"): + os.makedirs(dest_path, exist_ok=True) + continue + + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + blob.download_to_filename(dest_path) + + logger.info("Successfully downloaded assets from %s to %s.", path_or_repo, local_dir) + return local_dir + except Exception as e: + logger.error("Failed to download from GCS using google-cloud-storage. Error: %s", e) + shutil.rmtree(local_dir) # Clean up failed download + raise RuntimeError(f"Could not download {path_or_repo}") from e From 5feb9dd1b85c0a4d26bc4278c394a3ee31e2e5c2 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 6 Aug 2025 21:24:59 +0000 Subject: [PATCH 04/19] Formatting --- torchprime/launcher/cli.py | 8 ++++++-- torchprime/launcher/save_hf_tokenizer_and_model.py | 13 +++++++------ torchprime/torch_xla_models/model/model_utils.py | 4 +++- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index b1113300..ec9e9cc6 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -146,13 +146,17 @@ def save_hf_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None): """ click.echo(f"Preparing to save model from '{model_name}' to '{gcs_path}'...") try: - save_hf_tokenizer_and_model.save_model_to_gcs(model_name, gcs_path, temp_dir=temp_dir) + save_hf_tokenizer_and_model.save_model_to_gcs( + model_name, gcs_path, temp_dir=temp_dir + ) click.secho(f" -> Successfully saved model to {gcs_path}", fg="green") except RepositoryNotFoundError: click.secho(f"\n❌ Error: Model '{model_name}' not found.", fg="red") click.echo("Please check the following:") click.echo(f"1. The model name '{model_name}' is spelled correctly.") - click.echo("2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN.") + click.echo( + "2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN." + ) except Exception as e: click.secho( f"\n❌ An unexpected error occurred for model '{model_name}': {e}", fg="red" diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index 75851b58..2266016c 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -47,7 +47,8 @@ def save_tokenizer_to_gcs(tokenizer_name: str, gcs_path: str): # Upload the directory contents to the specified GCS path. _upload_directory_to_gcs(local_path, gcs_path) - + + def save_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None = None): """Downloads a model from a Hugging Face repo and uploads to GCS.""" with tempfile.TemporaryDirectory(dir=temp_dir) as tmpdir: @@ -63,11 +64,11 @@ def save_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None = Non # We explicitly list the patterns for files we need to ensure we don't download # unnecessary files like READMEs or large non-safetensors model weights. allow_patterns = [ - "*.safetensors*", - "config.json", - "generation_config.json", - "tokenizer*.json", - "special_tokens_map.json", + "*.safetensors*", + "config.json", + "generation_config.json", + "tokenizer*.json", + "special_tokens_map.json", ] snapshot_path = snapshot_download( repo_id=model_name, diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 6c1ada29..aa710d3b 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -561,7 +561,9 @@ def download_gcs_dir_if_needed(path_or_repo: str) -> str: os.makedirs(os.path.dirname(dest_path), exist_ok=True) blob.download_to_filename(dest_path) - logger.info("Successfully downloaded assets from %s to %s.", path_or_repo, local_dir) + logger.info( + "Successfully downloaded assets from %s to %s.", path_or_repo, local_dir + ) return local_dir except Exception as e: logger.error("Failed to download from GCS using google-cloud-storage. Error: %s", e) From a0f68c1af65f167ba481ace496d7a7f7a7dab901 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 7 Aug 2025 00:57:45 +0000 Subject: [PATCH 05/19] Fix issue in loading weights from gcs --- .../torch_xla_models/model/base_causal_lm.py | 64 +++++++++++++++---- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 1b683e11..614e8615 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -10,6 +10,7 @@ import json import logging import os +import tempfile from pathlib import Path import huggingface_hub @@ -107,23 +108,60 @@ def from_pretrained(self, model_path_or_repo: str): It supports both local directories and remote repositories hosted on the Hugging Face Hub. Note: - In distributed training setups, ensure that all replicas perform the loading operation - to synchronize model weights across processes. + In distributed training setups, to avoid I/O contention on shared + filesystems, only rank 0 loads data from disk. The state dict is then + broadcast to all other ranks. Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - # Convert GCS path to gcsfuse path if necessary. This allows us to treat - # GCS paths like local directories for the subsequent logic. - model_dir = model_utils.download_gcs_dir_if_needed(model_path_or_repo) - - if not os.path.isdir(model_dir): - model_dir = huggingface_hub.snapshot_download( - repo_id=model_dir, - allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, - ) - # Load weights - state_dict = model_utils.load_safetensors_to_state_dict(model_dir) + tmp_path = "" + state_dict = None + if xr.process_index() == 0: + # On rank 0, resolve the model path (which may involve downloading) + # and load the state dictionary from disk. + model_dir = model_utils.download_gcs_dir_if_needed(model_path_or_repo) + logger.info("Loading model from %s", model_dir) + + if not os.path.isdir(model_dir): + model_dir = huggingface_hub.snapshot_download( + repo_id=model_dir, + allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, + ) + + if os.path.isdir(model_dir): + logger.info("Listing contents of model directory: %s", model_dir) + for root, _, files in os.walk(model_dir): + relative_root = os.path.relpath(root, model_dir) + for name in files: + logger.info( + " - %s", os.path.join(relative_root, name).lstrip("./") + ) + + logger.info("Loading safetensors on rank 0...") + state_dict = model_utils.load_safetensors_to_state_dict(model_dir) + logger.info("Finished loading safetensors on rank 0.") + + # Save state_dict to a temporary file on a fast local filesystem + # to avoid broadcast issues with large objects. + with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp: + tmp_path = tmp.name + logger.info("Saving state_dict to temporary file: %s", tmp_path) + torch.save(state_dict, tmp_path) + del state_dict # Free memory + + if xr.process_count() > 1: + # Rendezvous can be used to broadcast small amounts of data like a path string. + tmp_path = xm.rendezvous("broadcast_tmp_path", tmp_path) + + state_dict = torch.load(tmp_path, map_location="cpu") + + if xr.process_count() > 1: + xm.rendezvous("wait_for_state_dict_load") + + if xr.process_index() == 0: + os.remove(tmp_path) + self.load_state_dict(state_dict) def _maybe_save_checkpoint(self, config: DictConfig) -> None: From dca490cd59db1100f699573e9739288cff0c7469 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 7 Aug 2025 00:59:20 +0000 Subject: [PATCH 06/19] Formatting --- torchprime/torch_xla_models/model/base_causal_lm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 614e8615..870e02be 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -134,9 +134,7 @@ def from_pretrained(self, model_path_or_repo: str): for root, _, files in os.walk(model_dir): relative_root = os.path.relpath(root, model_dir) for name in files: - logger.info( - " - %s", os.path.join(relative_root, name).lstrip("./") - ) + logger.info(" - %s", os.path.join(relative_root, name).lstrip("./")) logger.info("Loading safetensors on rank 0...") state_dict = model_utils.load_safetensors_to_state_dict(model_dir) From 19da2b1d774acc98648c82da8b9c29fe79650061 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 7 Aug 2025 22:29:30 +0000 Subject: [PATCH 07/19] Refactor --- .github/workflows/e2e_test.yml | 20 ++--- torchprime/launcher/cli.py | 78 +++++------------ .../launcher/save_hf_tokenizer_and_model.py | 72 +++++++-------- torchprime/tools/prepare_gcs_model_files.py | 87 +++++++++++++++++++ .../torch_xla_models/model/base_causal_lm.py | 62 ++++--------- .../torch_xla_models/model/model_utils.py | 73 +++++----------- .../torch_xla_models/trainer/sft_trainer.py | 1 + 7 files changed, 189 insertions(+), 204 deletions(-) create mode 100644 torchprime/tools/prepare_gcs_model_files.py diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 272216fb..06e52160 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -77,7 +77,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ dataset=wikitext \ task=train \ task.global_batch_size=8 \ @@ -121,7 +121,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3.1-405B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3.1-405b \ model.attention_kernel=splash_attention \ dataset=wikitext \ task=train \ @@ -143,7 +143,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3.1-405B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3.1-405b \ model/remat=llama-scan-offload \ dataset=wikitext \ task=train \ @@ -165,7 +165,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ model/sharding=llama-fsdp-tp \ dataset=wikitext \ task=train \ @@ -188,7 +188,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b-cp \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ model/sharding=llama-fsdp-tp-cp \ dataset=wikitext \ task=train \ @@ -210,7 +210,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=mixtral-8x7b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/mistralai--Mixtral-8x7B-v0.1 \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/mixtral-8x7b-v0.1/ \ model.num_hidden_layers=16 \ dataset=wikitext \ task=train \ @@ -232,7 +232,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ @@ -256,8 +256,8 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ --config-name llama-3-8b-sft-w-gsm8k \ - model.pretrained_model=gs://torchprime/jackoh-exp/models/Llama-3-8B \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ + model.pretrained_model=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ ici_mesh.fsdp=4 \ task.max_steps=50 \ task.convert_to_safetensors=False \ @@ -275,7 +275,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ - model.tokenizer_name=gs://torchprime/jackoh-exp/tokenizers/meta-llama--Meta-Llama-3-8B \ + model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index ec9e9cc6..48376df1 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -79,59 +79,24 @@ def cli(ctx, interactive): ctx.obj["interactive"] = interactive -@cli.command("save-hf-tokenizers-to-gcs") +@cli.command("save-hf-model-files-to-gcs") @click.option( - "--tokenizer-name", + "--repo-id", type=str, required=True, - multiple=True, - help="Hugging Face tokenizer name (e.g., 'meta-llama/Llama-3-8B-hf'). Can be specified multiple times.", + help="Hugging Face model or tokenizer repo ID (e.g., 'meta-llama/Llama-3-8B-hf').", ) @click.option( - "--gcs-base-path", - type=str, - required=True, - help="Base GCS path for the tokenizers (e.g., 'gs://bucket/tokenizers').", -) -def save_hf_tokenizers_to_gcs(tokenizer_name: tuple[str], gcs_base_path: str): - """ - Downloads one or more tokenizers from Hugging Face Hub and saves them to a - Google Cloud Storage (GCS) bucket. - """ - for name in tokenizer_name: - # Create a safe directory name from the repo ID by replacing slashes - safe_dir_name = name.replace("/", "--") - gcs_path = f"{gcs_base_path.rstrip('/')}/{safe_dir_name}" - - click.echo(f"\nPreparing to save tokenizer from '{name}' to '{gcs_path}'...") - try: - save_hf_tokenizer_and_model.save_tokenizer_to_gcs(name, gcs_path) - click.secho(f" -> Successfully saved tokenizer to {gcs_path}", fg="green") - except RepositoryNotFoundError: - click.secho(f"\n❌ Error: Tokenizer '{name}' not found.", fg="red") - click.echo("Please check the following:") - click.echo(f"1. The tokenizer name '{name}' is spelled correctly.") - click.echo( - "2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN." - ) - except Exception as e: - click.secho( - f"\n❌ An unexpected error occurred for tokenizer '{name}': {e}", fg="red" - ) - - -@cli.command("save-hf-model-to-gcs") -@click.option( - "--model-name", + "--gcs-path", type=str, required=True, - help="Hugging Face model name (e.g., 'meta-llama/Llama-3-8B').", + help="Target GCS path for the model files (e.g., 'gs://bucket/models/Llama-3-8B-hf').", ) @click.option( - "--gcs-path", - type=str, - required=True, - help="Target GCS path for the model (e.g., 'gs://bucket/models/Llama-3-8B').", + "--file-type", + type=click.Choice(["tokenizer", "model", "all"], case_sensitive=False), + default="all", + help="Type of files to save. 'tokenizer' for tokenizer files, 'model' for model weights and configs, 'all' for both.", ) @click.option( "--temp-dir", @@ -139,27 +104,28 @@ def save_hf_tokenizers_to_gcs(tokenizer_name: tuple[str], gcs_base_path: str): default=None, help="Path to a temporary directory with sufficient space. Defaults to system temp.", ) -def save_hf_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None): - """ - Downloads a full model from Hugging Face Hub and saves it to a - Google Cloud Storage (GCS) bucket. - """ - click.echo(f"Preparing to save model from '{model_name}' to '{gcs_path}'...") +def save_hf_model_files_to_gcs( + repo_id: str, gcs_path: str, file_type: str, temp_dir: str | None +): + """Downloads model and tokenizer files from Hugging Face Hub and saves them to Google Cloud Storage.""" + click.echo( + f"Preparing to save '{file_type}' files from '{repo_id}' to '{gcs_path}'..." + ) try: - save_hf_tokenizer_and_model.save_model_to_gcs( - model_name, gcs_path, temp_dir=temp_dir + save_hf_tokenizer_and_model.save_hf_model_files_to_gcs( + repo_id, gcs_path, file_type=file_type, temp_dir=temp_dir ) - click.secho(f" -> Successfully saved model to {gcs_path}", fg="green") + click.secho(f" -> Successfully saved files to {gcs_path}", fg="green") except RepositoryNotFoundError: - click.secho(f"\n❌ Error: Model '{model_name}' not found.", fg="red") + click.secho(f"\n❌ Error: Repository '{repo_id}' not found.", fg="red") click.echo("Please check the following:") - click.echo(f"1. The model name '{model_name}' is spelled correctly.") + click.echo(f"1. The repository ID '{repo_id}' is spelled correctly.") click.echo( "2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN." ) except Exception as e: click.secho( - f"\n❌ An unexpected error occurred for model '{model_name}': {e}", fg="red" + f"\n❌ An unexpected error occurred for repository '{repo_id}': {e}", fg="red" ) diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index 2266016c..092fa65e 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -7,10 +7,24 @@ from pathlib import Path from huggingface_hub import snapshot_download -from transformers import AutoTokenizer logger = logging.getLogger(__name__) +TOKENIZER_PATTERNS = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "*.model", # For sentencepiece tokenizers + "vocab.txt", # For WordPiece/BERT tokenizers + "merges.txt", # For BPE tokenizers +] + +MODEL_PATTERNS = [ + "*.safetensors*", + "config.json", + "generation_config.json", +] + def _upload_directory_to_gcs(local_path: Path, gcs_path: str): """Uploads the contents of a local directory to GCS using gsutil.""" @@ -29,55 +43,35 @@ def _upload_directory_to_gcs(local_path: Path, gcs_path: str): raise -def save_tokenizer_to_gcs(tokenizer_name: str, gcs_path: str): - """Downloads a tokenizer from a Hugging Face repo and uploads to GCS.""" - with tempfile.TemporaryDirectory() as tmpdir: - local_path = Path(tmpdir) - logger.info(f"Created temporary directory: {local_path}") - - # Re-saving the tokenizer ensures it's in a standardized format, which is a good practice. - # This will only download tokenizer-related files, not the large model weights. - logger.info(f"Standardizing tokenizer for '{tokenizer_name}'...") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, token=os.environ.get("HF_TOKEN") - ) - tokenizer.save_pretrained(str(local_path)) - - logger.info(f"Tokenizer for '{tokenizer_name}' downloaded and prepared locally.") - - # Upload the directory contents to the specified GCS path. - _upload_directory_to_gcs(local_path, gcs_path) +def save_hf_model_files_to_gcs( + repo_id: str, + gcs_path: str, + file_type: str, + temp_dir: str | None = None, +): + """Downloads model and tokenizer files from a Hugging Face repo and uploads to GCS.""" + allow_patterns = [] + if file_type in ("tokenizer", "all"): + allow_patterns.extend(TOKENIZER_PATTERNS) + if file_type in ("model", "all"): + allow_patterns.extend(MODEL_PATTERNS) + if not allow_patterns: + raise ValueError("file_type must be one of 'tokenizer', 'model', or 'all'") -def save_model_to_gcs(model_name: str, gcs_path: str, temp_dir: str | None = None): - """Downloads a model from a Hugging Face repo and uploads to GCS.""" with tempfile.TemporaryDirectory(dir=temp_dir) as tmpdir: - # tmpdir is the root for the temporary cache. logger.info(f"Created temporary directory: {tmpdir}") - logger.info(f"Downloading model snapshot for '{model_name}'...") - # We use the temporary directory as the cache_dir to ensure that the - # files are downloaded directly into it. This avoids a copy operation - # from the default Hugging Face cache (~/.cache/huggingface) to /tmp, - # which can fail if /tmp and /home are on different filesystems and /tmp - # has limited space. The function returns the path to the actual snapshot directory. - # We explicitly list the patterns for files we need to ensure we don't download - # unnecessary files like READMEs or large non-safetensors model weights. - allow_patterns = [ - "*.safetensors*", - "config.json", - "generation_config.json", - "tokenizer*.json", - "special_tokens_map.json", - ] + logger.info(f"Downloading files for '{repo_id}' with patterns: {allow_patterns}") snapshot_path = snapshot_download( - repo_id=model_name, + repo_id=repo_id, cache_dir=str(tmpdir), token=os.environ.get("HF_TOKEN"), allow_patterns=allow_patterns, + ignore_patterns=["*.bin*"], # Avoid large pytorch_model.bin files ) - logger.info(f"Model '{model_name}' downloaded locally to '{snapshot_path}'.") + logger.info(f"Files for '{repo_id}' downloaded locally to '{snapshot_path}'.") # Upload the directory contents to the specified GCS path. _upload_directory_to_gcs(Path(snapshot_path), gcs_path) diff --git a/torchprime/tools/prepare_gcs_model_files.py b/torchprime/tools/prepare_gcs_model_files.py new file mode 100644 index 00000000..6f2bc524 --- /dev/null +++ b/torchprime/tools/prepare_gcs_model_files.py @@ -0,0 +1,87 @@ +""" +This script downloads specified model and tokenizer files from the Hugging Face Hub +and uploads them to a Google Cloud Storage (GCS) bucket. + +It is designed to prepare files required for training runs. + +Usage: + python torchprime/tools/prepare_gcs_model_files.py gs://your-bucket/your-path [--temp-dir /path/to/temp] +""" + +import os +import sys + +from torchprime.launcher import save_hf_tokenizer_and_model + +# --- Configuration --- +# List of models and the specific files to download for each. +# file_type can be 'tokenizer', 'model', or 'all'. +FILES_TO_PREPARE = [ + { + "repo_id": "meta-llama/Meta-Llama-3-8B", + "gcs_dir_name": "meta-llama-3-8b", + "file_type": "all", # Download model weights, config, and tokenizer + }, + { + "repo_id": "meta-llama/Meta-Llama-3.1-405B", + "gcs_dir_name": "meta-llama-3.1-405b", + "file_type": "tokenizer", + }, + { + "repo_id": "meta-llama/Llama-4-Scout-17B-16E", + "gcs_dir_name": "llama-4-scout-17b-16e", + "file_type": "tokenizer", + }, + { + "repo_id": "mistralai/Mixtral-8x7B-v0.1", + "gcs_dir_name": "mixtral-8x7b-v0.1", + "file_type": "tokenizer", + }, +] + + +def main(): + """Downloads and uploads specified Hugging Face files to GCS.""" + if len(sys.argv) < 2 or not sys.argv[1].startswith("gs://"): + print( + f"Usage: python {sys.argv[0]} gs://your-bucket/your-path [--temp-dir /path/to/temp]", + file=sys.stderr, + ) + sys.exit(1) + + gcs_base_path = sys.argv[1] + temp_dir = None + if "--temp-dir" in sys.argv: + try: + idx = sys.argv.index("--temp-dir") + temp_dir = sys.argv[idx + 1] + except IndexError: + print("Error: --temp-dir requires a path.", file=sys.stderr) + sys.exit(1) + + if not os.environ.get("HF_TOKEN"): + raise RuntimeError( + "The HF_TOKEN environment variable is not set. " + "Please run 'huggingface-cli login' or export your token." + ) + + print(f"--- Starting file preparation for GCS path: {gcs_base_path} ---") + + for i, file_info in enumerate(FILES_TO_PREPARE): + repo_id, gcs_dir, file_type = ( + file_info["repo_id"], + file_info["gcs_dir_name"], + file_info["file_type"], + ) + gcs_path = f"{gcs_base_path.rstrip('/')}/{gcs_dir}" + print(f"\n[{i + 1}/{len(FILES_TO_PREPARE)}] Processing '{repo_id}'...") + save_hf_tokenizer_and_model.save_hf_model_files_to_gcs( + repo_id=repo_id, gcs_path=gcs_path, file_type=file_type, temp_dir=temp_dir + ) + print(f" -> Successfully saved '{file_type}' files for '{repo_id}' to {gcs_path}") + + print("\n--- File preparation complete. ---") + + +if __name__ == "__main__": + main() diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 870e02be..b6d5cf22 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -10,7 +10,6 @@ import json import logging import os -import tempfile from pathlib import Path import huggingface_hub @@ -115,51 +114,19 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - tmp_path = "" - state_dict = None - if xr.process_index() == 0: - # On rank 0, resolve the model path (which may involve downloading) - # and load the state dictionary from disk. - model_dir = model_utils.download_gcs_dir_if_needed(model_path_or_repo) - logger.info("Loading model from %s", model_dir) - - if not os.path.isdir(model_dir): - model_dir = huggingface_hub.snapshot_download( - repo_id=model_dir, - allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, - ) - - if os.path.isdir(model_dir): - logger.info("Listing contents of model directory: %s", model_dir) - for root, _, files in os.walk(model_dir): - relative_root = os.path.relpath(root, model_dir) - for name in files: - logger.info(" - %s", os.path.join(relative_root, name).lstrip("./")) - - logger.info("Loading safetensors on rank 0...") - state_dict = model_utils.load_safetensors_to_state_dict(model_dir) - logger.info("Finished loading safetensors on rank 0.") - - # Save state_dict to a temporary file on a fast local filesystem - # to avoid broadcast issues with large objects. - with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp: - tmp_path = tmp.name - logger.info("Saving state_dict to temporary file: %s", tmp_path) - torch.save(state_dict, tmp_path) - del state_dict # Free memory - - if xr.process_count() > 1: - # Rendezvous can be used to broadcast small amounts of data like a path string. - tmp_path = xm.rendezvous("broadcast_tmp_path", tmp_path) - - state_dict = torch.load(tmp_path, map_location="cpu") - - if xr.process_count() > 1: - xm.rendezvous("wait_for_state_dict_load") - - if xr.process_index() == 0: - os.remove(tmp_path) - + local_path = model_utils.download_gcs_dir_if_needed(model_path_or_repo) + logger.info("Loading model from %s", local_path) + + if os.path.isdir(local_path): + model_dir = local_path + else: + model_dir = huggingface_hub.snapshot_download( + repo_id=local_path, + allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, + ) + + # Load weights + state_dict = model_utils.load_safetensors_to_state_dict(model_dir) self.load_state_dict(state_dict) def _maybe_save_checkpoint(self, config: DictConfig) -> None: @@ -191,7 +158,8 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: if xr.process_index() == 0: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) model_utils.copy_hf_config_files(config.model.pretrained_model, save_dir) - model_utils.save_hf_tokenizer(config.model.pretrained_model, save_dir) + # model_utils.save_hf_tokenizer(config.model.pretrained_model, save_dir) + model_utils.save_hf_tokenizer(config.model.tokenizer_name, save_dir) # Step 4: Initialize torch.distributed process group if not dist.is_initialized(): diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index aa710d3b..2ea83eb0 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -508,64 +508,33 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: def download_gcs_dir_if_needed(path_or_repo: str) -> str: - """Resolves a GCS path to a local path, trying gcsfuse first and falling back to download.""" + """Resolves a GCS path to a local path by downloading its contents using gsutil.""" if not path_or_repo.startswith("gs://"): return path_or_repo - from urllib.parse import urlparse - - # Consistently parse the GCS path to get the path inside the bucket. - path_inside_bucket = urlparse(path_or_repo).path.lstrip("/") - - # Strategy 1: Try to use the gcsfuse mount point. This is the most efficient. - fuse_path = os.path.join("/tmp/gcs-mount", path_inside_bucket) - if os.path.exists(fuse_path): - logger.info("Found existing gcsfuse mount for %s at %s", path_or_repo, fuse_path) - return fuse_path - - # Strategy 2: Fallback to downloading if the fuse mount doesn't exist. - logger.warning( - "gcsfuse path %s not found. Falling back to downloading from %s.", - fuse_path, - path_or_repo, - ) - from google.cloud import storage + if not shutil.which("gsutil"): + raise RuntimeError( + "gsutil command not found, but is required for downloading from GCS. " + "Please install the Google Cloud SDK." + ) + local_dir = tempfile.mkdtemp() + _TEMP_DIRS_TO_CLEAN.append(local_dir) try: - local_dir = tempfile.mkdtemp() - _TEMP_DIRS_TO_CLEAN.append(local_dir) - - # Parse GCS path to get bucket and prefix for listing blobs. - parsed_url = urlparse(path_or_repo, scheme="gs") - bucket_name = parsed_url.netloc - prefix = path_inside_bucket - if prefix and not prefix.endswith("/"): - prefix += "/" - - storage_client = storage.Client() - blobs = list(storage_client.list_blobs(bucket_name, prefix=prefix)) - - if not blobs: - raise FileNotFoundError(f"No objects found at GCS path: {path_or_repo}") - - for blob in blobs: - # Recreate the directory structure locally. - relative_path = os.path.relpath(blob.name, prefix) - dest_path = os.path.join(local_dir, relative_path) - - # Handle subdirectories explicitly. - if blob.name.endswith("/"): - os.makedirs(dest_path, exist_ok=True) - continue - - os.makedirs(os.path.dirname(dest_path), exist_ok=True) - blob.download_to_filename(dest_path) + gcs_path = path_or_repo.rstrip("/") + "/*" + command = ["gsutil", "-m", "cp", "-r", gcs_path, local_dir] + subprocess.run(command, check=True) logger.info( - "Successfully downloaded assets from %s to %s.", path_or_repo, local_dir + "Successfully downloaded files from %s to %s using gsutil.", + path_or_repo, + local_dir, ) return local_dir - except Exception as e: - logger.error("Failed to download from GCS using google-cloud-storage. Error: %s", e) - shutil.rmtree(local_dir) # Clean up failed download - raise RuntimeError(f"Could not download {path_or_repo}") from e + except (subprocess.CalledProcessError, Exception): + # Clean up the partially created directory on failure. + shutil.rmtree(local_dir) + logger.error( + "gsutil download failed for %s. See gsutil output above for details.", + path_or_repo, + ) diff --git a/torchprime/torch_xla_models/trainer/sft_trainer.py b/torchprime/torch_xla_models/trainer/sft_trainer.py index 1bedbffb..da553084 100644 --- a/torchprime/torch_xla_models/trainer/sft_trainer.py +++ b/torchprime/torch_xla_models/trainer/sft_trainer.py @@ -38,6 +38,7 @@ def __init__( if xr.process_index() == 0: logger.info("Loading model weights from %s", self.pretrained_model) model.from_pretrained(self.pretrained_model) + logger.info("Model loaded from %s", self.pretrained_model) xm.mark_step() else: logger.warning( From 1cf724b806b1d0607010d585cf2a7aae9a4b5b3d Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 7 Aug 2025 22:33:18 +0000 Subject: [PATCH 08/19] Remove unused lines --- .github/workflows/cpu_test.yml | 6 ------ torchprime/launcher/save_hf_tokenizer_and_model.py | 4 +--- torchprime/torch_xla_models/model/base_causal_lm.py | 1 - 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/.github/workflows/cpu_test.yml b/.github/workflows/cpu_test.yml index bb9ee9d6..8516a38a 100644 --- a/.github/workflows/cpu_test.yml +++ b/.github/workflows/cpu_test.yml @@ -41,18 +41,12 @@ jobs: python -m pip install --upgrade pip pip install -e '.[dev]' - name: Run PyTest - env: - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - # HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | export PJRT_DEVICE=CPU export JAX_PLATFORMS=cpu export CI=true pytest -v - name: Run model forward - env: - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - # HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | export PJRT_DEVICE=CPU export JAX_PLATFORMS=cpu diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index 092fa65e..e843fb25 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -32,14 +32,12 @@ def _upload_directory_to_gcs(local_path: Path, gcs_path: str): raise ValueError("GCS path must start with gs://") logger.info(f"Uploading contents of '{local_path}' to '{gcs_path}'...") - # Using gsutil for efficient, parallel uploads. - # The '/*' at the end of local_path ensures the contents are copied, not the directory itself. command = ["gsutil", "-m", "cp", "-r", f"{str(local_path).rstrip('/')}/*", gcs_path] try: subprocess.run(command, check=True, capture_output=True, text=True) logger.info(f"Successfully uploaded assets to {gcs_path}.") except subprocess.CalledProcessError as e: - logger.error(f"Failed to upload to GCS. Error: {e.stderr}") + logger.error(f"Failed to upload {local_path} to {gcs_path}. Error: {e.stderr}") raise diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index b6d5cf22..8d6c6833 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -158,7 +158,6 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: if xr.process_index() == 0: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) model_utils.copy_hf_config_files(config.model.pretrained_model, save_dir) - # model_utils.save_hf_tokenizer(config.model.pretrained_model, save_dir) model_utils.save_hf_tokenizer(config.model.tokenizer_name, save_dir) # Step 4: Initialize torch.distributed process group From f23f46d156b7a6233d5f476e0318b0af196e249d Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 7 Aug 2025 23:24:23 +0000 Subject: [PATCH 09/19] Refactor --- .github/workflows/reusable_e2e_check.yml | 3 --- torchprime/torch_xla_models/model/base_causal_lm.py | 5 ++--- torchprime/torch_xla_models/model/model_utils.py | 3 --- torchprime/torch_xla_models/train.py | 4 ++-- torchprime/torch_xla_models/trainer/sft_trainer.py | 1 - 5 files changed, 4 insertions(+), 12 deletions(-) diff --git a/.github/workflows/reusable_e2e_check.yml b/.github/workflows/reusable_e2e_check.yml index fce9ee9a..c282b83f 100644 --- a/.github/workflows/reusable_e2e_check.yml +++ b/.github/workflows/reusable_e2e_check.yml @@ -30,9 +30,6 @@ on: secrets: GCP_SA_KEY: required: true - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - # HF_TOKEN: - # required: true jobs: results: diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 8d6c6833..7bdec3d9 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -107,9 +107,8 @@ def from_pretrained(self, model_path_or_repo: str): It supports both local directories and remote repositories hosted on the Hugging Face Hub. Note: - In distributed training setups, to avoid I/O contention on shared - filesystems, only rank 0 loads data from disk. The state dict is then - broadcast to all other ranks. + In distributed training setups, ensure that all replicas perform the loading operation + to synchronize model weights across processes. Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 2ea83eb0..57918e3d 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -463,8 +463,6 @@ def copy_hf_config_files(model_path_or_repo: str, save_dir: Path) -> None: """ patterns = HF_MODEL_CONFIG_FILES - # Convert GCS path to gcsfuse path if necessary. This allows us to treat - # GCS paths like local directories for the subsequent logic. model_path_or_repo = download_gcs_dir_if_needed(model_path_or_repo) if os.path.isdir(model_path_or_repo): model_dir = Path(model_path_or_repo) @@ -499,7 +497,6 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: model repo ID (e.g., "meta-llama/Llama-2-7b-hf"). save_dir: Directory where the tokenizer files will be saved. """ - # If it's a GCS path, convert it to the gcsfuse mount path. model_path_or_repo = download_gcs_dir_if_needed(model_path_or_repo) tokenizer = AutoTokenizer.from_pretrained(model_path_or_repo) save_dir = Path(save_dir) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index d25a832f..bca302c0 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -51,9 +51,9 @@ def main(config: omegaconf.DictConfig): # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. tokenizer_name = config.model.tokenizer_name - local_tokenizer_path = model_utils.download_gcs_dir_if_needed(tokenizer_name) + tokenizer_path_or_repo = model_utils.download_gcs_dir_if_needed(tokenizer_name) tokenizer = retry.retry( - lambda: transformers.AutoTokenizer.from_pretrained(local_tokenizer_path) + lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path_or_repo) ) assert config.torch_dtype == "bfloat16", "Currently only bfloat16 is supported" diff --git a/torchprime/torch_xla_models/trainer/sft_trainer.py b/torchprime/torch_xla_models/trainer/sft_trainer.py index da553084..1bedbffb 100644 --- a/torchprime/torch_xla_models/trainer/sft_trainer.py +++ b/torchprime/torch_xla_models/trainer/sft_trainer.py @@ -38,7 +38,6 @@ def __init__( if xr.process_index() == 0: logger.info("Loading model weights from %s", self.pretrained_model) model.from_pretrained(self.pretrained_model) - logger.info("Model loaded from %s", self.pretrained_model) xm.mark_step() else: logger.warning( From c3edee25844eef410fa24cf6f50b78b546407c3a Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Fri, 8 Aug 2025 01:16:07 +0000 Subject: [PATCH 10/19] Use e2e gcs directory and refactored --- .github/workflows/e2e_test.yml | 20 ++--- torchprime/launcher/cli.py | 8 +- .../launcher/save_hf_tokenizer_and_model.py | 33 +++++-- torchprime/tools/prepare_gcs_model_files.py | 87 ------------------- .../torch_xla_models/model/base_causal_lm.py | 7 +- .../torch_xla_models/model/model_utils.py | 23 ++++- torchprime/torch_xla_models/train.py | 2 +- 7 files changed, 63 insertions(+), 117 deletions(-) delete mode 100644 torchprime/tools/prepare_gcs_model_files.py diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 06e52160..a71c83c8 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -77,7 +77,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ dataset=wikitext \ task=train \ task.global_batch_size=8 \ @@ -121,7 +121,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3.1-405b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \ model.attention_kernel=splash_attention \ dataset=wikitext \ task=train \ @@ -143,7 +143,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3.1-405b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \ model/remat=llama-scan-offload \ dataset=wikitext \ task=train \ @@ -165,7 +165,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model/sharding=llama-fsdp-tp \ dataset=wikitext \ task=train \ @@ -188,7 +188,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b-cp \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model/sharding=llama-fsdp-tp-cp \ dataset=wikitext \ task=train \ @@ -210,7 +210,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=mixtral-8x7b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/mixtral-8x7b-v0.1/ \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/mixtral-8x7b-v0.1/ \ model.num_hidden_layers=16 \ dataset=wikitext \ task=train \ @@ -232,7 +232,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ @@ -256,8 +256,8 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ --config-name llama-3-8b-sft-w-gsm8k \ - model.pretrained_model=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.pretrained_model=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ ici_mesh.fsdp=4 \ task.max_steps=50 \ task.convert_to_safetensors=False \ @@ -275,7 +275,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ - model.tokenizer_name=gs://torchprime/jackoh-exp/hf-models/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 48376df1..a7534527 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -115,18 +115,16 @@ def save_hf_model_files_to_gcs( save_hf_tokenizer_and_model.save_hf_model_files_to_gcs( repo_id, gcs_path, file_type=file_type, temp_dir=temp_dir ) - click.secho(f" -> Successfully saved files to {gcs_path}", fg="green") + click.echo(f" -> Successfully saved files to {gcs_path}") except RepositoryNotFoundError: - click.secho(f"\n❌ Error: Repository '{repo_id}' not found.", fg="red") + click.echo(f"\n❌ Error: Repository '{repo_id}' not found.") click.echo("Please check the following:") click.echo(f"1. The repository ID '{repo_id}' is spelled correctly.") click.echo( "2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN." ) except Exception as e: - click.secho( - f"\n❌ An unexpected error occurred for repository '{repo_id}': {e}", fg="red" - ) + click.echo(f"\n❌ An unexpected error occurred for repository '{repo_id}': {e}") @cli.command() diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index e843fb25..b4e50483 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -14,9 +14,9 @@ "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", - "*.model", # For sentencepiece tokenizers - "vocab.txt", # For WordPiece/BERT tokenizers - "merges.txt", # For BPE tokenizers + "*.model", + "vocab.txt", + "merges.txt", ] MODEL_PATTERNS = [ @@ -27,7 +27,12 @@ def _upload_directory_to_gcs(local_path: Path, gcs_path: str): - """Uploads the contents of a local directory to GCS using gsutil.""" + """Uploads the contents of a local directory to GCS using gsutil. + + Args: + local_path: The local directory whose contents will be uploaded. + gcs_path: The destination GCS path (e.g., 'gs://my-bucket/models/'). + """ if not gcs_path.startswith("gs://"): raise ValueError("GCS path must start with gs://") @@ -47,7 +52,23 @@ def save_hf_model_files_to_gcs( file_type: str, temp_dir: str | None = None, ): - """Downloads model and tokenizer files from a Hugging Face repo and uploads to GCS.""" + """Downloads model or tokenizer files from Hugging Face and uploads them to GCS. + + This function uses `huggingface_hub.snapshot_download` to fetch specific + files based on predefined patterns for models and tokenizers. The downloaded + files are then uploaded to the specified GCS path. + + Args: + repo_id: The ID of the Hugging Face repository (e.g., 'meta-llama/Llama-3-8B-hf'). + gcs_path: The target GCS path for the files (e.g., 'gs://bucket/models/Llama-3-8B-hf'). + file_type: The type of files to download. Must be one of 'tokenizer', + 'model', or 'all'. + temp_dir: An optional path to a temporary directory for downloading. If + None, the system's default temporary directory is used. + + Raises: + ValueError: If an invalid `file_type` is provided. + """ allow_patterns = [] if file_type in ("tokenizer", "all"): allow_patterns.extend(TOKENIZER_PATTERNS) @@ -66,10 +87,8 @@ def save_hf_model_files_to_gcs( cache_dir=str(tmpdir), token=os.environ.get("HF_TOKEN"), allow_patterns=allow_patterns, - ignore_patterns=["*.bin*"], # Avoid large pytorch_model.bin files ) logger.info(f"Files for '{repo_id}' downloaded locally to '{snapshot_path}'.") - # Upload the directory contents to the specified GCS path. _upload_directory_to_gcs(Path(snapshot_path), gcs_path) diff --git a/torchprime/tools/prepare_gcs_model_files.py b/torchprime/tools/prepare_gcs_model_files.py deleted file mode 100644 index 6f2bc524..00000000 --- a/torchprime/tools/prepare_gcs_model_files.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -This script downloads specified model and tokenizer files from the Hugging Face Hub -and uploads them to a Google Cloud Storage (GCS) bucket. - -It is designed to prepare files required for training runs. - -Usage: - python torchprime/tools/prepare_gcs_model_files.py gs://your-bucket/your-path [--temp-dir /path/to/temp] -""" - -import os -import sys - -from torchprime.launcher import save_hf_tokenizer_and_model - -# --- Configuration --- -# List of models and the specific files to download for each. -# file_type can be 'tokenizer', 'model', or 'all'. -FILES_TO_PREPARE = [ - { - "repo_id": "meta-llama/Meta-Llama-3-8B", - "gcs_dir_name": "meta-llama-3-8b", - "file_type": "all", # Download model weights, config, and tokenizer - }, - { - "repo_id": "meta-llama/Meta-Llama-3.1-405B", - "gcs_dir_name": "meta-llama-3.1-405b", - "file_type": "tokenizer", - }, - { - "repo_id": "meta-llama/Llama-4-Scout-17B-16E", - "gcs_dir_name": "llama-4-scout-17b-16e", - "file_type": "tokenizer", - }, - { - "repo_id": "mistralai/Mixtral-8x7B-v0.1", - "gcs_dir_name": "mixtral-8x7b-v0.1", - "file_type": "tokenizer", - }, -] - - -def main(): - """Downloads and uploads specified Hugging Face files to GCS.""" - if len(sys.argv) < 2 or not sys.argv[1].startswith("gs://"): - print( - f"Usage: python {sys.argv[0]} gs://your-bucket/your-path [--temp-dir /path/to/temp]", - file=sys.stderr, - ) - sys.exit(1) - - gcs_base_path = sys.argv[1] - temp_dir = None - if "--temp-dir" in sys.argv: - try: - idx = sys.argv.index("--temp-dir") - temp_dir = sys.argv[idx + 1] - except IndexError: - print("Error: --temp-dir requires a path.", file=sys.stderr) - sys.exit(1) - - if not os.environ.get("HF_TOKEN"): - raise RuntimeError( - "The HF_TOKEN environment variable is not set. " - "Please run 'huggingface-cli login' or export your token." - ) - - print(f"--- Starting file preparation for GCS path: {gcs_base_path} ---") - - for i, file_info in enumerate(FILES_TO_PREPARE): - repo_id, gcs_dir, file_type = ( - file_info["repo_id"], - file_info["gcs_dir_name"], - file_info["file_type"], - ) - gcs_path = f"{gcs_base_path.rstrip('/')}/{gcs_dir}" - print(f"\n[{i + 1}/{len(FILES_TO_PREPARE)}] Processing '{repo_id}'...") - save_hf_tokenizer_and_model.save_hf_model_files_to_gcs( - repo_id=repo_id, gcs_path=gcs_path, file_type=file_type, temp_dir=temp_dir - ) - print(f" -> Successfully saved '{file_type}' files for '{repo_id}' to {gcs_path}") - - print("\n--- File preparation complete. ---") - - -if __name__ == "__main__": - main() diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 7bdec3d9..2b1d23f7 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,7 +113,7 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - local_path = model_utils.download_gcs_dir_if_needed(model_path_or_repo) + local_path = model_utils.copy_gcs_to_local(model_path_or_repo) logger.info("Loading model from %s", local_path) if os.path.isdir(local_path): @@ -156,8 +156,9 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: # Step 3: Save the HF config files and tokenizer if xr.process_index() == 0: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) - model_utils.copy_hf_config_files(config.model.pretrained_model, save_dir) - model_utils.save_hf_tokenizer(config.model.tokenizer_name, save_dir) + model_path_or_repo = model_utils.copy_gcs_to_local(config.model.pretrained_model) + model_utils.copy_hf_config_files(model_path_or_repo, save_dir) + model_utils.save_hf_tokenizer(model_path_or_repo, save_dir) # Step 4: Initialize torch.distributed process group if not dist.is_initialized(): diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 57918e3d..c5c22985 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -463,7 +463,6 @@ def copy_hf_config_files(model_path_or_repo: str, save_dir: Path) -> None: """ patterns = HF_MODEL_CONFIG_FILES - model_path_or_repo = download_gcs_dir_if_needed(model_path_or_repo) if os.path.isdir(model_path_or_repo): model_dir = Path(model_path_or_repo) else: @@ -497,15 +496,30 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: model repo ID (e.g., "meta-llama/Llama-2-7b-hf"). save_dir: Directory where the tokenizer files will be saved. """ - model_path_or_repo = download_gcs_dir_if_needed(model_path_or_repo) tokenizer = AutoTokenizer.from_pretrained(model_path_or_repo) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) tokenizer.save_pretrained(save_dir) -def download_gcs_dir_if_needed(path_or_repo: str) -> str: - """Resolves a GCS path to a local path by downloading its contents using gsutil.""" +def copy_gcs_to_local(path_or_repo: str) -> str: + """Download gcs content to local temporaily directory. + + If the input `path_or_repo` starts with 'gs://', this function will download + the contents of the GCS directory to a temporary local directory using the + `gsutil` command-line tool. The local directory will be automatically cleaned + up when the program exits. + + If the input is not a GCS path, it is assumed to be a local path or huggingface repo ID, and is + returned unmodified. + + Args: + path_or_repo: The path to resolve. Can be a GCS URI (e.g., + 'gs://bucket/data') or a local file path. + + Returns: + A string containing the path to the local temporary directory. + """ if not path_or_repo.startswith("gs://"): return path_or_repo @@ -535,3 +549,4 @@ def download_gcs_dir_if_needed(path_or_repo: str) -> str: "gsutil download failed for %s. See gsutil output above for details.", path_or_repo, ) + raise diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index bca302c0..ffe7e17c 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -51,7 +51,7 @@ def main(config: omegaconf.DictConfig): # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. tokenizer_name = config.model.tokenizer_name - tokenizer_path_or_repo = model_utils.download_gcs_dir_if_needed(tokenizer_name) + tokenizer_path_or_repo = model_utils.copy_gcs_to_local(tokenizer_name) tokenizer = retry.retry( lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path_or_repo) ) From 15ef3eb3d34154f37dfb8816831ab8102b62ebe1 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Fri, 8 Aug 2025 01:34:32 +0000 Subject: [PATCH 11/19] Formatting --- .../torch_xla_models/model/base_causal_lm.py | 15 +++++++++------ torchprime/torch_xla_models/train.py | 5 ++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 2b1d23f7..298f3ad1 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,14 +113,13 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - local_path = model_utils.copy_gcs_to_local(model_path_or_repo) - logger.info("Loading model from %s", local_path) + model_path_or_repo = model_utils.copy_gcs_to_local(model_path_or_repo) - if os.path.isdir(local_path): - model_dir = local_path + if os.path.isdir(model_path_or_repo): + model_dir = model_path_or_repo else: model_dir = huggingface_hub.snapshot_download( - repo_id=local_path, + repo_id=model_path_or_repo, allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, ) @@ -156,8 +155,12 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: # Step 3: Save the HF config files and tokenizer if xr.process_index() == 0: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) + # Copy to local if in GCS + tokenizer_path_or_repo = model_utils.copy_gcs_to_local( + config.model.tokenizer_name + ) model_path_or_repo = model_utils.copy_gcs_to_local(config.model.pretrained_model) - model_utils.copy_hf_config_files(model_path_or_repo, save_dir) + model_utils.copy_hf_config_files(tokenizer_path_or_repo, save_dir) model_utils.save_hf_tokenizer(model_path_or_repo, save_dir) # Step 4: Initialize torch.distributed process group diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index ffe7e17c..506ad9e4 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -50,10 +50,9 @@ def main(config: omegaconf.DictConfig): logger.info(f"Profiling server started: {str(server)}") # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. - tokenizer_name = config.model.tokenizer_name - tokenizer_path_or_repo = model_utils.copy_gcs_to_local(tokenizer_name) + tokenizer_name = model_utils.copy_gcs_to_local(config.model.tokenizer_name) tokenizer = retry.retry( - lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path_or_repo) + lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_name) ) assert config.torch_dtype == "bfloat16", "Currently only bfloat16 is supported" From 0e5952844abe5d7100ffcdf1bf1f5217031beb51 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 13 Aug 2025 22:41:15 +0000 Subject: [PATCH 12/19] Change copy_gcs_to_local to gcs_to_local which uses context manager to auto clean --- .../torch_xla_models/model/base_causal_lm.py | 37 ++++++----- .../torch_xla_models/model/model_utils.py | 62 +++++++------------ torchprime/torch_xla_models/train.py | 8 +-- 3 files changed, 46 insertions(+), 61 deletions(-) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 298f3ad1..07baa76b 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,19 +113,18 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - model_path_or_repo = model_utils.copy_gcs_to_local(model_path_or_repo) - - if os.path.isdir(model_path_or_repo): - model_dir = model_path_or_repo - else: - model_dir = huggingface_hub.snapshot_download( - repo_id=model_path_or_repo, - allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, - ) - - # Load weights - state_dict = model_utils.load_safetensors_to_state_dict(model_dir) - self.load_state_dict(state_dict) + with model_utils.gcs_to_local(model_path_or_repo) as local_model_path_or_repo: + if os.path.isdir(local_model_path_or_repo): + model_dir = local_model_path_or_repo + else: + model_dir = huggingface_hub.snapshot_download( + repo_id=local_model_path_or_repo, + allow_patterns=["*.safetensors*"] + model_utils.HF_MODEL_CONFIG_FILES, + ) + + # Load weights + state_dict = model_utils.load_safetensors_to_state_dict(model_dir) + self.load_state_dict(state_dict) def _maybe_save_checkpoint(self, config: DictConfig) -> None: """Save a sharded checkpoint and optionally convert it to safetensors format. @@ -156,12 +155,12 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: if xr.process_index() == 0: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) # Copy to local if in GCS - tokenizer_path_or_repo = model_utils.copy_gcs_to_local( - config.model.tokenizer_name - ) - model_path_or_repo = model_utils.copy_gcs_to_local(config.model.pretrained_model) - model_utils.copy_hf_config_files(tokenizer_path_or_repo, save_dir) - model_utils.save_hf_tokenizer(model_path_or_repo, save_dir) + with ( + model_utils.gcs_to_local(config.model.tokenizer_name) as tokenizer_path, + model_utils.gcs_to_local(config.model.pretrained_model) as model_path, + ): + model_utils.copy_hf_config_files(tokenizer_path, save_dir) + model_utils.save_hf_tokenizer(model_path, save_dir) # Step 4: Initialize torch.distributed process group if not dist.is_initialized(): diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index c5c22985..57344dc1 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -2,7 +2,6 @@ from __future__ import annotations -import atexit import importlib import json import logging @@ -32,21 +31,6 @@ "generation_config.json", ] -_TEMP_DIRS_TO_CLEAN = [] - - -def _cleanup_temp_dirs(): - """Removes all temporary directories created by this module.""" - for d in _TEMP_DIRS_TO_CLEAN: - try: - logger.info(f"Cleaning up temporary directory: {d}") - shutil.rmtree(d) - except OSError as e: - logger.warning(f"Failed to remove temporary directory {d}: {e}") - - -atexit.register(_cleanup_temp_dirs) - def load_safetensors_to_state_dict(model_dir: str) -> dict: """Load a model state dict from safetensors, supporting both sharded and single-file formats. @@ -502,26 +486,31 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: tokenizer.save_pretrained(save_dir) -def copy_gcs_to_local(path_or_repo: str) -> str: - """Download gcs content to local temporaily directory. +@contextmanager +def gcs_to_local(path_or_repo: str, temp_dir: str | None = None): + """A context manager to download GCS content to a local temporary directory. If the input `path_or_repo` starts with 'gs://', this function will download the contents of the GCS directory to a temporary local directory using the `gsutil` command-line tool. The local directory will be automatically cleaned - up when the program exits. + up when the context is exited. - If the input is not a GCS path, it is assumed to be a local path or huggingface repo ID, and is - returned unmodified. + If the input is not a GCS path, it is assumed to be a local path or a + Hugging Face repository ID, and is yielded unmodified with no cleanup. Args: path_or_repo: The path to resolve. Can be a GCS URI (e.g., 'gs://bucket/data') or a local file path. + temp_dir: An optional path to a directory for creating the temporary + download location. If None, the system's default temporary directory + is used. - Returns: - A string containing the path to the local temporary directory. + Yields: + A string containing the path to the local directory. """ if not path_or_repo.startswith("gs://"): - return path_or_repo + yield path_or_repo + return if not shutil.which("gsutil"): raise RuntimeError( @@ -529,24 +518,21 @@ def copy_gcs_to_local(path_or_repo: str) -> str: "Please install the Google Cloud SDK." ) - local_dir = tempfile.mkdtemp() - _TEMP_DIRS_TO_CLEAN.append(local_dir) + local_dir = tempfile.mkdtemp(dir=temp_dir) try: gcs_path = path_or_repo.rstrip("/") + "/*" - command = ["gsutil", "-m", "cp", "-r", gcs_path, local_dir] - subprocess.run(command, check=True) - + command = ["gsutil", "-m", "-q", "cp", "-r", gcs_path, local_dir] + subprocess.run(command, check=True, capture_output=True, text=True) logger.info( - "Successfully downloaded files from %s to %s using gsutil.", + "Successfully downloaded files from %s to temporary directory %s.", path_or_repo, local_dir, ) - return local_dir - except (subprocess.CalledProcessError, Exception): - # Clean up the partially created directory on failure. - shutil.rmtree(local_dir) - logger.error( - "gsutil download failed for %s. See gsutil output above for details.", - path_or_repo, - ) + + yield local_dir + except subprocess.CalledProcessError as e: + logger.error("gsutil download failed for %s. Stderr:\n%s", path_or_repo, e.stderr) raise + finally: + logger.info(f"Cleaning up temporary directory: {local_dir}") + shutil.rmtree(local_dir, ignore_errors=True) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 506ad9e4..be0ca113 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -50,10 +50,10 @@ def main(config: omegaconf.DictConfig): logger.info(f"Profiling server started: {str(server)}") # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. - tokenizer_name = model_utils.copy_gcs_to_local(config.model.tokenizer_name) - tokenizer = retry.retry( - lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_name) - ) + with model_utils.local_path_from_gcs(config.model.tokenizer_name) as tokenizer_path: + tokenizer = retry.retry( + lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path) + ) assert config.torch_dtype == "bfloat16", "Currently only bfloat16 is supported" model_dtype = getattr(torch, config.torch_dtype) From 766afb379894be8cd2e362cf962f8217ea88bdcd Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 13 Aug 2025 22:48:52 +0000 Subject: [PATCH 13/19] Remove hf_token on deepseek e2e --- .github/workflows/e2e_test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 0a9fdf43..899652b6 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -291,7 +291,6 @@ jobs: - name: Run Deepseek v3 Shallow id: run-ds-v3-shallow env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | From 91d71b867b74d6539dc1a40db3654b73e4ff0c48 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 13 Aug 2025 23:02:07 +0000 Subject: [PATCH 14/19] Fix error callling wrong function name --- torchprime/torch_xla_models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index be0ca113..3ae21b68 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -50,7 +50,7 @@ def main(config: omegaconf.DictConfig): logger.info(f"Profiling server started: {str(server)}") # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. - with model_utils.local_path_from_gcs(config.model.tokenizer_name) as tokenizer_path: + with model_utils.gcs_to_local(config.model.tokenizer_name) as tokenizer_path: tokenizer = retry.retry( lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path) ) From 14dc4191008061e2d2a38a25e1256a76f671a35d Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 13 Aug 2025 23:04:25 +0000 Subject: [PATCH 15/19] Change function name to be more descriptive --- torchprime/torch_xla_models/model/base_causal_lm.py | 6 +++--- torchprime/torch_xla_models/model/model_utils.py | 2 +- torchprime/torch_xla_models/train.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 07baa76b..d5f8eb4e 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,7 +113,7 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - with model_utils.gcs_to_local(model_path_or_repo) as local_model_path_or_repo: + with model_utils.local_path_from_gcs(model_path_or_repo) as local_model_path_or_repo: if os.path.isdir(local_model_path_or_repo): model_dir = local_model_path_or_repo else: @@ -156,8 +156,8 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) # Copy to local if in GCS with ( - model_utils.gcs_to_local(config.model.tokenizer_name) as tokenizer_path, - model_utils.gcs_to_local(config.model.pretrained_model) as model_path, + model_utils.local_path_from_gcs(config.model.tokenizer_name) as tokenizer_path, + model_utils.local_path_from_gcs(config.model.pretrained_model) as model_path, ): model_utils.copy_hf_config_files(tokenizer_path, save_dir) model_utils.save_hf_tokenizer(model_path, save_dir) diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 57344dc1..ecc00540 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -487,7 +487,7 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: @contextmanager -def gcs_to_local(path_or_repo: str, temp_dir: str | None = None): +def local_path_from_gcs(path_or_repo: str, temp_dir: str | None = None): """A context manager to download GCS content to a local temporary directory. If the input `path_or_repo` starts with 'gs://', this function will download diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 3ae21b68..1c80ba72 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -50,9 +50,9 @@ def main(config: omegaconf.DictConfig): logger.info(f"Profiling server started: {str(server)}") # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. - with model_utils.gcs_to_local(config.model.tokenizer_name) as tokenizer_path: + with model_utils.local_path_from_gcs(config.model.tokenizer_name) as tokenizer_path_or_repo: tokenizer = retry.retry( - lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path) + lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path_or_repo) ) assert config.torch_dtype == "bfloat16", "Currently only bfloat16 is supported" From c728389ce0148ff355d96d061cd5a803e4ff4cd2 Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Wed, 13 Aug 2025 23:07:21 +0000 Subject: [PATCH 16/19] Ruff format --- torchprime/torch_xla_models/model/base_causal_lm.py | 4 +++- torchprime/torch_xla_models/train.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index d5f8eb4e..1d3ac52b 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,7 +113,9 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ - with model_utils.local_path_from_gcs(model_path_or_repo) as local_model_path_or_repo: + with model_utils.local_path_from_gcs( + model_path_or_repo + ) as local_model_path_or_repo: if os.path.isdir(local_model_path_or_repo): model_dir = local_model_path_or_repo else: diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 1c80ba72..1a68079e 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -50,7 +50,9 @@ def main(config: omegaconf.DictConfig): logger.info(f"Profiling server started: {str(server)}") # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. - with model_utils.local_path_from_gcs(config.model.tokenizer_name) as tokenizer_path_or_repo: + with model_utils.local_path_from_gcs( + config.model.tokenizer_name + ) as tokenizer_path_or_repo: tokenizer = retry.retry( lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_path_or_repo) ) From b33a9daeb93afda49d7278167f1a5c3f6320e20d Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 14 Aug 2025 18:32:51 +0000 Subject: [PATCH 17/19] Fix error from previous merging --- torchprime/launcher/cli.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 402a0eaa..31f718b1 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -647,12 +647,6 @@ def main(): # `doctor` command parser_doctor = subparsers.add_parser("doctor", help=doctor.__doc__) parser_doctor.set_defaults(func=doctor) - - # Parse arguments - known_args, remaining_args = parser.parse_known_args() - - func_to_run = known_args.func - is_interactive = known_args.interactive # `save-hf-model-files-to-gcs` command parser_save_hf = subparsers.add_parser( @@ -687,6 +681,12 @@ def main(): ) parser_save_hf.set_defaults(func=save_hf_model_files_to_gcs) + # Parse arguments + known_args, remaining_args = parser.parse_known_args() + + func_to_run = known_args.func + is_interactive = known_args.interactive + # Prepare arguments for the function call func_kwargs = vars(known_args) From a57b6b1b8059309ad870f60811f66a3b5492f7db Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 14 Aug 2025 18:38:45 +0000 Subject: [PATCH 18/19] Ruff format --- torchprime/launcher/cli.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index 31f718b1..65aff52b 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -57,9 +57,7 @@ def save_hf_model_files_to_gcs( repo_id: str, gcs_path: str, file_type: str, temp_dir: str | None ): """Downloads model and tokenizer files from Hugging Face Hub and saves them to Google Cloud Storage.""" - print( - f"Preparing to save '{file_type}' files from '{repo_id}' to '{gcs_path}'..." - ) + print(f"Preparing to save '{file_type}' files from '{repo_id}' to '{gcs_path}'...") try: save_hf_tokenizer_and_model.save_hf_model_files_to_gcs( repo_id, gcs_path, file_type=file_type, temp_dir=temp_dir @@ -74,7 +72,7 @@ def save_hf_model_files_to_gcs( ) except Exception as e: print(f"\n❌ An unexpected error occurred for repository '{repo_id}': {e}") - + def use( cluster: str, @@ -647,7 +645,7 @@ def main(): # `doctor` command parser_doctor = subparsers.add_parser("doctor", help=doctor.__doc__) parser_doctor.set_defaults(func=doctor) - + # `save-hf-model-files-to-gcs` command parser_save_hf = subparsers.add_parser( "save-hf-model-files-to-gcs", @@ -687,7 +685,6 @@ def main(): func_to_run = known_args.func is_interactive = known_args.interactive - # Prepare arguments for the function call func_kwargs = vars(known_args) del func_kwargs["command"] From aeef24a6ba7ec9efceef9b2a705174bd6dfeecbe Mon Sep 17 00:00:00 2001 From: Jack Oh Date: Thu, 14 Aug 2025 21:19:45 +0000 Subject: [PATCH 19/19] Address comment --- torchprime/launcher/save_hf_tokenizer_and_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py index b4e50483..17cf7684 100644 --- a/torchprime/launcher/save_hf_tokenizer_and_model.py +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -10,14 +10,14 @@ logger = logging.getLogger(__name__) -TOKENIZER_PATTERNS = [ +TOKENIZER_PATTERNS = ( "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "*.model", "vocab.txt", "merges.txt", -] +) MODEL_PATTERNS = [ "*.safetensors*",