diff --git a/.github/workflows/cpu_test.yml b/.github/workflows/cpu_test.yml index b6907764..8516a38a 100644 --- a/.github/workflows/cpu_test.yml +++ b/.github/workflows/cpu_test.yml @@ -41,18 +41,12 @@ jobs: python -m pip install --upgrade pip pip install -e '.[dev]' - name: Run PyTest - env: - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | export PJRT_DEVICE=CPU export JAX_PLATFORMS=cpu export CI=true pytest -v - name: Run model forward - env: - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | export PJRT_DEVICE=CPU export JAX_PLATFORMS=cpu diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 0d3828a6..a71c83c8 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -68,7 +68,6 @@ jobs: - name: Run Llama 3.0 8B id: run-llama-3-8b env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -78,6 +77,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ dataset=wikitext \ task=train \ task.global_batch_size=8 \ @@ -112,7 +112,6 @@ jobs: - name: Run Llama 3.1 8B (Splash Attention) id: run-llama-3_1-8b-SplashAttention env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -122,6 +121,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \ model.attention_kernel=splash_attention \ dataset=wikitext \ task=train \ @@ -134,7 +134,6 @@ jobs: - name: Run Llama 3.1 8B (Scan + Offload) id: run-llama-3_1-8b-scan-offload env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -144,6 +143,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3.1-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \ model/remat=llama-scan-offload \ dataset=wikitext \ task=train \ @@ -156,7 +156,6 @@ jobs: - name: Run Llama 3.0 8B (2D sharding) id: run-llama-3-8b-2d env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -166,6 +165,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model/sharding=llama-fsdp-tp \ dataset=wikitext \ task=train \ @@ -179,7 +179,6 @@ jobs: - name: Run Llama 3.0 8B (fsdp + cp) id: run-llama-3-8b-fsdp-cp env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -189,6 +188,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b-cp \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model/sharding=llama-fsdp-tp-cp \ dataset=wikitext \ task=train \ @@ -201,7 +201,6 @@ jobs: - name: Run Mixtral 8x7B id: run-mixtral-8x7b env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -211,6 +210,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=mixtral-8x7b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/mixtral-8x7b-v0.1/ \ model.num_hidden_layers=16 \ dataset=wikitext \ task=train \ @@ -223,7 +223,6 @@ jobs: - name: Run Llama 3.0 8B (2 slice) id: run-llama-3-8b-2-slice env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -233,6 +232,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ @@ -247,7 +247,6 @@ jobs: - name: Run Llama 3.0 8B SFT id: run-llama-3-8b-sft env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -257,6 +256,8 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ --config-name llama-3-8b-sft-w-gsm8k \ + model.pretrained_model=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ ici_mesh.fsdp=4 \ task.max_steps=50 \ task.convert_to_safetensors=False \ @@ -265,7 +266,6 @@ jobs: - name: Run Llama 3.0 8B (ddp + fsdp) id: run-llama-3-8b-ddp-fsdp env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} XLA_IR_DEBUG: 1 XLA_HLO_DEBUG: 1 run: | @@ -275,6 +275,7 @@ jobs: --name $name \ --num-slices 2 \ torchprime/torch_xla_models/train.py \ + model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ model=llama-3-8b \ model/sharding=llama-fsdp \ dataset=wikitext \ diff --git a/.github/workflows/reusable_e2e_check.yml b/.github/workflows/reusable_e2e_check.yml index f7faf677..c282b83f 100644 --- a/.github/workflows/reusable_e2e_check.yml +++ b/.github/workflows/reusable_e2e_check.yml @@ -30,9 +30,6 @@ on: secrets: GCP_SA_KEY: required: true - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. - HF_TOKEN: - required: true jobs: results: diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index a5134337..a7534527 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -17,12 +17,14 @@ import click import toml from dataclasses_json import dataclass_json +from huggingface_hub.errors import RepositoryNotFoundError from pathspec import PathSpec from pathspec.patterns import GitWildMatchPattern # type: ignore from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer import torchprime.launcher.doctor +from torchprime.launcher import save_hf_tokenizer_and_model from torchprime.launcher.buildpush import buildpush from torchprime.launcher.util import run_docker @@ -77,6 +79,54 @@ def cli(ctx, interactive): ctx.obj["interactive"] = interactive +@cli.command("save-hf-model-files-to-gcs") +@click.option( + "--repo-id", + type=str, + required=True, + help="Hugging Face model or tokenizer repo ID (e.g., 'meta-llama/Llama-3-8B-hf').", +) +@click.option( + "--gcs-path", + type=str, + required=True, + help="Target GCS path for the model files (e.g., 'gs://bucket/models/Llama-3-8B-hf').", +) +@click.option( + "--file-type", + type=click.Choice(["tokenizer", "model", "all"], case_sensitive=False), + default="all", + help="Type of files to save. 'tokenizer' for tokenizer files, 'model' for model weights and configs, 'all' for both.", +) +@click.option( + "--temp-dir", + type=str, + default=None, + help="Path to a temporary directory with sufficient space. Defaults to system temp.", +) +def save_hf_model_files_to_gcs( + repo_id: str, gcs_path: str, file_type: str, temp_dir: str | None +): + """Downloads model and tokenizer files from Hugging Face Hub and saves them to Google Cloud Storage.""" + click.echo( + f"Preparing to save '{file_type}' files from '{repo_id}' to '{gcs_path}'..." + ) + try: + save_hf_tokenizer_and_model.save_hf_model_files_to_gcs( + repo_id, gcs_path, file_type=file_type, temp_dir=temp_dir + ) + click.echo(f" -> Successfully saved files to {gcs_path}") + except RepositoryNotFoundError: + click.echo(f"\n❌ Error: Repository '{repo_id}' not found.") + click.echo("Please check the following:") + click.echo(f"1. The repository ID '{repo_id}' is spelled correctly.") + click.echo( + "2. If it's a gated repository, ensure you are authenticated by running 'huggingface-cli login' or exporting your HF_TOKEN." + ) + except Exception as e: + click.echo(f"\n❌ An unexpected error occurred for repository '{repo_id}': {e}") + + @cli.command() @click.option("--cluster", required=True, help="Name of the XPK cluster") @click.option("--project", required=True, help="GCP project the cluster belongs to") diff --git a/torchprime/launcher/save_hf_tokenizer_and_model.py b/torchprime/launcher/save_hf_tokenizer_and_model.py new file mode 100644 index 00000000..b4e50483 --- /dev/null +++ b/torchprime/launcher/save_hf_tokenizer_and_model.py @@ -0,0 +1,94 @@ +"""Utilities for preparing Hugging Face assets (models and tokenizers) for GCS.""" + +import logging +import os +import subprocess +import tempfile +from pathlib import Path + +from huggingface_hub import snapshot_download + +logger = logging.getLogger(__name__) + +TOKENIZER_PATTERNS = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "*.model", + "vocab.txt", + "merges.txt", +] + +MODEL_PATTERNS = [ + "*.safetensors*", + "config.json", + "generation_config.json", +] + + +def _upload_directory_to_gcs(local_path: Path, gcs_path: str): + """Uploads the contents of a local directory to GCS using gsutil. + + Args: + local_path: The local directory whose contents will be uploaded. + gcs_path: The destination GCS path (e.g., 'gs://my-bucket/models/'). + """ + if not gcs_path.startswith("gs://"): + raise ValueError("GCS path must start with gs://") + + logger.info(f"Uploading contents of '{local_path}' to '{gcs_path}'...") + command = ["gsutil", "-m", "cp", "-r", f"{str(local_path).rstrip('/')}/*", gcs_path] + try: + subprocess.run(command, check=True, capture_output=True, text=True) + logger.info(f"Successfully uploaded assets to {gcs_path}.") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to upload {local_path} to {gcs_path}. Error: {e.stderr}") + raise + + +def save_hf_model_files_to_gcs( + repo_id: str, + gcs_path: str, + file_type: str, + temp_dir: str | None = None, +): + """Downloads model or tokenizer files from Hugging Face and uploads them to GCS. + + This function uses `huggingface_hub.snapshot_download` to fetch specific + files based on predefined patterns for models and tokenizers. The downloaded + files are then uploaded to the specified GCS path. + + Args: + repo_id: The ID of the Hugging Face repository (e.g., 'meta-llama/Llama-3-8B-hf'). + gcs_path: The target GCS path for the files (e.g., 'gs://bucket/models/Llama-3-8B-hf'). + file_type: The type of files to download. Must be one of 'tokenizer', + 'model', or 'all'. + temp_dir: An optional path to a temporary directory for downloading. If + None, the system's default temporary directory is used. + + Raises: + ValueError: If an invalid `file_type` is provided. + """ + allow_patterns = [] + if file_type in ("tokenizer", "all"): + allow_patterns.extend(TOKENIZER_PATTERNS) + if file_type in ("model", "all"): + allow_patterns.extend(MODEL_PATTERNS) + + if not allow_patterns: + raise ValueError("file_type must be one of 'tokenizer', 'model', or 'all'") + + with tempfile.TemporaryDirectory(dir=temp_dir) as tmpdir: + logger.info(f"Created temporary directory: {tmpdir}") + + logger.info(f"Downloading files for '{repo_id}' with patterns: {allow_patterns}") + snapshot_path = snapshot_download( + repo_id=repo_id, + cache_dir=str(tmpdir), + token=os.environ.get("HF_TOKEN"), + allow_patterns=allow_patterns, + ) + + logger.info(f"Files for '{repo_id}' downloaded locally to '{snapshot_path}'.") + + _upload_directory_to_gcs(Path(snapshot_path), gcs_path) diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index 5dc5efbe..298f3ad1 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -113,6 +113,8 @@ def from_pretrained(self, model_path_or_repo: str): Args: model_path_or_repo: Path to the local directory or Hugging Face Hub repository ID. """ + model_path_or_repo = model_utils.copy_gcs_to_local(model_path_or_repo) + if os.path.isdir(model_path_or_repo): model_dir = model_path_or_repo else: @@ -153,8 +155,13 @@ def _maybe_save_checkpoint(self, config: DictConfig) -> None: # Step 3: Save the HF config files and tokenizer if xr.process_index() == 0: logger.info("Saving Hugging Face configs and tokenizer to %s", save_dir) - model_utils.copy_hf_config_files(config.model.pretrained_model, save_dir) - model_utils.save_hf_tokenizer(config.model.pretrained_model, save_dir) + # Copy to local if in GCS + tokenizer_path_or_repo = model_utils.copy_gcs_to_local( + config.model.tokenizer_name + ) + model_path_or_repo = model_utils.copy_gcs_to_local(config.model.pretrained_model) + model_utils.copy_hf_config_files(tokenizer_path_or_repo, save_dir) + model_utils.save_hf_tokenizer(model_path_or_repo, save_dir) # Step 4: Initialize torch.distributed process group if not dist.is_initialized(): diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index 0f9fa1b9..c5c22985 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import atexit import importlib import json import logging @@ -31,6 +32,21 @@ "generation_config.json", ] +_TEMP_DIRS_TO_CLEAN = [] + + +def _cleanup_temp_dirs(): + """Removes all temporary directories created by this module.""" + for d in _TEMP_DIRS_TO_CLEAN: + try: + logger.info(f"Cleaning up temporary directory: {d}") + shutil.rmtree(d) + except OSError as e: + logger.warning(f"Failed to remove temporary directory {d}: {e}") + + +atexit.register(_cleanup_temp_dirs) + def load_safetensors_to_state_dict(model_dir: str) -> dict: """Load a model state dict from safetensors, supporting both sharded and single-file formats. @@ -484,3 +500,53 @@ def save_hf_tokenizer(model_path_or_repo: str, save_dir: Path) -> None: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) tokenizer.save_pretrained(save_dir) + + +def copy_gcs_to_local(path_or_repo: str) -> str: + """Download gcs content to local temporaily directory. + + If the input `path_or_repo` starts with 'gs://', this function will download + the contents of the GCS directory to a temporary local directory using the + `gsutil` command-line tool. The local directory will be automatically cleaned + up when the program exits. + + If the input is not a GCS path, it is assumed to be a local path or huggingface repo ID, and is + returned unmodified. + + Args: + path_or_repo: The path to resolve. Can be a GCS URI (e.g., + 'gs://bucket/data') or a local file path. + + Returns: + A string containing the path to the local temporary directory. + """ + if not path_or_repo.startswith("gs://"): + return path_or_repo + + if not shutil.which("gsutil"): + raise RuntimeError( + "gsutil command not found, but is required for downloading from GCS. " + "Please install the Google Cloud SDK." + ) + + local_dir = tempfile.mkdtemp() + _TEMP_DIRS_TO_CLEAN.append(local_dir) + try: + gcs_path = path_or_repo.rstrip("/") + "/*" + command = ["gsutil", "-m", "cp", "-r", gcs_path, local_dir] + subprocess.run(command, check=True) + + logger.info( + "Successfully downloaded files from %s to %s using gsutil.", + path_or_repo, + local_dir, + ) + return local_dir + except (subprocess.CalledProcessError, Exception): + # Clean up the partially created directory on failure. + shutil.rmtree(local_dir) + logger.error( + "gsutil download failed for %s. See gsutil output above for details.", + path_or_repo, + ) + raise diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 534a37f0..506ad9e4 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -50,7 +50,7 @@ def main(config: omegaconf.DictConfig): logger.info(f"Profiling server started: {str(server)}") # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Add tokenizers to torchprime. - tokenizer_name = config.model.tokenizer_name + tokenizer_name = model_utils.copy_gcs_to_local(config.model.tokenizer_name) tokenizer = retry.retry( lambda: transformers.AutoTokenizer.from_pretrained(tokenizer_name) )