diff --git a/tests/assets/tokenizer/tokenizer.json b/tests/assets/tokenizer/tokenizer.json index a39c930b41..3f7f2a019d 100644 --- a/tests/assets/tokenizer/tokenizer.json +++ b/tests/assets/tokenizer/tokenizer.json @@ -2029,7 +2029,10 @@ "land": 1994, "?\n": 1995, " respect": 1996, - "ances": 1997 + "ances": 1997, + "<|image|>": 1998, + "<|begin_of_image|>": 1999, + "<|end_of_image|>": 2000 }, "merges": [ ] diff --git a/tests/assets/tokenizer/tokenizer_config.json b/tests/assets/tokenizer/tokenizer_config.json index da6379b3f8..8ae72ccc09 100644 --- a/tests/assets/tokenizer/tokenizer_config.json +++ b/tests/assets/tokenizer/tokenizer_config.json @@ -15,11 +15,38 @@ "rstrip": false, "single_word": false, "special": true + }, + "1998": { + "content": "<|image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1999": { + "content": "<|begin_of_image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2000": { + "content": "<|end_of_image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true } }, "bos_token": "<|begin_of_text|>", "clean_up_tokenization_spaces": true, "eos_token": "<|end_of_text|>", + "img_token": "<|image|>", + "boi_token": "<|begin_of_image|>", + "eoi_token": "<|end_of_image|>", "model_input_names": [ "input_ids", "attention_mask" diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 9d81f6b885..d11ef99d88 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -7,3 +7,4 @@ import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.qwen3 import torchtitan.experiments.simple_fsdp # noqa: F401 +import torchtitan.experiments.vlm # noqa: F401 diff --git a/torchtitan/experiments/vlm/README.md b/torchtitan/experiments/vlm/README.md new file mode 100644 index 0000000000..2739149a55 --- /dev/null +++ b/torchtitan/experiments/vlm/README.md @@ -0,0 +1,19 @@ +# Vision Language Model training in `torchtitan` + +**under active development** + +This folder showcases how to train modern Vision Language Model (vlm) in torchtitan. + + +## Features: +- Native Aspect Ratio: not limited to square crops. +- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails. +- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model. + + +## Design +Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size. +Then we scatter the patch embeddings to their actual positions in the LLM input tokens. +This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio. +By setting the appropriate dataloader hyperparameters, we can easily reduce the amount of padding tokens. +We leverage Flex Attention to efficiently handle varying number of patches per image. diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py new file mode 100644 index 0000000000..244a5231cd --- /dev/null +++ b/torchtitan/experiments/vlm/__init__.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .datasets.mm_datasets import build_mm_dataloader +from .infra.parallelize import parallelize_vlm +# from .infra.pipeline import pipeline_llama +from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs +from .model.model import Llama3Siglip2Transformer + +__all__ = [ + "parallelize_vlm", + # "pipeline_llama", + "Llama3Siglip2ModelArgs", + "Llama3Siglip2Transformer", + "llama3_siglip2_configs", +] + + +siglip2_configs = { + "debugmodel": Siglip2ModelArgs( + dim=128, + ffn_dim=256, + n_layers=4, + n_heads=2, + ) +} + +llama3_siglip2_configs = { + "debugmodel": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000 + ), + "debugmodel_flex_attn": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2000, + rope_theta=500000, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "8B": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), + "405B": Llama3Siglip2ModelArgs( + encoder=siglip2_configs["debugmodel"], + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), +} + + +register_train_spec( + TrainSpec( + name="llama3-siglip2", + model_cls=Llama3Siglip2Transformer, + model_args=llama3_siglip2_configs, + parallelize_fn=parallelize_vlm, + pipelining_fn=None, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_mm_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + # state_dict_adapter=Llama3StateDictAdapter, + ) +) diff --git a/torchtitan/experiments/vlm/datasets/mm_collator_nld.py b/torchtitan/experiments/vlm/datasets/mm_collator_nld.py new file mode 100644 index 0000000000..7c44665014 --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/mm_collator_nld.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Dict, List + +import einops as E +import torch +from torch.nn.utils.rnn import pad_sequence + +from torchtitan.tools.logging import logger + + +IGNORE_INDEX = -100 + + +@dataclass +class MultiModalCollatorNLD: + """Collator that works with patches in NLD format (N=batch, L=patches, D=patch_features)""" + + padding_idx: int = 0 + ignore_idx: int = IGNORE_INDEX + max_images_per_batch: int = 5 + max_patch_per_image: int = 256 # Maximum patches per image + patch_size: int = 16 # Patch size for converting images to patches + merge_size: int = 1 # Merge size for converting spatial patches to channel dim + seq_len: int = 2048 + + def convert_to_patches( + self, pixel_values: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Direct NTHWC -> NLD conversion using einops.""" + N, T, H, W, C = pixel_values.shape + ps = self.patch_size + device = pixel_values.device + patches = E.rearrange( + pixel_values, "n t (h p1) (w p2) c -> n (t h w) (p1 p2 c)", p1=ps, p2=ps + ) + + coords = torch.meshgrid( + torch.arange(T, device=device), + torch.arange(H // ps, device=device), + torch.arange(W // ps, device=device), + indexing="ij", + ) + grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords") + grid = grid.unsqueeze(0).expand(N, -1, -1) # (N, t*h*w, 3) + + # All patches are valid since we resize images to be divisible by patch_size + return patches, grid + + def _pad_to_max(self, patches, grids): + """Pad or truncate to max_patch_per_image.""" + N, L, D = patches.shape + if L == self.max_patch_per_image: + return patches, grids + elif L < self.max_patch_per_image: + # Pad + pad_len = self.max_patch_per_image - L + zero_patches = torch.zeros(N, pad_len, D, device=patches.device) + invalid_grids = torch.full( + (grids.shape[0], pad_len, 3), -1, device=grids.device + ) + return torch.cat([patches, zero_patches], 1), torch.cat( + [grids, invalid_grids], 1 + ) + else: + # Truncate + return ( + patches[:, : self.max_patch_per_image], + grids[:, : self.max_patch_per_image], + ) + + def __call__( + self, batch: List[Dict[str, Any]] + ) -> tuple[Dict[str, torch.Tensor | None], torch.Tensor]: + """Encode batch with patch-based approach.""" + if not batch: + return None + + # Count images per sample and total images + images_per_sample = [] + for sample in batch: + num_images = ( + len(sample.get("pixel_values", [])) if "pixel_values" in sample else 0 + ) + images_per_sample.append(num_images) + + # Remove samples from end until total images <= max_images_per_batch + total_images = sum(images_per_sample) + while total_images > self.max_images_per_batch and batch: + removed_images = images_per_sample.pop() + total_images -= removed_images + batch.pop() + logger.warning(f"Removed sample with {removed_images} images to keep total images <= {self.max_images_per_batch}") + + all_images = [ + img + for sample in batch + if "pixel_values" in sample + for img in sample["pixel_values"] + ] + + if all_images: + patch_list, grid_list = [], [] + for img in all_images: + p, g = self.convert_to_patches(img.unsqueeze(0)) + p, g = self._pad_to_max(p, g) + patch_list.append(p[0]) + grid_list.append(g[0]) + patches = torch.stack(patch_list) + grids = torch.stack(grid_list) + + if len(all_images) < self.max_images_per_batch: + blank_count = self.max_images_per_batch - len(all_images) + blank_patches = torch.zeros( + blank_count, + self.max_patch_per_image, + patches.shape[2], + device=patches.device, + ) + blank_grids = torch.full( + (blank_count, self.max_patch_per_image, 3), -1, device=grids.device + ) + patches = torch.cat([patches, blank_patches], dim=0) + grids = torch.cat([grids, blank_grids], dim=0) + else: + patches = grids = None + + # Text processing + input_ids = pad_sequence( + [s["input_ids"] for s in batch], + batch_first=True, + padding_value=self.padding_idx, + ) + labels = pad_sequence( + [s["labels"] for s in batch], + batch_first=True, + padding_value=self.padding_idx, + ) + + # Pad along batch dimension if needed + batch_size = len(batch) + if input_ids.size(0) < batch_size: + padding_needed = batch_size - input_ids.size(0) + padding_input = ( + torch.ones(padding_needed, input_ids.size(1), dtype=torch.long) + * self.padding_idx + ) + padding_labels = ( + torch.ones(padding_needed, labels.size(1), dtype=torch.long) + * self.padding_idx + ) + input_ids = torch.cat([input_ids, padding_input], dim=0) + labels = torch.cat([labels, padding_labels], dim=0) + + # Handle sequence length + current_length = input_ids.size(1) + desired_length = self.seq_len + 1 # Extra token for label shift and cut + if current_length < desired_length: + padding_length = desired_length - current_length + padding_input = ( + torch.ones(batch_size, padding_length, dtype=torch.long) + * self.padding_idx + ) + padding_labels = ( + torch.ones(batch_size, padding_length, dtype=torch.long) + * self.padding_idx + ) + input_ids = torch.cat([input_ids, padding_input], dim=1) + labels = torch.cat([labels, padding_labels], dim=1) + elif current_length > self.seq_len: + input_ids = input_ids[:, :desired_length] + labels = labels[:, :desired_length] + + labels[labels == self.padding_idx] = self.ignore_idx + # Cut and shift + input_ids = input_ids[:, :-1] + labels = labels[:, 1:] + + return { + "input": input_ids, + "pixel_values": patches, + "grid_thw": grids, + }, labels diff --git a/torchtitan/experiments/vlm/datasets/mm_datasets.py b/torchtitan/experiments/vlm/datasets/mm_datasets.py new file mode 100644 index 0000000000..f540779d7c --- /dev/null +++ b/torchtitan/experiments/vlm/datasets/mm_datasets.py @@ -0,0 +1,422 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from io import BytesIO +from typing import Any + +import numpy as np +import requests +import torch +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node + +from PIL import Image +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.config import JobConfig +from torchtitan.tools.logging import logger + +from .mm_collator_nld import MultiModalCollatorNLD + + +IGNORE_INDEX = -100 +# TODO: should add this to the tokenizer +BEGIN_OF_IMAGE_TOKEN = "<|begin_of_image|>" +END_OF_IMAGE_TOKEN = "<|end_of_image|>" +IMAGE_TOKEN = "<|image|>" + + +def smart_resize( + height: int, + width: int, + factor: int = 16, + min_pixels: int = 16 * 16 * 16, + max_pixels: int = 16 * 16 * 4 * 1280, +): + if height < factor or width < factor: + raise ValueError( + f"height:{height} or width:{width} must be larger than factor:{factor}" + ) + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = np.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = np.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def resize_image_by_patch_count( + image, + max_patch_per_image, + patch_size=16, + merge_size=1, + min_dimension=56, +): + """Resize image while maintaining aspect ratio and ensuring patch count <= max_patch_per_image. + + Args: + image: PIL Image or image bytes + max_patch_per_image: Maximum number of patches (L) allowed per image + patch_size: Size of each patch (default: 16) + merge_size: Spatial Merge size factor (default: 1) + min_dimension: Minimum dimension for width/height (default: 56) + + Returns: + Resized PIL Image with dimensions divisible by factor and patches <= max_patch_per_image + """ + if not isinstance(image, Image.Image): + image = Image.open(BytesIO(image)) + + original_width, original_height = image.size + factor = patch_size * merge_size + + # Calculate current number of patches + current_patches = (original_height * original_width) // (factor * factor) + + # If already within limits and divisible, return as-is after smart_resize + if current_patches <= max_patch_per_image: + try: + resized_height, resized_width = smart_resize( + original_height, original_width, factor=factor + ) + return image.resize((resized_width, resized_height)) + except ValueError: + # If smart_resize fails, continue with scaling + pass + + # Calculate maximum area that gives us max_patch_per_image patches + max_area = max_patch_per_image * (factor * factor) + + # Calculate scaling factor to fit within max_area while maintaining aspect ratio + current_area = original_width * original_height + scale_factor = math.sqrt(max_area / current_area) + + # Scale dimensions + new_width = int(original_width * scale_factor) + new_height = int(original_height * scale_factor) + + # Ensure minimum dimensions + if new_width < min_dimension: + new_width = min_dimension + new_height = int(new_width * original_height / original_width) + if new_height < min_dimension: + new_height = min_dimension + new_width = int(new_height * original_width / original_height) + + # Use smart_resize to ensure divisibility and handle constraints + try: + resized_height, resized_width = smart_resize( + new_height, new_width, factor=factor + ) + except ValueError: + # If smart_resize fails, fall back to manual rounding + resized_height = (new_height // factor) * factor + resized_width = (new_width // factor) * factor + resized_height = max(factor, resized_height) # Ensure at least one patch + resized_width = max(factor, resized_width) + + # Final verification: ensure patch count is within limit + final_patches = (resized_height * resized_width) // (factor * factor) + if final_patches > max_patch_per_image: + # Reduce dimensions proportionally + reduction_factor = math.sqrt(max_patch_per_image / final_patches) + resized_height = int(resized_height * reduction_factor) + resized_width = int(resized_width * reduction_factor) + + # Round down to nearest factor multiple + resized_height = (resized_height // factor) * factor + resized_width = (resized_width // factor) * factor + resized_height = max(factor, resized_height) + resized_width = max(factor, resized_width) + + resized_image = image.resize((resized_width, resized_height)) + return resized_image + + +def calculate_image_tokens(image, patch_size=16, merge_size=1): + """Calculate tokens for an image based on patch size.""" + if isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + else: + width, height = image.size + return ( + int((height * width) / (patch_size * patch_size * merge_size * merge_size)), + int(width / (patch_size * merge_size)), + int(height / (patch_size * merge_size)), + ) + + +class MultiModalDataset(IterableDataset, Stateful): + """PyTorch MultiModal Dataset.""" + + def __init__( + self, + dataset_path: str, + tokenizer: BaseTokenizer, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + patch_size: int = 16, + merge_size: int = 1, + max_patch_per_image: int = 256, + max_images_per_batch: int = 4, + ) -> None: + ds = load_dataset(dataset_path, split="train", streaming=True) + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self._sample_idx = 0 + self.patch_size = patch_size + self.merge_size = merge_size + self.max_patch_per_image = max_patch_per_image + self.max_images_per_batch = max_images_per_batch + + def _process_sample(self, sample: dict[str, Any]) -> dict[str, Any] | None: + """Process a single sample into the required format.""" + try: + # Get texts, images and metadata + texts = sample.get("texts", []) + images = sample.get("images", []) + metadata = sample.get("metadata", []) + + if not texts or len(texts) != len(images): + logger.warning( + f"Invalid sample: texts={len(texts)}, images={len(images)}" + ) + return None + + # Process images and build interleaved text + processed_images = [] + processed_text = "" + + for i, (img, txt) in enumerate(zip(images, texts)): + # Add text if it exists + if txt is not None: + processed_text += txt + + # Try to get image if it exists + if img is not None: + try: + # Handle online case (image URLs) + if isinstance(img, str) and img.startswith("http"): + response = requests.get(img) + img = Image.open(BytesIO(response.content)) + # Handle offline/cached case + elif isinstance(img, bytes): + img = Image.open(BytesIO(img)) + elif isinstance(img, str): + img = Image.open(img) + + if img.mode != "RGB": + img = img.convert("RGB") + + # Resize maintaining aspect ratio + img = resize_image_by_patch_count( + img, + max_patch_per_image=self.max_patch_per_image, + patch_size=self.patch_size, + merge_size=self.merge_size, + ) + + # Convert to numpy array and rescale to [0, 1] + img_array = np.array(img) / 255.0 + + # Normalize with OpenAI CLIP mean/std + mean = np.array([0.48145466, 0.4578275, 0.40821073]) + std = np.array([0.26862954, 0.26130258, 0.27577711]) + img_array = (img_array - mean) / std + + # Convert to tensor in NTHWC format (1, H, W, 3) + img_tensor = torch.from_numpy(img_array).float() + img_tensor = img_tensor.unsqueeze(0) # Add time dimension + + # Calculate number of image tokens needed + ( + num_tokens, + add_row_image_token_after, + _, + ) = calculate_image_tokens( + img, patch_size=self.patch_size, merge_size=self.merge_size + ) + + processed_images.append(img_tensor) + processed_text += BEGIN_OF_IMAGE_TOKEN + + # Add image tokens with row separators following dataset_utils pattern + image_tokens = [] + for token_idx in range(num_tokens): + image_tokens.append(IMAGE_TOKEN) + + processed_text += "".join(image_tokens) + processed_text += END_OF_IMAGE_TOKEN + + except Exception as e: + logger.warning(f"Error processing image {i}: {e}") + + if not processed_images: + return None + + # Add EOS token and tokenize + processed_text = processed_text + END_OF_IMAGE_TOKEN + tokens = self._tokenizer.encode(processed_text) + + input_ids = torch.tensor(tokens) + labels = torch.tensor(tokens) + + if len(input_ids) > self.seq_len: + logger.warning( + f"Skipping sample with length {len(input_ids)} greater than max_seq_len {self.seq_len}" + ) + return None + + # Get special token IDs just like in dataset_utils.py + def _get_special_token_id(token): + token_id = self._tokenizer.encode(token) + assert ( + len(token_id) == 1 + ), f"{token} is not a special token of the tokenizer" + return token_id[0] + + special_tokens = [ + _get_special_token_id(token) + for token in ( + IMAGE_TOKEN, + BEGIN_OF_IMAGE_TOKEN, + END_OF_IMAGE_TOKEN, + ) + ] + + labels = torch.where( + torch.isin(labels, torch.tensor(special_tokens)), IGNORE_INDEX, labels + ) + + # No truncation here - let collator handle it + + # Keep images as list + pixel_values = processed_images # List of tensors + + return { + "input_ids": input_ids, + "labels": labels, + "pixel_values": pixel_values, + } + + except Exception as e: + logger.warning(f"Error processing sample: {e}") + return None + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + try: + processed = self._process_sample(sample) + if processed is None: + continue + + # Simply yield individual samples - DataLoader will handle batching + self._sample_idx += 1 + yield processed + + except Exception as e: + logger.warning(f"Error in iteration: {e}") + continue + + if not self.infinite: + break + else: + self._sample_idx = 0 + + def _get_data_iter(self): + try: + # For streaming datasets, we don't need to check length + if not hasattr(self._data, "iterable_dataset"): + if isinstance(self._data, Dataset) and self._sample_idx == len( + self._data + ): + return iter([]) + + it = iter(self._data) + + # Skip samples if needed + if self._sample_idx > 0: + for _ in range(self._sample_idx): + next(it) + + return it + except Exception as e: + logger.error(f"Error in _get_data_iter: {e}") + return iter([]) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + + def state_dict(self): + return {"sample_idx": self._sample_idx} + + +def build_mm_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: BaseTokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.local_batch_size + seq_len = job_config.training.seq_len + # TODO: config + max_images_per_batch = batch_size * 2 + max_patch_per_image = 256 + patch_size = 16 + + hf_ds = MultiModalDataset( + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + patch_size=patch_size, + merge_size=1, + max_patch_per_image=max_patch_per_image, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + max_images_per_batch=max_images_per_batch, + ) + + collate_fn = MultiModalCollatorNLD( + padding_idx=0, + max_images_per_batch=max_images_per_batch, + max_patch_per_image=max_patch_per_image, + patch_size=patch_size, + merge_size=1, + seq_len=seq_len, + ) + + base_dataloader = ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, # Use micro_batch_size for initial batching + collate_fn=collate_fn, + ) + + return base_dataloader diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py new file mode 100644 index 0000000000..201e3dedd2 --- /dev/null +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +from collections import defaultdict +from typing import Optional + +import torch +import torch.nn as nn +from torch.distributed._composable.replicate import replicate +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger + + +def parallelize_vlm( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + assert isinstance(model.encoder, nn.Module) + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + raise NotImplementedError("TP support for VLM training is still in progress.") + # if ( + # job_config.parallelism.enable_async_tensor_parallel + # and not job_config.training.compile + # ): + # raise RuntimeError("Async TP requires --training.compile") + # + # enable_float8_linear = "float8" in job_config.model.converters + # float8_is_rowwise = job_config.float8.recipe_name in ( + # "rowwise", + # "rowwise_with_gw_hp", + # ) + # + # # For now, float8 all-gather with TP is only supported for tensorwise + # # float8 scaling recipes. For rowwise recipes, we use regular TP and + # # all-gather happens in high precision. + # enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + # + # apply_tp( + # model, + # world_mesh["tp"], + # loss_parallel=not job_config.parallelism.disable_loss_parallel, + # enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + # enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + # ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + apply_ac(model.encoder, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.compile.enable: + apply_compile(model) + apply_compile(model.encoder) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.compile.enable, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_transformer_block( + module: nn.Module, ac_config: ACConfig, *, base_fqn: Optional[str] = None +): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + mm_recompute_shapes = set() + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + logger.debug( + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, ac_config, base_fqn=f"layers.{layer_id}" + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model {type(model).__name__}") + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Compiling each TransformerBlock of {type(model).__name__} with torch.compile") + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.encoder.layers.items(): + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.layers.items(): + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + fully_shard(model, **fsdp_config) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/torchtitan/experiments/vlm/model/args.py b/torchtitan/experiments/vlm/model/args.py new file mode 100644 index 0000000000..4bf6583ccb --- /dev/null +++ b/torchtitan/experiments/vlm/model/args.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from torchtitan.models.llama3 import TransformerModelArgs as Llama3Args + + +@dataclass +class Siglip2ModelArgs: + dim: int = 768 + ffn_dim: int = 3072 + n_layers: int = 12 + n_heads: int = 12 + + n_pos_embs: int = 16 # Number of positional embeddings per h&w + n_channels: int = 3 # RGB channels + patch_size: int = 16 + + layer_norm_eps: float = 1e-6 + use_flex_attn: bool = True + attn_mask_type: str = "causal" + + +@dataclass +class Llama3Siglip2ModelArgs(Llama3Args): + encoder: Siglip2ModelArgs = field(default_factory=Siglip2ModelArgs) + img_token_id: int = 1998 diff --git a/torchtitan/experiments/vlm/model/model.py b/torchtitan/experiments/vlm/model/model.py new file mode 100644 index 0000000000..e33d11f9b8 --- /dev/null +++ b/torchtitan/experiments/vlm/model/model.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import einops as E +import torch +from torch import nn + +from torchtitan.models.attention import init_attention_mask +from torchtitan.models.llama3 import Transformer as Llama3 + +from .args import Llama3Siglip2ModelArgs +from .siglip2 import VisionTransformer + + +class Projector(nn.Module): + """Project the Encoder embedding to the LLM embedding.""" + + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.w1 = nn.Linear(in_dim, in_dim) + self.w2 = nn.Linear(in_dim, out_dim) + self.init_weights() + + def forward(self, x_NLD: torch.Tensor): + x_NLD = self.w1(x_NLD) + x_NLD = nn.functional.silu(x_NLD) + x_NLD = self.w2(x_NLD) + return x_NLD + + def init_weights(self): + nn.init.xavier_uniform_(self.w1.weight) + if self.w1.bias is not None: + nn.init.zeros_(self.w1.bias) + nn.init.xavier_uniform_(self.w2.weight) + if self.w2.bias is not None: + nn.init.zeros_(self.w2.bias) + + +class Llama3Siglip2Transformer(Llama3): + def __init__(self, model_args: Llama3Siglip2ModelArgs): + super().__init__(model_args) + self.model_args = model_args + self.encoder = VisionTransformer(model_args.encoder) + self.projector = Projector( + in_dim=model_args.encoder.dim, out_dim=model_args.dim + ) + self.n_pixels_per_token = model_args.encoder.patch_size**2 + self.init_encoder_weights() + + def init_encoder_weights(self, buffer_device=None): + super().init_weights(buffer_device=buffer_device) + if self.encoder is not None: + self.encoder.init_weights() + if self.projector is not None: + self.projector.init_weights() + + def _scatter_img_tokens(self, h_BSD, tokens_BS, i_NLD, i_mask_NL, img_id=None): + img_id = img_id or self.model_args.img_token_id + B, S, D = h_BSD.shape + # Where are the image tokens in LLM input, make broadcastable with h_BSD + img_mask_h_BSD = E.repeat(tokens_BS == img_id, "b s -> b s 1") + # Only get valid (non-padded) tokens, result are flatten + i_flatten = torch.masked_select(i_NLD, mask=i_mask_NL.unsqueeze(-1)) + + assert i_flatten.numel() // D == img_mask_h_BSD.sum(), ( + f"Different number of visual embeddings {i_flatten.numel() // D} " + f"with placeholder in input token embeddings {img_mask_h_BSD.sum()}" + ) + h_BSD.masked_scatter_(mask=img_mask_h_BSD, source=i_flatten) + return h_BSD + + def forward( + self, + tokens: torch.Tensor, + eos_id: int | None = None, + input_batch: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + grid_thw: torch.Tensor | None = None, + ): + if self.model_args.use_flex_attn: + init_attention_mask( + input_batch if input_batch is not None else tokens, eos_id=self.eos_id + ) + + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + if self.encoder is not None: + grid_hw = grid_thw[:, :, 1:] # Siglip2 only support image hw + pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") + i_NLD = self.encoder(pixel_values, pixel_masks, grid_hw) + i_NLD = self.projector(i_NLD) + h_BSD = self._scatter_img_tokens(h_BSD, tokens, i_NLD, pixel_masks) + + for layer in self.layers.values(): + h_BSD = layer(h_BSD, self.freqs_cis) + + h_BSD = self.norm(h_BSD) if self.norm else h_BSD + output = self.output(h_BSD) if self.output else h_BSD + return output diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py new file mode 100644 index 0000000000..a1183f7cbb --- /dev/null +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import einops as E +import torch +import torch.nn.functional as F +from torch import nn + +from torchtitan.models.attention import build_attention, init_attention_mask + +from .args import Siglip2ModelArgs + + +def resize_positional_embeddings( + pos_embs_HWD: torch.Tensor, + spatial_shapes_N2: torch.Tensor, + max_length: int, +) -> torch.Tensor: + """ + Resize the learned 2D positional embeddings to image-specific size and pad to a fixed size. + + Args: + pos_embs_HWD (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + _, _, D = pos_embs_HWD.shape + B, _ = spatial_shapes_N2.shape + + resized_embs_BLD = torch.empty( + (B, max_length, D), + device=pos_embs_HWD.device, + dtype=pos_embs_HWD.dtype, + ) + + # TODO: group images by size, and do interpolate, + # or cache the interpolate output so we do this once per size + for i in range(B): + height, width = spatial_shapes_N2[i].tolist() + if (height + width) == 0: # Skip empty padding images + continue + + resized_emb = F.interpolate( + E.rearrange(pos_embs_HWD, "h w d -> 1 d h w"), + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + resized_emb_LD = E.rearrange(resized_emb, "1 d h w -> (h w) d") + resized_embs_BLD[i, : int(height * width)] = resized_emb_LD + + return resized_embs_BLD + + +class VisionEmbeddings(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.patch_embedding = nn.Linear( + in_features=args.n_channels * args.patch_size * args.patch_size, + out_features=args.dim, + ) + self.position_embedding = nn.Embedding(args.n_pos_embs**2, args.dim) + self.n_pos_embs = args.n_pos_embs + + def init_weights(self): + nn.init.trunc_normal_(self.patch_embedding.weight, mean=0.0, std=0.02) + nn.init.normal_(self.position_embedding.weight) + + def forward(self, pixels_NLD: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + # Apply patch embeddings to already patchified pixel values + patch_embeds_NLD = self.patch_embedding(pixels_NLD) + + # Get positional resized and padded positional embeddings + pos_emb_HWD = self.position_embedding.weight.reshape( + self.n_pos_embs, self.n_pos_embs, -1 + ) + spatial_h = E.reduce(grid_hw[:, :, 0], "n l -> n", reduction="max") + 1 + spatial_w = E.reduce(grid_hw[:, :, 1], "n l -> n", reduction="max") + 1 + spatial_shapes = torch.stack([spatial_h, spatial_w], dim=-1).long() + resized_positional_embeddings = resize_positional_embeddings( + pos_emb_HWD, + spatial_shapes, + max_length=pixels_NLD.shape[1], + ) + # Add positional embeddings to patch embeddings + embeddings = patch_embeds_NLD + resized_positional_embeddings + return embeddings + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of query heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + + self.q_proj = nn.Linear(self.dim, self.dim) + self.k_proj = nn.Linear(self.dim, self.dim) + self.v_proj = nn.Linear(self.dim, self.dim) + self.out_proj = nn.Linear(self.dim, self.dim) + + self.attn = build_attention( + use_flex_attn=True, attn_mask_type=args.attn_mask_type + ) + + def forward(self, x: torch.Tensor): + xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Use self.head_dim instead of `n_heads` to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = E.rearrange(xq, "b l (h d) -> b h l d", d=self.head_dim) + xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) + xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) + + output = self.attn(xq, xk, xv) + output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() + + return self.out_proj(output) + + def init_weights(self): + for linear in (self.q_proj, self.k_proj, self.v_proj, self.out_proj): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + + +class FeedForward(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.fc1 = nn.Linear(args.dim, args.ffn_dim) + self.fc2 = nn.Linear(args.ffn_dim, args.dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = F.gelu(x, approximate="tanh") + x = self.fc2(x) + return x + + def init_weights(self): + nn.init.trunc_normal_(self.fc1.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.fc2.weight, mean=0.0, std=0.02) + + +class TransformerLayer(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.layer_norm1 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + self.self_attn = Attention(args) + self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + self.mlp = FeedForward(args) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x)) + x = x + self.mlp(self.layer_norm2(x)) + return x + + def init_weights(self): + self.layer_norm1.reset_parameters() + self.layer_norm2.reset_parameters() + self.self_attn.init_weights() + self.mlp.init_weights() + + +class VisionTransformer(nn.Module): + def __init__(self, args: Siglip2ModelArgs): + super().__init__() + self.args = args + self.eos_id = 11 + + self.embeddings = VisionEmbeddings(args) + self.layers = nn.ModuleDict( + {str(idx): TransformerLayer(args) for idx in range(args.n_layers)} + ) + self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + + def forward( + self, + pixel_values_NLD: torch.FloatTensor, + pixel_masks_NL: torch.BoolTensor, + grid_hw: torch.LongTensor, + ): + init_attention_mask(pixel_masks_NL, eos_id=self.eos_id) + + h = self.embeddings(pixel_values_NLD, grid_hw) + + for layer in self.layers.values(): + h = layer(h) + h = self.post_layernorm(h) + + return h + + def init_weights(self): + self.embeddings.init_weights() + for layer in self.layers.values(): + layer.init_weights() + self.post_layernorm.reset_parameters() diff --git a/torchtitan/experiments/vlm/requirements.txt b/torchtitan/experiments/vlm/requirements.txt new file mode 100644 index 0000000000..d27fa26c68 --- /dev/null +++ b/torchtitan/experiments/vlm/requirements.txt @@ -0,0 +1 @@ +einops diff --git a/torchtitan/experiments/vlm/train_configs/debug_model.toml b/torchtitan/experiments/vlm/train_configs/debug_model.toml new file mode 100644 index 0000000000..56afa6d6de --- /dev/null +++ b/torchtitan/experiments/vlm/train_configs/debug_model.toml @@ -0,0 +1,80 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Llama 3 Siglip2 VLM debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3-siglip2" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +# dataset = "c4_test" +dataset_path = "HuggingFaceM4/OBELICS" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enabled = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/train.py b/torchtitan/train.py index e38446a398..be8e274abf 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -410,6 +410,7 @@ def forward_backward_step( # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], @@ -430,7 +431,7 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + inputs, **extra_inputs, target=targets, losses=losses, input_batch=inputs ) else: self.pp_schedule.step( @@ -449,7 +450,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) + pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id, **extra_inputs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred