diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 8da05f74e..c7fba1849 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -594,6 +594,7 @@ use_untrainable_positional_embedding: False trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size # RoPE parameters rope_type: "default" # one of "default", "llama3.1" or "yarn" +rope_linear_scaling_factor: 1.0 # linear scaling factor for "default" RoPE (see class `RotaryEmbedding` for more) rope_use_scale: True # apply rope scaling for llama3.1 (see class `LLaMARotaryEmbedding` for more) rope_min_timescale: 1 rope_max_timescale: 10_000 # Timescale For global Attention diff --git a/MaxText/configs/models/gemma3-12b.yml b/MaxText/configs/models/gemma3-12b.yml index 5435c2160..3a81de318 100644 --- a/MaxText/configs/models/gemma3-12b.yml +++ b/MaxText/configs/models/gemma3-12b.yml @@ -30,3 +30,4 @@ use_post_attn_norm: true use_post_ffw_norm: true local_rope_max_timescale: 10_000 rope_max_timescale: 1_000_000 +rope_linear_scaling_factor: 8.0 diff --git a/MaxText/configs/models/gemma3-27b.yml b/MaxText/configs/models/gemma3-27b.yml index 26ba0bc67..59a5cf466 100644 --- a/MaxText/configs/models/gemma3-27b.yml +++ b/MaxText/configs/models/gemma3-27b.yml @@ -30,3 +30,4 @@ use_post_attn_norm: true use_post_ffw_norm: true local_rope_max_timescale: 10_000 rope_max_timescale: 1_000_000 +rope_linear_scaling_factor: 8.0 diff --git a/MaxText/configs/models/gemma3-4b.yml b/MaxText/configs/models/gemma3-4b.yml index 2bfffe808..4c1c25c1f 100644 --- a/MaxText/configs/models/gemma3-4b.yml +++ b/MaxText/configs/models/gemma3-4b.yml @@ -30,3 +30,4 @@ use_post_attn_norm: true use_post_ffw_norm: true local_rope_max_timescale: 10_000 rope_max_timescale: 1_000_000 +rope_linear_scaling_factor: 8.0 diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 1325c1a60..cf0fc104f 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -699,11 +699,18 @@ def init_rotary_embedding(self): # For local attention use local_rope_max_timescale if it's is positive if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.local_rope_max_timescale > 0: max_timescale = self.config.local_rope_max_timescale + + rope_linear_scaling_factor = self.config.rope_linear_scaling_factor + # In gemma3, linear scaling factor does not apply to local sliding layers. + if self.config.model_name.startswith("gemma3") and self.attention_type == AttentionType.LOCAL_SLIDING: + rope_linear_scaling_factor = 1.0 + rotary_embedding = RotaryEmbedding( min_timescale=self.config.rope_min_timescale, max_timescale=max_timescale, embedding_dims=rope_embedding_dims, fprop_dtype=self.dtype, + rope_linear_scaling_factor=rope_linear_scaling_factor, rngs=self.rngs, ) return rotary_embedding diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index e73dd38e6..fafafcd08 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -242,6 +242,7 @@ def __init__( fprop_dtype: DType = jnp.bfloat16, # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed + rope_linear_scaling_factor: float = 1.0, rngs: nnx.Rngs = None, ): """Initializes the RotaryEmbedding module. @@ -261,6 +262,7 @@ def __init__( self.embedding_dims = embedding_dims self.cast_as_fprop_dtype = cast_as_fprop_dtype self.fprop_dtype = fprop_dtype + self.rope_linear_scaling_factor = rope_linear_scaling_factor if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") @@ -270,7 +272,10 @@ def timescale(self): """Returns the timescale for the rotary embedding.""" half_embedding_dim = self.embedding_dims // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims - return self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + if self.rope_linear_scaling_factor != 1.0: + timescale = timescale * self.rope_linear_scaling_factor + return timescale def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks @@ -448,9 +453,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax. if len(inputs.shape) != 4: raise ValueError("Input is assumed to be a rank 4 tensor of shape [B, S, N, H].") if self.embedding_dims != inputs.shape[3]: - raise ValueError( - "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." - ) + raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.") # Shift the inputs left and right as per LLaMA's specific behavior inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1) @@ -649,9 +652,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: if len(inputs.shape) != 4: raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, dims].") if self.embedding_dims != inputs.shape[3]: - raise ValueError( - "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." - ) + raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.") # Determine positions if not provided if position is None: diff --git a/MaxText/multimodal_utils.py b/MaxText/multimodal_utils.py index 44cbd94fe..98ddfd98f 100644 --- a/MaxText/multimodal_utils.py +++ b/MaxText/multimodal_utils.py @@ -35,9 +35,9 @@ GEMMA_IMAGE_STD = (127.5,) * 3 GEMMA_IMAGE_PLACEHOLDER_IN_PROMPT = "" GEMMA_BEGIN_IMAGE_TOKEN = 255999 -GEMMA_END_IMAGE_TOKEN = 262144 +GEMMA_END_IMAGE_TOKEN = 256000 GEMMA_NEW_LINE_TOKEN = 108 -GEMMA_TOKEN_PLACEHOLDER = -2 +GEMMA_TOKEN_PLACEHOLDER = 262144 # The number of GEMMA_TOKEN_PLACEHOLDER tokens per image in Gemma3 GEMMA_NUM_PLACEHOLDER_TOKENS_PER_IMAGE = 256 # +4 means 4 extra tokens to pad around image: \n\n, , , \n\n diff --git a/MaxText/scratch_code/generate_hf_golden_logits.py b/MaxText/scratch_code/generate_hf_golden_logits.py index a95debb01..29b8d3387 100644 --- a/MaxText/scratch_code/generate_hf_golden_logits.py +++ b/MaxText/scratch_code/generate_hf_golden_logits.py @@ -15,16 +15,34 @@ Usage: python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite \ - --output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to,Today is a,What is the' \ + --output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to;Today is a;What is the' \ --gcs-bucket=my-gcs-bucket +For large models, you can use an m1 cpu. Calling the script directly instead of calling MaxText module \ +can skip importing unnecessary dependencies. +For large Hugginface checkpoints, you can use pre-downloaded checkpoints with --hf-model-path argument. +For multimodal models, use --image-paths argument to provide image path(s),\ + use --apply-chat-template=true if use HF chat template to format image+prompt.\ + When using chat template, the prompt should not contain image placeholders. + +More examples: +python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=meta-llama/Llama-4-Scout-17B-16E \ + --output-path=golden_Llama-4-Scout-17B-16E_vision.jsonl --prompts='Describe this image.' \ + --apply-chat-template=true --gcs-bucket= --hf-model-path= \ + --image-paths=MaxText/test_assets/test_image.jpg + +python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=google/gemma-3-4b-it \ + --output-path=golden_gemma-3-4b-it_vision.jsonl --prompts='' \ + --apply-chat-template=false --gcs-bucket= --hf-model-path= \ + --image-paths=MaxText/test_assets/test_image.jpg """ import torch import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM import jsonlines from google.cloud import storage +from PIL import Image # Load the tokenizer and model from Hugging Face @@ -37,24 +55,64 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name): blob.upload_from_filename(source_file_name) -def save_golden_logits(model_id, output_path, prompt_texts, gcs_bucket): +def save_golden_logits(model_id, output_path, prompt_texts, apply_chat_template, gcs_bucket, hf_model_path, image_paths): """save golden logits""" - tokenizer = AutoTokenizer.from_pretrained(model_id) + if hf_model_path is None: + hf_model_path = model_id + tokenizer = AutoTokenizer.from_pretrained(hf_model_path) model = AutoModelForCausalLM.from_pretrained( - model_id, + hf_model_path, torch_dtype=torch.float32, trust_remote_code=True, ) all_data_to_save = [] - for prompt_text in prompt_texts: + for i, prompt_text in enumerate(prompt_texts): # Encode the prompt text - input_ids = tokenizer.encode(prompt_text, return_tensors="pt") + if image_paths: + try: + image = Image.open(image_paths[i]) + except Exception as e: + raise e + image = image.convert("RGB") + # TODO (aireenmei): remove this when Llama-4 supports dynamic image shapes. + if model_id.startswith("meta-llama/Llama-4"): + image = image.resize((336, 336)) + processor = AutoProcessor.from_pretrained(model_id, token=True) + if apply_chat_template: + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt_text}, + ], + }, + ] + formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=formatted_prompt, images=image, return_tensors="pt") + else: + formatted_prompt = prompt_text + inputs = processor(text=formatted_prompt, images=image, return_tensors="pt", add_special_tokens=False) + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits.cpu().numpy().astype("float32") - # Get the logits for the prompt + completion - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits.cpu().numpy().astype("float32") + data_to_save = { + "prompt": prompt_text, + "formatted_prompt": formatted_prompt, + "tokens": inputs["input_ids"].tolist()[0], + "attention_mask": inputs["attention_mask"].tolist()[0], + "image_path": image_paths[i], + "pixel_values": inputs["pixel_values"].tolist()[0], + "logits": logits.tolist()[0], + } + else: + input_ids = tokenizer.encode(prompt_text, return_tensors="pt") + # Get the logits for the prompt + completion + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits.cpu().numpy().astype("float32") # Prepare data to be saved data_to_save = { @@ -62,7 +120,9 @@ def save_golden_logits(model_id, output_path, prompt_texts, gcs_bucket): "tokens": input_ids.tolist()[0], "logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization } - all_data_to_save.append(data_to_save) + print(f"Token length is {len(data_to_save['tokens'])} for prompt: {prompt_text}") + print(f"raw ids: {data_to_save['tokens']}") + all_data_to_save.append(data_to_save) with jsonlines.open(output_path, "w") as f: f.write_all(all_data_to_save) @@ -77,13 +137,33 @@ def main(raw_args=None) -> None: parser = argparse.ArgumentParser() parser.add_argument("--model-id", type=str, required=True, help="The identifier of the model to use.") parser.add_argument("--output-path", type=str, required=True, help="The path to save the generated golden logits.") - parser.add_argument("--prompts", type=str, required=True, help="A comma-separated list of prompts.") + parser.add_argument("--prompts", type=str, required=True, help="A semicolon-separated list of prompts.") + parser.add_argument( + "--apply-chat-template", + type=bool, + required=False, + default=False, + help="Whether to apply chat template from the HF processor. Used for image+text input.", + ) parser.add_argument( "--gcs-bucket", type=str, required=False, default=None, help="A GCS bucket to store logits, without gs://." ) + parser.add_argument("--hf-model-path", type=str, required=False, default=None, help="local path to checkpoint if exists.") + parser.add_argument( + "--image-paths", type=str, required=False, default=None, help="A semicolon-separated list of image_paths." + ) args = parser.parse_args(raw_args) - prompts = args.prompts.split(",") - save_golden_logits(args.model_id, args.output_path, prompts, args.gcs_bucket) + prompts = args.prompts.split(";") + image_paths = args.image_paths.split(";") if args.image_paths else [] + if image_paths: + assert len(image_paths) == len( + prompts + ), "when image paths are provided, image_paths and prompts must have the same length." + if args.apply_chat_template: + assert image_paths, "apply_chat_template is only used for image+text input, so image_paths must be provided." + save_golden_logits( + args.model_id, args.output_path, prompts, args.apply_chat_template, args.gcs_bucket, args.hf_model_path, image_paths + ) if __name__ == "__main__": diff --git a/MaxText/tests/check_gemma3_layers.py b/MaxText/tests/check_gemma3_layers.py new file mode 100644 index 000000000..d1c3c33e4 --- /dev/null +++ b/MaxText/tests/check_gemma3_layers.py @@ -0,0 +1,143 @@ +import torch +from torch import nn +import jax +import unittest +import jax.numpy as jnp +import numpy as np +from MaxText.layers import embeddings +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + + +def to_jax(pt_tensor: torch.Tensor) -> jax.Array: + return jnp.asarray(pt_tensor.detach().numpy()) + + +### original Pytorch Reference implementation +class Gemma3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Gemma3RotaryEmbeddingTest(unittest.TestCase): + """Test for Gemma3 RoPE implementation with linear scaling.""" + + def test_rope_compare_pytorch_and_jax(self): + # Config parameters + batch_size = 4 + seq_len = 128 + num_heads = 8 + head_dim = 64 + # embedding_dims = num_heads * head_dim + min_timescale = 1 + max_timescale = 1000000 # 10000 + + # Create random input tensors + q_pt = torch.randn(batch_size, num_heads, seq_len, head_dim) + k_pt = torch.randn(batch_size, num_heads, seq_len, head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) + + # PyTorch reference implementation + class DummyConfig: + + def __init__(self, rope_theta, head_dim, max_position_embeddings): + self.rope_theta = rope_theta + self.head_dim = head_dim + self.max_position_embeddings = max_position_embeddings + self.rope_scaling = {"factor": 8.0, "rope_type": "linear"} + + config = DummyConfig(rope_theta=max_timescale, head_dim=head_dim, max_position_embeddings=seq_len) + + pt_rope = Gemma3RotaryEmbedding(config) + cos_pt, sin_pt = pt_rope(q_pt, position_ids) + q_rope_pt, k_rope_pt = apply_rotary_pos_emb(q_pt, k_pt, cos_pt, sin_pt, position_ids) + + # JAX implementation + jax_rope = embeddings.RotaryEmbedding( + min_timescale=min_timescale, + max_timescale=max_timescale, + embedding_dims=head_dim, + cast_as_fprop_dtype=False, + fprop_dtype=jnp.float32, + rope_linear_scaling_factor=8.0, + ) + + # JAX expects [B, S, N, H] + q_jax = to_jax(q_pt.permute(0, 2, 1, 3)) + k_jax = to_jax(k_pt.permute(0, 2, 1, 3)) + position_jax = to_jax(position_ids) + + # Apply JAX rotary embedding + q_rope_jax = jax_rope(q_jax, position=position_jax) + k_rope_jax = jax_rope(k_jax, position=position_jax) + + # Compare outputs + np.testing.assert_allclose(to_jax(q_rope_pt.permute(0, 2, 1, 3)), q_rope_jax, rtol=1e-3, atol=0.05) + np.testing.assert_allclose(to_jax(k_rope_pt.permute(0, 2, 1, 3)), k_rope_jax, rtol=1e-3, atol=0.05) + + +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index 41497ccc4..848c6275f 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -43,6 +43,7 @@ import jax.numpy as jnp import jsonlines import torch.nn.functional as F +from google.cloud import storage import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -58,6 +59,14 @@ from MaxText.layers import quantizations +def upload_blob(bucket_name, source_file_name, destination_blob_name): + """Uploads a file to the bucket.""" + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(destination_blob_name) + blob.upload_from_filename(source_file_name) + + def get_top_k_tokens_scores(logits_tensor, tokenizer_instance, k=10, description=""): """Get the top-k tokens and their scores from a given logits tensor.""" max_logging.log(f"\n--- {description} top {k} tokens ---") @@ -172,13 +181,24 @@ def check_kl_divergence(model_logits, golden_logits, atol=0.02): assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.6f} exceed the threshold {atol}" -def get_data(golden_data, golden_data_index, config): +def get_data(golden_data_point, config): """Get the golden data for the test indexed at golden_data_index""" - max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index}") max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}") + if config.use_multimodal: + assert "pixel_values" in golden_data_point, "no image found in golden data while use_multimodal=True" + pixel_values = np.asarray(golden_data_point["pixel_values"], dtype=np.float32) + max_logging.log(f"pixel_values.shape = {pixel_values.shape}") + model_prefix = config.model_name.split("-")[0] + if model_prefix in ["gemma3"]: + pixel_values = np.transpose(pixel_values, (1, 2, 0)) + elif model_prefix in ["llama4"]: + pixel_values = pixel_values[None, :] + pixel_values = np.stack([pixel_values for _ in range(config.global_batch_size_to_train_on)]) + else: + pixel_values = None - original_ids = np.asarray(golden_data[golden_data_index]["tokens"], dtype=np.int32) + original_ids = np.asarray(golden_data_point["tokens"], dtype=np.int32) seq_len = len(original_ids) if seq_len > config.max_target_length: @@ -192,62 +212,57 @@ def get_data(golden_data, golden_data_index, config): padded_ids = np.pad(original_ids, (0, config.max_target_length - seq_len), "constant", constant_values=0) ids = np.stack([padded_ids for _ in range(config.global_batch_size_to_train_on)]) - logits = np.asarray(golden_data[golden_data_index]["logits"], dtype=np.float32) - max_logging.log( - f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={original_ids}, logits.shape = {logits.shape}" - ) + logits = np.asarray(golden_data_point["logits"], dtype=np.float32) + if "formatted_prompt" in golden_data_point: + prompt = golden_data_point["formatted_prompt"] + else: + prompt = golden_data_point["prompt"] + max_logging.log(f' prompt="{prompt}" raw ids={original_ids}, logits.shape = {logits.shape}') decoder_segment_ids = np.zeros(s, dtype=np.int32) decoder_segment_ids[:, :seq_len] = DECODING_ACTIVE_SEQUENCE_INDICATOR decoder_positions = np.stack( [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)] ) - - return ids, decoder_segment_ids, decoder_positions, logits, seq_len + return ids, decoder_segment_ids, decoder_positions, logits, seq_len, pixel_values def main(config, test_args): # pylint: disable=W0621 """Test the Whole Model of model_name""" if not test_args.run_hf_model: """Comparing maxtext/huggingface model with pre-loaded golden logitis""" - # initialize the model with weights from reference ckpt - if test_args.hf_model_path != "": # Initialize model from the given HF path - model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path) - else: # Initialize MaxText model - init_rng = jax.random.PRNGKey(config.init_weights_seed) - init_rng, rng1 = jax.random.split(init_rng) - devices_array = maxtext_utils.create_device_mesh(config) - mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - quant = quantizations.configure_quantization(config) - model = models.Transformer(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None) + max_logging.log("Initializing MaxText model") + init_rng = jax.random.PRNGKey(config.init_weights_seed) + init_rng, rng1 = jax.random.split(init_rng) + devices_array = maxtext_utils.create_device_mesh(config) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) + quant = quantizations.configure_quantization(config) + model = models.Transformer(config, mesh=mesh, quant=quant) + state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None) if test_args.golden_logits_path == "": input_golden_data_path = os.path.join(PKG_DIR, "test_assets", f"golden_data_{config.model_name}.jsonl") else: input_golden_data_path = test_args.golden_logits_path + max_logging.log("loading hf goldens from jsonl file") with jsonlines.open(input_golden_data_path, "r") as f: golden_data = list(f) - - for golden_data_index in range(len(golden_data)): - ids, decoder_segment_ids, decoder_positions, golden_logits, seq_len = get_data( - golden_data, golden_data_index, config + max_logging.log(f"loaded {len(golden_data)} golden data points") + all_data_to_save = [] + for golden_data_index, golden_data_point in enumerate(golden_data): + max_logging.log(f"--- Comparing forward pass for golden data index: {golden_data_index} ---") + ids, decoder_segment_ids, decoder_positions, golden_logits, seq_len, images = get_data(golden_data_point, config) + max_logging.log("maxtext forward pass") + full_train_logits = model.apply( + state.params, + ids, + decoder_positions, + decoder_segment_ids, + encoder_images=images, + enable_dropout=False, + rngs={"aqt": init_rng}, ) - if test_args.hf_model_path != "": - with torch.no_grad(): - full_train_logits = model(torch.tensor(ids.tolist())).logits.cpu().numpy().astype("float32") - else: - # TODO(hengtaoguo): Add support for multimodal full prompt decoding check - full_train_logits = model.apply( - state.params, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - rngs={"aqt": init_rng}, - ) - full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits) # if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size] if full_train_logits.ndim == 4: @@ -255,17 +270,41 @@ def main(config, test_args): # pylint: disable=W0621 # Slice to original sequence length full_train_logits = full_train_logits[:, :seq_len, :] - max_logging.log(f"{golden_logits[2]=}") - max_logging.log(f"{full_train_logits[0, 2, :]=}") token_size = int(test_args.token_size) if test_args.token_size else seq_len - # min_vocab_size = min(full_train_logits.shape[-1], golden_logits.shape[-1]) - max_diff = np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :])) - max_logging.log(f"Max Numerical Difference {max_diff}") + if full_train_logits.shape[-1] != golden_logits.shape[-1]: + max_logging.log( + f"Vocab size mismatch: train logits vocab size {full_train_logits.shape[-1]}, " + f"golden logits vocab size {golden_logits.shape[-1]}. " + "Comparing up to the smaller vocab size." + ) + min_vocab_size = min(full_train_logits.shape[-1], golden_logits.shape[-1]) + train_logits_slice = full_train_logits[0, :token_size, :min_vocab_size] + golden_logits_slice = golden_logits[:token_size, :min_vocab_size] + max_logging.log(f"{golden_logits_slice[2]=}") + max_logging.log(f"{train_logits_slice[2]=}") + + # Calculate absolute and relative differences for detailed reporting + abs_diff = jnp.abs(train_logits_slice - golden_logits_slice) + + # To avoid division by zero, add a small epsilon where golden_logits_slice is zero + safe_golden_logits = jnp.where(golden_logits_slice == 0, 1e-8, golden_logits_slice) + rel_diff = abs_diff / jnp.abs(safe_golden_logits) + + max_abs_diff_idx = jnp.unravel_index(jnp.argmax(abs_diff), abs_diff.shape) + max_rel_diff_idx = jnp.unravel_index(jnp.argmax(rel_diff), rel_diff.shape) + + max_abs_diff_val = abs_diff[max_abs_diff_idx] + max_rel_diff_val = rel_diff[max_rel_diff_idx] + msg = ( + f"Max absolute difference: {max_abs_diff_val:.6f} at index {max_abs_diff_idx}\n" + f" (Train: {train_logits_slice[max_abs_diff_idx]:.6f}, Golden: {golden_logits_slice[max_abs_diff_idx]:.6f})\n" + f"Max relative difference: {max_rel_diff_val:.6f} at index {max_rel_diff_idx}\n" + f" (Train: {train_logits_slice[max_rel_diff_idx]:.6f}, Golden: {golden_logits_slice[max_rel_diff_idx]:.6f})" + ) + max_logging.log(msg) - # model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :min_vocab_size], axis=-1) - # golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :min_vocab_size], axis=-1) - model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1) - golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) + model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1) + golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1) max_logging.log(f"{golden_probabilities[1]=}") max_logging.log(f"{model_probabilities[1]=}") @@ -273,27 +312,34 @@ def main(config, test_args): # pylint: disable=W0621 kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1) max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}") + if jax.process_index() == 0 and test_args.output_logits_path: + data_to_save = { + "prompt": golden_data[golden_data_index]["prompt"], + "tokens": ids[0, :seq_len].tolist(), + "logits": full_train_logits[0].tolist(), + } + all_data_to_save.append(data_to_save) + if test_args.max_kl_div is not None: max_logging.log("Checking KL Divergence between train distribution and " "golden distribution") assert jax.numpy.all( - kl_div < test_args.max_kl_div - ), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" # pylint: disable=C0301 - else: - max_logging.log("Checking Numerical Differences between train logits and golden logits") # pylint: disable=C0301 - assert jax.numpy.allclose( - full_train_logits[0, :token_size, :], - golden_logits[:token_size, :], - rtol=float(test_args.rtol), - atol=float(test_args.atol), - equal_nan=False, - ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." # pylint: disable=C0301 + kl_div < test_args.max_kl_div, + ), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" + + max_logging.log("Checking Numerical Differences between train logits and golden logits against the provided atol, rtol.") # pylint: disable=C0301 + rtol_val = float(test_args.rtol) + atol_val = float(test_args.atol) + assert jax.numpy.allclose( + train_logits_slice, golden_logits_slice, rtol=rtol_val, atol=atol_val, equal_nan=False + ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." + else: """Comparing maxtext model with HF model on-the-fly""" if test_args.hf_model_path == "": raise ValueError hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(test_args.hf_model_path) - if 'Llama-3.1' in test_args.hf_model_path: + if "Llama-3.1" in test_args.hf_model_path: tokenizer.pad_token = tokenizer.eos_token init_rng = jax.random.PRNGKey(config.init_weights_seed) @@ -305,13 +351,12 @@ def main(config, test_args): # pylint: disable=W0621 maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) prompts = ["I love to", "Today is a", "What is the"] + all_data_to_save = [] for input_text in prompts: max_logging.log(f"\n--- Prompt: {input_text} ---") # Tokenize for HF - inputs = tokenizer( - input_text, return_tensors="pt", padding=True, max_length=config.max_target_length, truncation=True - ) + inputs = tokenizer(input_text, return_tensors="pt", padding=True, max_length=config.max_target_length, truncation=True) actual_seq_len = inputs["input_ids"].shape[1] # Tokenize for MaxText @@ -358,6 +403,26 @@ def main(config, test_args): # pylint: disable=W0621 # --- Compare all logits in the sequence (for the first batch item) --- # Unsqueeze to add batch dimension for check_kl_divergence: [1, seq, vocab] check_kl_divergence(mt_logits_torch[0].unsqueeze(0), hf_logits_torch[0].unsqueeze(0), atol=test_args.max_kl_div) + if jax.process_index() == 0 and test_args.output_logits_path: + data_to_save = { + "mt_logits": mt_logits_torch[0].tolist(), + "hf_logits": hf_logits_torch[0].tolist(), + } + all_data_to_save.append(data_to_save) + + if jax.process_index() == 0 and test_args.output_logits_path: + os.makedirs(os.path.dirname(test_args.output_logits_path), exist_ok=True) + with jsonlines.open(test_args.output_logits_path, "a") as f: + f.write(all_data_to_save) + max_logging.log(f"Saved logits to {test_args.output_logits_path}") + + if test_args.gcs_output_logits_path: + bucket_name = test_args.gcs_output_logits_path.split("/")[2] + destination_blob_name = "/".join( + test_args.gcs_output_logits_path.split("/")[3:] + test_args.output_logits_path.split("/")[-1:] + ) + upload_blob(bucket_name, test_args.output_logits_path, destination_blob_name) + max_logging.log(f"Uploaded logits to {test_args.gcs_output_logits_path}") if __name__ == "__main__": @@ -372,6 +437,8 @@ def main(config, test_args): # pylint: disable=W0621 parser.add_argument("--golden_logits_path", type=str, required=False, default="") parser.add_argument("--hf_model_path", type=str, required=False, default="") parser.add_argument("--run_hf_model", type=bool, required=False, default=False) + parser.add_argument("--output_logits_path", type=str, required=False, default="") + parser.add_argument("--gcs_output_logits_path", type=str, required=False, default="") test_args, _ = parser.parse_known_args() # Remove args defined in this test file to avoid error from pyconfig @@ -384,9 +451,15 @@ def main(config, test_args): # pylint: disable=W0621 "--golden_logits_path", "--hf_model_path", "--run_hf_model", + "--output_logits_path", + "--gcs_output_logits_path", ] for arg in to_remove_args: model_args = [s for s in model_args if not s.startswith(arg)] cfg = pyconfig.initialize(model_args) + if cfg.use_multimodal: + assert ( + not test_args.run_hf_model + ), "Multimodal does not support running hf model on-the-fly, please generate hf golden logits using generate_hf_golden_logits.py" main(cfg, test_args) diff --git a/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/MaxText/utils/ckpt_conversion/utils/param_mapping.py index 675690061..df56c69ca 100644 --- a/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -34,6 +34,7 @@ parameter from the source checkpoint and build the target checkpoint. """ +import warnings import numpy as np import jax @@ -181,8 +182,10 @@ def pad_and_scale_embedding(input_tensor, target_shape): # Handle padding/truncation if source_vocab_size > target_vocab_size: + warnings.warn(f"source vocab={source_vocab_size} > target vocab={target_vocab_size}, truncate output layer for MaxText.") output_tensor = scaled_tensor[:target_vocab_size, :] elif source_vocab_size < target_vocab_size: + warnings.warn(f"source vocab={source_vocab_size} < target vocab={target_vocab_size}, pad output layer for MaxText.") padding_shape = (target_vocab_size - source_vocab_size, target_hidden_size) # Use jnp.zeros for JAX arrays, np.zeros for numpy arrays padding = (