From 68f9bcced13bd1a55d0f94aac5d271082c86088f Mon Sep 17 00:00:00 2001 From: pandora <128635000+pandora-s-git@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:09:50 +0100 Subject: [PATCH 1/6] add tokenizer comparison script This script compares the basic `.encode` tokenization between Hugging Face and Mistral Common tokenizers across multiple datasets. --- scripts/compare_tokenizer.py | 218 +++++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 scripts/compare_tokenizer.py diff --git a/scripts/compare_tokenizer.py b/scripts/compare_tokenizer.py new file mode 100644 index 00000000..b042247e --- /dev/null +++ b/scripts/compare_tokenizer.py @@ -0,0 +1,218 @@ +""" +This script compares the basic `.encode` tokenization between Hugging Face and Mistral tokenizers across multiple datasets. +It identifies and reports mismatches in tokenization results, helping ensure consistency between the two tokenizers. +The comparison is limited to the first 10,000 characters to avoid excessively long sequences. +The mismatch rate is calculated as the percentage of samples where tokenization results differ at any point. +Results include the exact token where the mismatch starts, as well as the three tokens immediately before and after the mismatch, totaling 7 tokens. +If no mismatch is detected, it only indicates that the first 10,000 characters are consistent—not necessarily the entire sample. +Some datasets may show a high mismatch rate due to frequent occurrences of the same problematic string, hence we recommend taking a look at the raw results. + +Usage: + python3 compare_tokenizer.py --hf_model --mc_model [--num_samples ] [--hf_token ] [-save-results] + +Arguments: + --hf_model: Model name or path for the Hugging Face tokenizer. + --mc_model: Model name or path for the Mistral-Common tokenizer loaded from Hugging Face, if not provided, will default to hf_model. + --num_samples: Maximum number of samples to test (default: 3000). + --hf_token: Hugging Face token for private models (optional). + --save_results: If provided, saves the raw mismatch results to a JSON file (default: tokenizer_mismatches_data.json). + +Examples: + python3 compare_tokenizer.py --hf_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --save_results + python3 compare_tokenizer.py --hf_model mistralai/Mistral-Nemo-Instruct-2407 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --num_samples 1000 +""" + +import argparse +from transformers import AutoTokenizer +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy +from datasets import load_dataset +from tqdm import tqdm +from typing import List, Dict, Any +from collections import Counter +import json + +def compare_tokenizers( + hf_tokenizer: AutoTokenizer, + mc_tokenizer: MistralTokenizer, + content: str, +) -> Dict[str, Any]: + """Compare tokenization between Hugging Face and Mistral tokenizers.""" + content = content[:10_000] # Limit to 10000 characters to avoid long sequences + hf_tokens = hf_tokenizer.encode(content) + mc_tokens = mc_tokenizer.instruct_tokenizer.tokenizer.encode(content, bos=True, eos=False) + if hf_tokens != mc_tokens: + mismatch_id = next( + (idx for idx, (a, b) in enumerate(zip(hf_tokens, mc_tokens)) if a != b), + min(len(hf_tokens), len(mc_tokens)), + ) + start = max(0, mismatch_id - 3) + end = min(len(hf_tokens), mismatch_id + 4) + hf_str = [hf_tokenizer.decode([t]) for t in hf_tokens[start:end]] + mc_str = [ + mc_tokenizer.decode([t], special_token_policy=SpecialTokenPolicy.KEEP) + for t in mc_tokens[start:end] + ] + return { + "string_content": ''.join(mc_str), + "mistral_common": mc_str, + "hugging_face": hf_str + } + return None + +def process_dataset( + dataset_name: str, + split: str, + hf_tokenizer: AutoTokenizer, + mc_tokenizer: MistralTokenizer, + num_samples: int, + **kwargs, +) -> List[Dict[str, Any]]: + """Process a dataset and collect tokenizer mismatches.""" + mismatches = [] + dataset = load_dataset(dataset_name, split=split, streaming=True) + for i, example in tqdm(enumerate(dataset), desc=f"Processing {dataset_name} ({split})", total=num_samples): + if i >= num_samples: + break + content = example["text"] + mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content) + if mismatch: + mismatch.update(kwargs) + mismatches.append(mismatch) + return mismatches + +def test_tokenizer( + hf_model: str, + mc_model: str, + num_samples: int = 800, + hf_token: str = None, +) -> Dict[str, List[Dict[str, Any]]]: + """Test tokenizer consistency across multiple datasets.""" + print("Loading HF Tokenizer.") + hf_tokenizer = AutoTokenizer.from_pretrained(hf_model, token=hf_token) + print("Loading Mistral-Common Tokenizer.") + mc_tokenizer = MistralTokenizer.from_hf_hub(mc_model, token=hf_token) + + n_per_set = num_samples // 3 + + # Web Data - Mostly English + print(f"Testing Tokenizers on HuggingFaceFW/fineweb with {n_per_set} samples.") + web_mismatches = process_dataset( + "HuggingFaceFW/fineweb", "train", hf_tokenizer, mc_tokenizer, n_per_set + ) + web_mismatch_rate = (len(web_mismatches) / n_per_set) * 100 + print(f"HuggingFaceFW/fineweb: {web_mismatch_rate:.2f}% Sample Mismatch ({len(web_mismatches)}/{n_per_set})") + + # Web PDF Data - Mostly English + print(f"Testing Tokenizers on HuggingFaceFW/finepdfs with {n_per_set} samples.") + pdf_mismatches = process_dataset( + "HuggingFaceFW/finepdfs", "train", hf_tokenizer, mc_tokenizer, n_per_set + ) + pdf_mismatch_rate = (len(pdf_mismatches) / n_per_set) * 100 + print(f"HuggingFaceFW/finepdfs: {pdf_mismatch_rate:.2f}% Sample Mismatch ({len(pdf_mismatches)}/{n_per_set})") + + # Multilingual Creative Writing Data + print(f"Testing Tokenizers on manu/project_gutenberg with {n_per_set} samples.") + multi_mismatches = [] + dataset_info = load_dataset("manu/project_gutenberg", streaming=True) + splits = [s for s in list(dataset_info.keys()) if s != "en"] + n_per_split = n_per_set // len(splits) + collected = 0 + + for split in splits: + if collected >= n_per_set: + break + split_mismatches = process_dataset( + "manu/project_gutenberg", split, hf_tokenizer, mc_tokenizer, n_per_split + ) + multi_mismatches.extend(split_mismatches) + collected += n_per_split + + if collected < n_per_set: + remaining = num_samples - collected + en_mismatches = process_dataset( + "manu/project_gutenberg", "en", hf_tokenizer, mc_tokenizer, remaining + ) + multi_mismatches.extend(en_mismatches) + + multi_mismatch_rate = (len(multi_mismatches) / n_per_set) * 100 + print(f"manu/project_gutenberg: {multi_mismatch_rate:.2f}% Sample Mismatch ({len(multi_mismatches)}/{n_per_set})") + + # Total Mismatch + print("\n=== SUMMARY ===") + print(f"Total Mismatch Rate: {(len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)) / num_samples * 100:.2f}% ( {len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)} / {num_samples} )") + print(f" - HuggingFaceFW/fineweb: {len(web_mismatches) / n_per_set * 100:.2f}% ( {len(web_mismatches)} / {n_per_set} )") + print(f" - HuggingFaceFW/finepdfs: {len(pdf_mismatches) / n_per_set * 100:.2f}% ( {len(pdf_mismatches)} / {n_per_set} )") + print(f" - manu/project_gutenberg: {len(multi_mismatches) / n_per_set * 100:.2f}% ( {len(multi_mismatches)} / {n_per_set} )") + + return { + "web_model_mismatches": web_mismatches, + "pdf_model_mismatches": pdf_mismatches, + "multi_model_mismatches": multi_mismatches, + } + +def generate_mismatch_report(mismatch_results: Dict[str, List[Dict[str, Any]]]) -> None: + """Generate a report of the most frequent tokens in mismatched samples.""" + all_tokens = [] + mismatch_percentages = [] + + # Flatten all mismatches from all datasets + for dataset_key in mismatch_results: + for mismatch in mismatch_results[dataset_key]: + hf_tokens = mismatch["hugging_face"] + mc_tokens = mismatch["mistral_common"] + all_tokens.extend(hf_tokens) + all_tokens.extend(mc_tokens) + + # Count and rank tokens + token_counter = Counter(all_tokens) + most_common = token_counter.most_common(10) # Top 10 + + print("\n=== MISMATCH REPORT: TOP 10 FREQUENT TOKENS ===") + print(f"Total mismatched unique tokens analyzed: {len(token_counter)}") + print("Token -> Frequency (across all mismatches)") + print("-" * 40) + for token, count in most_common: + print(f"{token!r} -> {count} times") + +def save_results(results: Dict[str, List[Dict[str, Any]]], filename: str = "tokenizer_mismatches_data.json") -> None: + """Save the results to a JSON file.""" + with open(filename, "w") as f: + json.dump(results, f, indent=2) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test tokenizer consistency between Hugging Face and Mistral Common.") + parser.add_argument("--hf_model", type=str, help="Model name or path for the HF tokenizer.") + parser.add_argument("--mc_model", type=str, default=None, help="Model name or path for the Mistral-Common tokenizer.") + parser.add_argument("--n", type=int, default=3000, help="Maximum number of samples to test") + parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models") + parser.add_argument("--save_results", action="store_true", help="Save results to a JSON file") + args = parser.parse_args() + results = test_tokenizer(args.hf_model, args.mc_model if args.mc_model else args.hf_model, args.n, args.hf_token) + generate_mismatch_report(results) + if args.save_results: + save_results(results) + +# Output Example +""" +=== SUMMARY === +Total Mismatch Rate: 1.87% ( 56 / 3000 ) + - HuggingFaceFW/fineweb: 0.60% ( 6 / 1000 ) + - HuggingFaceFW/finepdfs: 1.70% ( 17 / 1000 ) + - manu/project_gutenberg: 3.30% ( 33 / 1000 ) + +=== MISMATCH REPORT: TOP 10 FREQUENT TOKENS === +Total mismatched unique tokens analyzed: 199 +Token -> Frequency (across all mismatches) +---------------------------------------- +"'" -> 73 times +' O' -> 44 times +' Jimmy' -> 32 times +' (' -> 22 times +'/' -> 21 times +'Re' -> 18 times +'Produ' -> 18 times +' and' -> 16 times +'gan' -> 16 times +'Reg' -> 16 times +""" From 5d0c69ed33f2e0fa87ff1ca6ee6de797a9333cc6 Mon Sep 17 00:00:00 2001 From: pandora Date: Wed, 12 Nov 2025 16:36:00 +0100 Subject: [PATCH 2/6] max_chars arg and nits --- scripts/compare_tokenizer.py | 111 ++++++++++++++++------------------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/scripts/compare_tokenizer.py b/scripts/compare_tokenizer.py index b042247e..199acc83 100644 --- a/scripts/compare_tokenizer.py +++ b/scripts/compare_tokenizer.py @@ -1,34 +1,31 @@ """ -This script compares the basic `.encode` tokenization between Hugging Face and Mistral tokenizers across multiple datasets. +This script compares the basic `.encode` tokenization between Hugging Face and Mistral Common tokenizers across multiple datasets. It identifies and reports mismatches in tokenization results, helping ensure consistency between the two tokenizers. -The comparison is limited to the first 10,000 characters to avoid excessively long sequences. +The comparison is limited to a user-specified number of characters to avoid excessively long sequences. The mismatch rate is calculated as the percentage of samples where tokenization results differ at any point. Results include the exact token where the mismatch starts, as well as the three tokens immediately before and after the mismatch, totaling 7 tokens. -If no mismatch is detected, it only indicates that the first 10,000 characters are consistent—not necessarily the entire sample. +If no mismatch is detected, it only indicates that the first `max_chars` characters are consistent—not necessarily the entire sample. Some datasets may show a high mismatch rate due to frequent occurrences of the same problematic string, hence we recommend taking a look at the raw results. - Usage: - python3 compare_tokenizer.py --hf_model --mc_model [--num_samples ] [--hf_token ] [-save-results] - + python3 compare_tokenizer.py --hf_model --mc_model [--num_samples ] [--max_chars ] [--hf_token ] [--save_results] Arguments: --hf_model: Model name or path for the Hugging Face tokenizer. --mc_model: Model name or path for the Mistral-Common tokenizer loaded from Hugging Face, if not provided, will default to hf_model. --num_samples: Maximum number of samples to test (default: 3000). + --max_chars: Maximum number of characters to tokenize per sample (default: 10_000). --hf_token: Hugging Face token for private models (optional). --save_results: If provided, saves the raw mismatch results to a JSON file (default: tokenizer_mismatches_data.json). - Examples: python3 compare_tokenizer.py --hf_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --save_results - python3 compare_tokenizer.py --hf_model mistralai/Mistral-Nemo-Instruct-2407 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --num_samples 1000 + python3 compare_tokenizer.py --hf_model mistralai/Mistral-Nemo-Instruct-2407 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --num_samples 1000 --max_chars 5000 """ - import argparse from transformers import AutoTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from datasets import load_dataset from tqdm import tqdm -from typing import List, Dict, Any +from typing import Any from collections import Counter import json @@ -36,9 +33,10 @@ def compare_tokenizers( hf_tokenizer: AutoTokenizer, mc_tokenizer: MistralTokenizer, content: str, -) -> Dict[str, Any]: + max_chars: int = 10_000, +) -> dict[str, Any] | None: """Compare tokenization between Hugging Face and Mistral tokenizers.""" - content = content[:10_000] # Limit to 10000 characters to avoid long sequences + content = content[:max_chars] # Limit to max_chars characters to avoid long sequences hf_tokens = hf_tokenizer.encode(content) mc_tokens = mc_tokenizer.instruct_tokenizer.tokenizer.encode(content, bos=True, eos=False) if hf_tokens != mc_tokens: @@ -66,8 +64,9 @@ def process_dataset( hf_tokenizer: AutoTokenizer, mc_tokenizer: MistralTokenizer, num_samples: int, - **kwargs, -) -> List[Dict[str, Any]]: + max_chars: int, + **kwargs: Any, +) -> list[dict[str, Any]]: """Process a dataset and collect tokenizer mismatches.""" mismatches = [] dataset = load_dataset(dataset_name, split=split, streaming=True) @@ -75,7 +74,7 @@ def process_dataset( if i >= num_samples: break content = example["text"] - mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content) + mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content, max_chars) if mismatch: mismatch.update(kwargs) mismatches.append(mismatch) @@ -84,33 +83,52 @@ def process_dataset( def test_tokenizer( hf_model: str, mc_model: str, - num_samples: int = 800, - hf_token: str = None, -) -> Dict[str, List[Dict[str, Any]]]: - """Test tokenizer consistency across multiple datasets.""" + num_samples: int = 3000, + max_chars: int = 10_000, + hf_token: str | None = None, +) -> dict[str, list[dict[str, Any]]]: + """ + Test tokenizer consistency across multiple datasets. + Output Example: + === SUMMARY === + Total Mismatch Rate: 1.87% ( 56 / 3000 ) + - HuggingFaceFW/fineweb: 0.60% ( 6 / 1000 ) + - HuggingFaceFW/finepdfs: 1.70% ( 17 / 1000 ) + - manu/project_gutenberg: 3.30% ( 33 / 1000 ) + === MISMATCH REPORT: TOP 10 FREQUENT TOKENS === + Total mismatched unique tokens analyzed: 199 + Token -> Frequency (across all mismatches) + ---------------------------------------- + "'" -> 73 times + ' O' -> 44 times + ' Jimmy' -> 32 times + ' (' -> 22 times + '/' -> 21 times + 'Re' -> 18 times + 'Produ' -> 18 times + ' and' -> 16 times + 'gan' -> 16 times + 'Reg' -> 16 times + """ print("Loading HF Tokenizer.") hf_tokenizer = AutoTokenizer.from_pretrained(hf_model, token=hf_token) print("Loading Mistral-Common Tokenizer.") mc_tokenizer = MistralTokenizer.from_hf_hub(mc_model, token=hf_token) - n_per_set = num_samples // 3 - # Web Data - Mostly English print(f"Testing Tokenizers on HuggingFaceFW/fineweb with {n_per_set} samples.") web_mismatches = process_dataset( - "HuggingFaceFW/fineweb", "train", hf_tokenizer, mc_tokenizer, n_per_set + "HuggingFaceFW/fineweb", "train", hf_tokenizer, mc_tokenizer, n_per_set, max_chars ) web_mismatch_rate = (len(web_mismatches) / n_per_set) * 100 print(f"HuggingFaceFW/fineweb: {web_mismatch_rate:.2f}% Sample Mismatch ({len(web_mismatches)}/{n_per_set})") - # Web PDF Data - Mostly English print(f"Testing Tokenizers on HuggingFaceFW/finepdfs with {n_per_set} samples.") pdf_mismatches = process_dataset( - "HuggingFaceFW/finepdfs", "train", hf_tokenizer, mc_tokenizer, n_per_set + "HuggingFaceFW/finepdfs", "train", hf_tokenizer, mc_tokenizer, n_per_set, max_chars ) pdf_mismatch_rate = (len(pdf_mismatches) / n_per_set) * 100 print(f"HuggingFaceFW/finepdfs: {pdf_mismatch_rate:.2f}% Sample Mismatch ({len(pdf_mismatches)}/{n_per_set})") - # Multilingual Creative Writing Data print(f"Testing Tokenizers on manu/project_gutenberg with {n_per_set} samples.") multi_mismatches = [] @@ -118,44 +136,38 @@ def test_tokenizer( splits = [s for s in list(dataset_info.keys()) if s != "en"] n_per_split = n_per_set // len(splits) collected = 0 - for split in splits: if collected >= n_per_set: break split_mismatches = process_dataset( - "manu/project_gutenberg", split, hf_tokenizer, mc_tokenizer, n_per_split + "manu/project_gutenberg", split, hf_tokenizer, mc_tokenizer, n_per_split, max_chars ) multi_mismatches.extend(split_mismatches) collected += n_per_split - if collected < n_per_set: remaining = num_samples - collected en_mismatches = process_dataset( - "manu/project_gutenberg", "en", hf_tokenizer, mc_tokenizer, remaining + "manu/project_gutenberg", "en", hf_tokenizer, mc_tokenizer, remaining, max_chars ) multi_mismatches.extend(en_mismatches) - multi_mismatch_rate = (len(multi_mismatches) / n_per_set) * 100 print(f"manu/project_gutenberg: {multi_mismatch_rate:.2f}% Sample Mismatch ({len(multi_mismatches)}/{n_per_set})") - # Total Mismatch print("\n=== SUMMARY ===") print(f"Total Mismatch Rate: {(len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)) / num_samples * 100:.2f}% ( {len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)} / {num_samples} )") print(f" - HuggingFaceFW/fineweb: {len(web_mismatches) / n_per_set * 100:.2f}% ( {len(web_mismatches)} / {n_per_set} )") print(f" - HuggingFaceFW/finepdfs: {len(pdf_mismatches) / n_per_set * 100:.2f}% ( {len(pdf_mismatches)} / {n_per_set} )") print(f" - manu/project_gutenberg: {len(multi_mismatches) / n_per_set * 100:.2f}% ( {len(multi_mismatches)} / {n_per_set} )") - return { "web_model_mismatches": web_mismatches, "pdf_model_mismatches": pdf_mismatches, "multi_model_mismatches": multi_mismatches, } -def generate_mismatch_report(mismatch_results: Dict[str, List[Dict[str, Any]]]) -> None: +def generate_mismatch_report(mismatch_results: dict[str, list[dict[str, Any]]]) -> None: """Generate a report of the most frequent tokens in mismatched samples.""" all_tokens = [] mismatch_percentages = [] - # Flatten all mismatches from all datasets for dataset_key in mismatch_results: for mismatch in mismatch_results[dataset_key]: @@ -163,11 +175,9 @@ def generate_mismatch_report(mismatch_results: Dict[str, List[Dict[str, Any]]]) mc_tokens = mismatch["mistral_common"] all_tokens.extend(hf_tokens) all_tokens.extend(mc_tokens) - # Count and rank tokens token_counter = Counter(all_tokens) most_common = token_counter.most_common(10) # Top 10 - print("\n=== MISMATCH REPORT: TOP 10 FREQUENT TOKENS ===") print(f"Total mismatched unique tokens analyzed: {len(token_counter)}") print("Token -> Frequency (across all mismatches)") @@ -175,7 +185,7 @@ def generate_mismatch_report(mismatch_results: Dict[str, List[Dict[str, Any]]]) for token, count in most_common: print(f"{token!r} -> {count} times") -def save_results(results: Dict[str, List[Dict[str, Any]]], filename: str = "tokenizer_mismatches_data.json") -> None: +def save_results(results: dict[str, list[dict[str, Any]]], filename: str = "tokenizer_mismatches_data.json") -> None: """Save the results to a JSON file.""" with open(filename, "w") as f: json.dump(results, f, indent=2) @@ -185,34 +195,11 @@ def save_results(results: Dict[str, List[Dict[str, Any]]], filename: str = "toke parser.add_argument("--hf_model", type=str, help="Model name or path for the HF tokenizer.") parser.add_argument("--mc_model", type=str, default=None, help="Model name or path for the Mistral-Common tokenizer.") parser.add_argument("--n", type=int, default=3000, help="Maximum number of samples to test") + parser.add_argument("--max_chars", type=int, default=10_000, help="Maximum number of characters to tokenize per sample") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models") parser.add_argument("--save_results", action="store_true", help="Save results to a JSON file") args = parser.parse_args() - results = test_tokenizer(args.hf_model, args.mc_model if args.mc_model else args.hf_model, args.n, args.hf_token) + results = test_tokenizer(args.hf_model, args.mc_model if args.mc_model else args.hf_model, args.n, args.max_chars, args.hf_token) generate_mismatch_report(results) if args.save_results: save_results(results) - -# Output Example -""" -=== SUMMARY === -Total Mismatch Rate: 1.87% ( 56 / 3000 ) - - HuggingFaceFW/fineweb: 0.60% ( 6 / 1000 ) - - HuggingFaceFW/finepdfs: 1.70% ( 17 / 1000 ) - - manu/project_gutenberg: 3.30% ( 33 / 1000 ) - -=== MISMATCH REPORT: TOP 10 FREQUENT TOKENS === -Total mismatched unique tokens analyzed: 199 -Token -> Frequency (across all mismatches) ----------------------------------------- -"'" -> 73 times -' O' -> 44 times -' Jimmy' -> 32 times -' (' -> 22 times -'/' -> 21 times -'Re' -> 18 times -'Produ' -> 18 times -' and' -> 16 times -'gan' -> 16 times -'Reg' -> 16 times -""" From aedfb3df73e68c127c8b869242680ac4309b52f7 Mon Sep 17 00:00:00 2001 From: pandora Date: Wed, 12 Nov 2025 16:49:26 +0100 Subject: [PATCH 3/6] formatting --- scripts/compare_tokenizer.py | 183 +++++++++++++++++++++++++++-------- 1 file changed, 140 insertions(+), 43 deletions(-) diff --git a/scripts/compare_tokenizer.py b/scripts/compare_tokenizer.py index 199acc83..b9564a07 100644 --- a/scripts/compare_tokenizer.py +++ b/scripts/compare_tokenizer.py @@ -1,33 +1,68 @@ """ -This script compares the basic `.encode` tokenization between Hugging Face and Mistral Common tokenizers across multiple datasets. -It identifies and reports mismatches in tokenization results, helping ensure consistency between the two tokenizers. -The comparison is limited to a user-specified number of characters to avoid excessively long sequences. -The mismatch rate is calculated as the percentage of samples where tokenization results differ at any point. -Results include the exact token where the mismatch starts, as well as the three tokens immediately before and after the mismatch, totaling 7 tokens. -If no mismatch is detected, it only indicates that the first `max_chars` characters are consistent—not necessarily the entire sample. -Some datasets may show a high mismatch rate due to frequent occurrences of the same problematic string, hence we recommend taking a look at the raw results. +This script compares the basic `.encode` tokenization +between Hugging Face and Mistral Common tokenizers +across multiple datasets. + +It identifies and reports mismatches in tokenization results, +helping ensure consistency between the two tokenizers. +The comparison is limited to a user-specified number of characters +to avoid excessively long sequences. + +The mismatch rate is calculated as the percentage of samples +where tokenization results differ at any point. +Results include the exact token where the mismatch starts, +as well as the three tokens immediately before and after, +totaling 7 tokens. + +If no mismatch is detected, it only indicates that +the first `max_chars` characters are consistent— +not necessarily the entire sample. + +Some datasets may show a high mismatch rate due to +frequent occurrences of the same problematic string. +We recommend reviewing the raw results for details. + Usage: - python3 compare_tokenizer.py --hf_model --mc_model [--num_samples ] [--max_chars ] [--hf_token ] [--save_results] + python3 compare_tokenizer.py --hf_model --mc_model + [--num_samples ] + [--max_chars ] + [--hf_token ] + [--save_results] + Arguments: - --hf_model: Model name or path for the Hugging Face tokenizer. - --mc_model: Model name or path for the Mistral-Common tokenizer loaded from Hugging Face, if not provided, will default to hf_model. - --num_samples: Maximum number of samples to test (default: 3000). - --max_chars: Maximum number of characters to tokenize per sample (default: 10_000). - --hf_token: Hugging Face token for private models (optional). - --save_results: If provided, saves the raw mismatch results to a JSON file (default: tokenizer_mismatches_data.json). +- `--hf_model`: Model name or path for the Hugging Face tokenizer. +- `--mc_model`: Model name or path for the Mistral-Common tokenizer. + If not provided, defaults to `hf_model`. +- `--num_samples`: Maximum number of samples to test (default: 3000). +- `--max_chars`: Maximum number of characters to tokenize per sample (default: 10,000). +- `--hf_token`: Hugging Face token for private models (optional). +- `--save_results`: If provided, saves raw mismatch results to a JSON file + (default: `tokenizer_mismatches_data.json`). + Examples: - python3 compare_tokenizer.py --hf_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --save_results - python3 compare_tokenizer.py --hf_model mistralai/Mistral-Nemo-Instruct-2407 --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 --num_samples 1000 --max_chars 5000 + python3 compare_tokenizer.py + --hf_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 + --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 + --save_results + python3 compare_tokenizer.py + --hf_model mistralai/Mistral-Nemo-Instruct-2407 + --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 + --num_samples 1000 + --max_chars 5000 """ + import argparse -from transformers import AutoTokenizer -from mistral_common.tokens.tokenizers.mistral import MistralTokenizer -from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy +import json +from collections import Counter +from typing import Any + from datasets import load_dataset from tqdm import tqdm -from typing import Any -from collections import Counter -import json +from transformers import AutoTokenizer + +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + def compare_tokenizers( hf_tokenizer: AutoTokenizer, @@ -47,17 +82,15 @@ def compare_tokenizers( start = max(0, mismatch_id - 3) end = min(len(hf_tokens), mismatch_id + 4) hf_str = [hf_tokenizer.decode([t]) for t in hf_tokens[start:end]] - mc_str = [ - mc_tokenizer.decode([t], special_token_policy=SpecialTokenPolicy.KEEP) - for t in mc_tokens[start:end] - ] + mc_str = [mc_tokenizer.decode([t], special_token_policy=SpecialTokenPolicy.KEEP) for t in mc_tokens[start:end]] return { - "string_content": ''.join(mc_str), + "string_content": "".join(mc_str), "mistral_common": mc_str, - "hugging_face": hf_str + "hugging_face": hf_str, } return None + def process_dataset( dataset_name: str, split: str, @@ -70,7 +103,11 @@ def process_dataset( """Process a dataset and collect tokenizer mismatches.""" mismatches = [] dataset = load_dataset(dataset_name, split=split, streaming=True) - for i, example in tqdm(enumerate(dataset), desc=f"Processing {dataset_name} ({split})", total=num_samples): + for i, example in tqdm( + enumerate(dataset), + desc=f"Processing {dataset_name} ({split})", + total=num_samples, + ): if i >= num_samples: break content = example["text"] @@ -80,6 +117,7 @@ def process_dataset( mismatches.append(mismatch) return mismatches + def test_tokenizer( hf_model: str, mc_model: str, @@ -118,14 +156,24 @@ def test_tokenizer( # Web Data - Mostly English print(f"Testing Tokenizers on HuggingFaceFW/fineweb with {n_per_set} samples.") web_mismatches = process_dataset( - "HuggingFaceFW/fineweb", "train", hf_tokenizer, mc_tokenizer, n_per_set, max_chars + "HuggingFaceFW/fineweb", + "train", + hf_tokenizer, + mc_tokenizer, + n_per_set, + max_chars, ) web_mismatch_rate = (len(web_mismatches) / n_per_set) * 100 print(f"HuggingFaceFW/fineweb: {web_mismatch_rate:.2f}% Sample Mismatch ({len(web_mismatches)}/{n_per_set})") # Web PDF Data - Mostly English print(f"Testing Tokenizers on HuggingFaceFW/finepdfs with {n_per_set} samples.") pdf_mismatches = process_dataset( - "HuggingFaceFW/finepdfs", "train", hf_tokenizer, mc_tokenizer, n_per_set, max_chars + "HuggingFaceFW/finepdfs", + "train", + hf_tokenizer, + mc_tokenizer, + n_per_set, + max_chars, ) pdf_mismatch_rate = (len(pdf_mismatches) / n_per_set) * 100 print(f"HuggingFaceFW/finepdfs: {pdf_mismatch_rate:.2f}% Sample Mismatch ({len(pdf_mismatches)}/{n_per_set})") @@ -140,34 +188,57 @@ def test_tokenizer( if collected >= n_per_set: break split_mismatches = process_dataset( - "manu/project_gutenberg", split, hf_tokenizer, mc_tokenizer, n_per_split, max_chars + "manu/project_gutenberg", + split, + hf_tokenizer, + mc_tokenizer, + n_per_split, + max_chars, ) multi_mismatches.extend(split_mismatches) collected += n_per_split if collected < n_per_set: remaining = num_samples - collected en_mismatches = process_dataset( - "manu/project_gutenberg", "en", hf_tokenizer, mc_tokenizer, remaining, max_chars + "manu/project_gutenberg", + "en", + hf_tokenizer, + mc_tokenizer, + remaining, + max_chars, ) multi_mismatches.extend(en_mismatches) multi_mismatch_rate = (len(multi_mismatches) / n_per_set) * 100 print(f"manu/project_gutenberg: {multi_mismatch_rate:.2f}% Sample Mismatch ({len(multi_mismatches)}/{n_per_set})") # Total Mismatch print("\n=== SUMMARY ===") - print(f"Total Mismatch Rate: {(len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)) / num_samples * 100:.2f}% ( {len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)} / {num_samples} )") - print(f" - HuggingFaceFW/fineweb: {len(web_mismatches) / n_per_set * 100:.2f}% ( {len(web_mismatches)} / {n_per_set} )") - print(f" - HuggingFaceFW/finepdfs: {len(pdf_mismatches) / n_per_set * 100:.2f}% ( {len(pdf_mismatches)} / {n_per_set} )") - print(f" - manu/project_gutenberg: {len(multi_mismatches) / n_per_set * 100:.2f}% ( {len(multi_mismatches)} / {n_per_set} )") + rate = (len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)) / num_samples * 100 + print( + f"Total Mismatch Rate: {rate:.2f}% " + f"( {len(web_mismatches) + len(pdf_mismatches) + len(multi_mismatches)} / {num_samples} )" + ) + print( + f" - HuggingFaceFW/fineweb: {len(web_mismatches) / n_per_set * 100:.2f}% " + f"( {len(web_mismatches)} / {n_per_set} )" + ) + print( + f" - HuggingFaceFW/finepdfs: {len(pdf_mismatches) / n_per_set * 100:.2f}% " + f"( {len(pdf_mismatches)} / {n_per_set} )" + ) + print( + f" - manu/project_gutenberg: {len(multi_mismatches) / n_per_set * 100:.2f}% " + f"( {len(multi_mismatches)} / {n_per_set} )" + ) return { "web_model_mismatches": web_mismatches, "pdf_model_mismatches": pdf_mismatches, "multi_model_mismatches": multi_mismatches, } + def generate_mismatch_report(mismatch_results: dict[str, list[dict[str, Any]]]) -> None: """Generate a report of the most frequent tokens in mismatched samples.""" all_tokens = [] - mismatch_percentages = [] # Flatten all mismatches from all datasets for dataset_key in mismatch_results: for mismatch in mismatch_results[dataset_key]: @@ -185,21 +256,47 @@ def generate_mismatch_report(mismatch_results: dict[str, list[dict[str, Any]]]) for token, count in most_common: print(f"{token!r} -> {count} times") -def save_results(results: dict[str, list[dict[str, Any]]], filename: str = "tokenizer_mismatches_data.json") -> None: + +def save_results( + results: dict[str, list[dict[str, Any]]], + filename: str = "tokenizer_mismatches_data.json", +) -> None: """Save the results to a JSON file.""" with open(filename, "w") as f: json.dump(results, f, indent=2) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test tokenizer consistency between Hugging Face and Mistral Common.") parser.add_argument("--hf_model", type=str, help="Model name or path for the HF tokenizer.") - parser.add_argument("--mc_model", type=str, default=None, help="Model name or path for the Mistral-Common tokenizer.") + parser.add_argument( + "--mc_model", + type=str, + default=None, + help="Model name or path for the Mistral-Common tokenizer.", + ) parser.add_argument("--n", type=int, default=3000, help="Maximum number of samples to test") - parser.add_argument("--max_chars", type=int, default=10_000, help="Maximum number of characters to tokenize per sample") - parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face token for private models") + parser.add_argument( + "--max_chars", + type=int, + default=10_000, + help="Maximum number of characters to tokenize per sample", + ) + parser.add_argument( + "--hf_token", + type=str, + default=None, + help="Hugging Face token for private models", + ) parser.add_argument("--save_results", action="store_true", help="Save results to a JSON file") args = parser.parse_args() - results = test_tokenizer(args.hf_model, args.mc_model if args.mc_model else args.hf_model, args.n, args.max_chars, args.hf_token) + results = test_tokenizer( + args.hf_model, + args.mc_model if args.mc_model else args.hf_model, + args.n, + args.max_chars, + args.hf_token, + ) generate_mismatch_report(results) if args.save_results: save_results(results) From 60e6e75d2d85f3572bc841ce2101a10fde69fb79 Mon Sep 17 00:00:00 2001 From: pandora Date: Fri, 14 Nov 2025 16:26:21 +0100 Subject: [PATCH 4/6] minor bos fix --- scripts/compare_tokenizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/compare_tokenizer.py b/scripts/compare_tokenizer.py index b9564a07..90dd7e53 100644 --- a/scripts/compare_tokenizer.py +++ b/scripts/compare_tokenizer.py @@ -72,8 +72,8 @@ def compare_tokenizers( ) -> dict[str, Any] | None: """Compare tokenization between Hugging Face and Mistral tokenizers.""" content = content[:max_chars] # Limit to max_chars characters to avoid long sequences - hf_tokens = hf_tokenizer.encode(content) - mc_tokens = mc_tokenizer.instruct_tokenizer.tokenizer.encode(content, bos=True, eos=False) + hf_tokens = hf_tokenizer.encode(content, add_special_tokens=False) + mc_tokens = mc_tokenizer.instruct_tokenizer.tokenizer.encode(content, bos=False, eos=False) if hf_tokens != mc_tokens: mismatch_id = next( (idx for idx, (a, b) in enumerate(zip(hf_tokens, mc_tokens)) if a != b), From 83854cb1c432df5cbc1802cb3508b0ea04bde981 Mon Sep 17 00:00:00 2001 From: pandora Date: Fri, 14 Nov 2025 17:26:34 +0100 Subject: [PATCH 5/6] restructure and new full_compare script --- external/transformers/scripts/config.json | 31 ++ .../scripts/encode_compare_tokenizer.py | 0 .../scripts/full_compare_tokenizer.py | 374 ++++++++++++++++++ 3 files changed, 405 insertions(+) create mode 100644 external/transformers/scripts/config.json rename scripts/compare_tokenizer.py => external/transformers/scripts/encode_compare_tokenizer.py (100%) create mode 100644 external/transformers/scripts/full_compare_tokenizer.py diff --git a/external/transformers/scripts/config.json b/external/transformers/scripts/config.json new file mode 100644 index 00000000..2177e3a5 --- /dev/null +++ b/external/transformers/scripts/config.json @@ -0,0 +1,31 @@ +{ + "text_encode": { + "datasets": [ + { + "name": "HuggingFaceFW/fineweb", + "split": "train", + "column": "text", + "num_samples": 500, + "max_tokens": 2048 + }, + { + "name": "HuggingFaceFW/finepdfs", + "split": "train", + "column": "text", + "num_samples": 500, + "max_tokens": 2048 + } + ] + }, + "text_instruct": { + "datasets": [ + { + "name": "HuggingFaceH4/ultrachat_200k", + "split": "train_sft", + "column": "messages", + "num_samples": 1000, + "max_tokens": 2048 + } + ] + } + } \ No newline at end of file diff --git a/scripts/compare_tokenizer.py b/external/transformers/scripts/encode_compare_tokenizer.py similarity index 100% rename from scripts/compare_tokenizer.py rename to external/transformers/scripts/encode_compare_tokenizer.py diff --git a/external/transformers/scripts/full_compare_tokenizer.py b/external/transformers/scripts/full_compare_tokenizer.py new file mode 100644 index 00000000..db30a009 --- /dev/null +++ b/external/transformers/scripts/full_compare_tokenizer.py @@ -0,0 +1,374 @@ +""" +This script compares tokenization between Hugging Face and Mistral Common tokenizers +across multiple datasets and modes (e.g., text_encode, text_instruct). +It identifies and reports mismatches in tokenization results, +helping ensure consistency between the two tokenizers. +The comparison is limited to a user-specified number of tokens +to avoid excessively long sequences. +The mismatch rate is calculated as the percentage of samples +where tokenization results differ at any point. +Results include the exact token where the mismatch starts, +as well as the three tokens immediately before and after, +totaling 7 tokens. +If no mismatch is detected, it only indicates that +the first `max_tokens` tokens are consistent— +not necessarily the entire sample. +Some datasets may show a high mismatch rate due to +frequent occurrences of the same problematic string. +We recommend reviewing the raw results for details. + +Usage: + python3 compare_tokenizer.py --hf_model [--mc_model ] + [--config ] + [--type ...] + [--hf_token ] + [--save_results] + +Arguments: +- `--hf_model`: Model name or path for the Hugging Face tokenizer (required). +- `--mc_model`: Model name or path for the Mistral-Common tokenizer. + If not provided, defaults to `hf_model`. +- `--config`: Path to JSON config file with mode-specific settings + (default: "external/transformers/scripts/full_compare_tokenizer_config.json"). +- `--type`: Type(s) of tokenization to perform (space-separated). + Supported modes: text_encode, text_instruct, vision_encode, vision_instruct, + tool_call_instruct, text_reasoning, vision_reasoning, tool_call_reasoning. + Default: ["text_encode"]. +- `--hf_token`: Hugging Face token for private models (optional). +- `--save_results`: If provided, saves raw mismatch results and report to JSON files + (default: "tokenizer_mismatches_data.json" and "tokenizer_mismatches_report.json"). + +Examples: + python3 compare_tokenizer.py + --hf_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 + --type text_encode text_instruct + --save_results + + python3 compare_tokenizer.py + --hf_model mistralai/Mistral-Nemo-Instruct-2407 + --mc_model mistralai/Mistral-Small-3.1-24b-Instruct-2503 + --config custom_config.json + --type text_encode text_instruct + --save_results +""" + +import argparse +import json +from collections import Counter +from typing import Any + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + + +def compare_tokenizers( + hf_tokenizer: AutoTokenizer, + mc_tokenizer: MistralTokenizer, + content: str | dict | list, + max_tokens: int = 2_048, + mode: str = "text_encode", +) -> dict[str, Any] | None: + """Compare tokenization between Hugging Face and Mistral tokenizers.""" + if mode == "text_encode": + assert isinstance(content, str), "Content must be a string for text_encode mode." + content_str = content + hf_tokens = hf_tokenizer.encode(content_str, add_special_tokens=False)[:max_tokens] + mc_tokens = mc_tokenizer.instruct_tokenizer.tokenizer.encode(content_str, bos=False, eos=False)[:max_tokens] + elif mode == "text_instruct": + assert isinstance(content, list), "Content must be a list of messages for text_instruct mode." + if content and content[-1]["role"] == "assistant": + content = content[:-1] + if content and content[0]["role"] != "system": + content = [{"role": "system", "content": ""}] + content + hf_tokens = hf_tokenizer.apply_chat_template(content, tokenize=True)[:max_tokens] + request = ChatCompletionRequest(messages=content) + mc_tokens = mc_tokenizer.encode_chat_completion(request).tokens[:max_tokens] + else: + raise NotImplementedError(f"Mode '{mode}' is not yet implemented.") + if hf_tokens != mc_tokens: + mismatch_id = next( + (idx for idx, (a, b) in enumerate(zip(hf_tokens, mc_tokens)) if a != b), + min(len(hf_tokens), len(mc_tokens)), + ) + start = max(0, mismatch_id - 3) + end = min(len(hf_tokens), mismatch_id + 4) + hf_str = [hf_tokenizer.decode([t]) for t in hf_tokens[start:end]] + mc_str = [mc_tokenizer.decode([t], special_token_policy=SpecialTokenPolicy.KEEP) for t in mc_tokens[start:end]] + return { + "string_content": "".join(mc_str), + "mistral_common": mc_str, + "hugging_face": hf_str, + "mode": mode, + } + return None + + +def process_dataset( + dataset_name: str, + split: str, + column: str, + hf_tokenizer: AutoTokenizer, + mc_tokenizer: MistralTokenizer, + num_samples: int, + max_tokens: int, + mode: str, +) -> list[dict[str, Any]]: + """Process a dataset and collect tokenizer mismatches.""" + mismatches = [] + dataset = load_dataset(dataset_name, split=split, streaming=True) + try: + for i, example in tqdm( + enumerate(dataset), + desc=f"Processing {dataset_name} ({split}, mode: {mode})", + total=num_samples, + ): + if i >= num_samples: + break + content = example[column] + mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content, max_tokens, mode) + if mismatch: + mismatch.update( + { + "dataset_name": dataset_name, + "dataset_column": column, + "sample_index": i, + } + ) + mismatches.append(mismatch) + finally: + if hasattr(dataset, "close"): + print(f"Closing dataset {dataset_name}.") + dataset.close() + return mismatches + + +def test_tokenizer( + hf_model: str, + mc_model: str, + modes: list[str], + hf_token: str | None = None, + config: dict[str, Any] = {}, +) -> tuple[dict[str, list[dict[str, Any]]], dict[str, Any]]: + """Test tokenizer consistency for each mode and dataset in the config. + Returns: (mismatch_results, stats) + """ + print("Loading HF Tokenizer.") + hf_tokenizer = AutoTokenizer.from_pretrained(hf_model, token=hf_token) + print("Loading Mistral-Common Tokenizer.") + mc_tokenizer = MistralTokenizer.from_hf_hub(mc_model, token=hf_token) + all_mismatches = {} + stats = {} + + for mode in modes: + mode_config = config.get(mode, {}) + if not mode_config or "datasets" not in mode_config: + print(f"Warning: No config found for mode '{mode}'. Skipping.") + continue + + all_mismatches[mode] = [] + mode_total_samples = 0 + mode_total_mismatches = 0 + mode_dataset_stats = {} + + for dataset_config in mode_config["datasets"]: + dataset_name = dataset_config["name"] + split = dataset_config.get("split", "train") + column = dataset_config.get("column", "text") + num_samples = dataset_config.get("num_samples", 1000) + max_tokens = dataset_config.get("max_tokens", 2048) + + print( + f"\nTesting Tokenizers on {dataset_name} (column: {column}, mode: {mode}) with {num_samples} samples." + ) + mismatches = process_dataset( + dataset_name, + split, + column, + hf_tokenizer, + mc_tokenizer, + num_samples, + max_tokens, + mode, + ) + + dataset_mismatches = len(mismatches) + mismatch_rate = (dataset_mismatches / num_samples) * 100 + print(f"{dataset_name} ({mode}): {mismatch_rate:.2f}% Sample Mismatch ({dataset_mismatches}/{num_samples})") + + all_mismatches[mode].extend(mismatches) + mode_total_samples += num_samples + mode_total_mismatches += dataset_mismatches + mode_dataset_stats[dataset_name] = { + "total_samples": num_samples, + "mismatches": dataset_mismatches, + "mismatch_rate": mismatch_rate, + } + + # Save mode-level stats + mode_mismatch_rate = (mode_total_mismatches / mode_total_samples) * 100 if mode_total_samples > 0 else 0 + stats[mode] = { + "total_samples": mode_total_samples, + "total_mismatches": mode_total_mismatches, + "mismatch_rate": mode_mismatch_rate, + "datasets": mode_dataset_stats, + } + + return all_mismatches, stats + + +def generate_mismatch_report( + mismatch_results: dict[str, list[dict[str, Any]]], stats: dict[str, Any] +) -> dict[str, Any]: + """Generate a detailed report using precomputed stats and return it as a dict.""" + report = {"summary_by_mode": {}, "top_tokens_by_mode": {}, "top_tokens_overall": {}} + + # Summary by mode + for mode, mode_stat in stats.items(): + report["summary_by_mode"][mode] = { + "total_samples": mode_stat["total_samples"], + "total_mismatches": mode_stat["total_mismatches"], + "mismatch_rate": mode_stat["mismatch_rate"], + "datasets": mode_stat["datasets"], + } + + # Top tokens by mode + for mode in mismatch_results: + mode_tokens = [] + for mismatch in mismatch_results[mode]: + mode_tokens.extend(mismatch["hugging_face"]) + mode_tokens.extend(mismatch["mistral_common"]) + token_counter = Counter(mode_tokens) + report["top_tokens_by_mode"][mode] = { + "total_unique_tokens": len(token_counter), + "top_5_tokens": token_counter.most_common(5), + } + + # Top tokens overall + all_tokens = [] + for mode in mismatch_results: + for mismatch in mismatch_results[mode]: + all_tokens.extend(mismatch["hugging_face"]) + all_tokens.extend(mismatch["mistral_common"]) + token_counter = Counter(all_tokens) + report["top_tokens_overall"] = { + "total_unique_tokens": len(token_counter), + "top_10_tokens": token_counter.most_common(10), + } + + # Print the report + print("\n=== MISMATCH SUMMARY BY MODE ===") + for mode, mode_stat in stats.items(): + print(f"\n--- Mode: {mode} ---") + print(f"Total Samples: {mode_stat['total_samples']}") + print(f"Total Mismatches: {mode_stat['total_mismatches']}") + print(f"Mismatch Rate: {mode_stat['mismatch_rate']:.2f}%") + print("Per-Dataset Breakdown:") + for dataset, ds_stat in mode_stat["datasets"].items(): + print( + f" - {dataset}: {ds_stat['mismatch_rate']:.2f}% ({ds_stat['mismatches']}/{ds_stat['total_samples']})" + ) + + print("\n=== TOP 5 MISMATCHED TOKENS BY MODE ===") + for mode in mismatch_results: + mode_tokens = [] + for mismatch in mismatch_results[mode]: + mode_tokens.extend(mismatch["hugging_face"]) + mode_tokens.extend(mismatch["mistral_common"]) + token_counter = Counter(mode_tokens) + most_common = token_counter.most_common(5) + print(f"\n--- Mode: {mode} ---") + print(f"Total unique tokens: {len(token_counter)}") + print("Token -> Frequency") + print("-" * 30) + for token, count in most_common: + print(f"{token!r} -> {count} times") + + print("\n=== OVERALL TOP 10 MISMATCHED TOKENS (ALL MODES) ===") + print(f"Total unique tokens analyzed: {len(token_counter)}") + print("Token -> Frequency (across all mismatches)") + print("-" * 40) + for token, count in token_counter.most_common(10): + print(f"{token!r} -> {count} times") + + return report + + +def save_results( + results: dict[str, list[dict[str, Any]]], + report: dict[str, Any], + mismatches_filename: str = "tokenizer_mismatches_data.json", + report_filename: str = "tokenizer_mismatches_report.json", +) -> None: + """Save the results and report to JSON files.""" + with open(mismatches_filename, "w") as f: + json.dump(results, f, indent=2) + with open(report_filename, "w") as f: + json.dump(report, f, indent=2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test tokenizer consistency between Hugging Face and Mistral Common.") + parser.add_argument( + "--hf_model", + type=str, + required=True, + help="Model name or path for the HF tokenizer.", + ) + parser.add_argument( + "--mc_model", + type=str, + default=None, + help="Model name or path for the Mistral-Common tokenizer.", + ) + parser.add_argument( + "--config", + type=str, + default="external/transformers/scripts/full_compare_tokenizer_config.json", + help="Path to JSON config file with mode-specific settings.", + ) + parser.add_argument( + "--hf_token", + type=str, + default=None, + help="Hugging Face token for private models.", + ) + parser.add_argument( + "--type", + type=str, + nargs="+", + default=["text_encode"], + choices=[ + "text_encode", + "vision_encode", + "text_instruct", + "vision_instruct", + "tool_call_instruct", + "text_reasoning", + "vision_reasoning", + "tool_call_reasoning", + ], + help="Type(s) of tokenization to perform (space-separated).", + ) + parser.add_argument("--save_results", action="store_true", help="Save results to JSON files.") + args = parser.parse_args() + + with open(args.config, "r") as f: + config = json.load(f) + + results, stats = test_tokenizer( + args.hf_model, + args.mc_model if args.mc_model else args.hf_model, + args.type, + args.hf_token, + config, + ) + + report = generate_mismatch_report(results, stats) + + if args.save_results: + save_results(results, report) From cec8f3770b0942c316e33b4553aca198a5424fc4 Mon Sep 17 00:00:00 2001 From: pandora Date: Fri, 14 Nov 2025 19:18:05 +0100 Subject: [PATCH 6/6] cleaning and add random option --- .../scripts/full_compare_tokenizer.py | 138 +++++++++++++----- ...son => full_compare_tokenizer_config.json} | 12 +- 2 files changed, 115 insertions(+), 35 deletions(-) rename external/transformers/scripts/{config.json => full_compare_tokenizer_config.json} (73%) diff --git a/external/transformers/scripts/full_compare_tokenizer.py b/external/transformers/scripts/full_compare_tokenizer.py index db30a009..eeab2aed 100644 --- a/external/transformers/scripts/full_compare_tokenizer.py +++ b/external/transformers/scripts/full_compare_tokenizer.py @@ -28,6 +28,7 @@ - `--hf_model`: Model name or path for the Hugging Face tokenizer (required). - `--mc_model`: Model name or path for the Mistral-Common tokenizer. If not provided, defaults to `hf_model`. +- `--random_only`: If provided, only use random UTF-8 strings and no datasets for quick testing. - `--config`: Path to JSON config file with mode-specific settings (default: "external/transformers/scripts/full_compare_tokenizer_config.json"). - `--type`: Type(s) of tokenization to perform (space-separated). @@ -50,12 +51,17 @@ --config custom_config.json --type text_encode text_instruct --save_results + + python3 compare_tokenizer.py + --hf_model mistralai/Mistral-Small-24b-Instruct-2501 + --type text_encode text_instruct + --random_only """ import argparse import json +import random from collections import Counter -from typing import Any from datasets import load_dataset from tqdm import tqdm @@ -72,7 +78,7 @@ def compare_tokenizers( content: str | dict | list, max_tokens: int = 2_048, mode: str = "text_encode", -) -> dict[str, Any] | None: +) -> dict[str] | None: """Compare tokenization between Hugging Face and Mistral tokenizers.""" if mode == "text_encode": assert isinstance(content, str), "Content must be a string for text_encode mode." @@ -117,43 +123,99 @@ def process_dataset( num_samples: int, max_tokens: int, mode: str, -) -> list[dict[str, Any]]: +) -> list[dict[str]]: """Process a dataset and collect tokenizer mismatches.""" mismatches = [] - dataset = load_dataset(dataset_name, split=split, streaming=True) - try: - for i, example in tqdm( - enumerate(dataset), - desc=f"Processing {dataset_name} ({split}, mode: {mode})", - total=num_samples, - ): - if i >= num_samples: - break - content = example[column] - mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content, max_tokens, mode) - if mismatch: - mismatch.update( - { - "dataset_name": dataset_name, - "dataset_column": column, - "sample_index": i, - } + if dataset_name not in ["random_utf8", "random_instruct_utf8"]: + dataset = load_dataset(dataset_name, split=split, streaming=True) + try: + for i, example in tqdm( + enumerate(dataset), + desc=f"Processing {dataset_name} ({split}, mode: {mode})", + total=num_samples, + ): + if i >= num_samples: + break + content = example[column] + mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content, max_tokens, mode) + if mismatch: + mismatch.update( + { + "dataset_name": dataset_name, + "dataset_column": column, + "sample_index": i, + } + ) + mismatches.append(mismatch) + finally: + if hasattr(dataset, "close"): + print(f"Closing dataset {dataset_name}.") + dataset.close() + else: + if dataset_name == "random_utf8": + for i in tqdm(range(num_samples), desc=f"Processing random_utf8 (mode: {mode})"): + content = "".join([chr(random.randint(0, 127)) for _ in range(max_tokens)]) + mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content, max_tokens, mode) + if mismatch: + mismatch.update( + { + "dataset_name": dataset_name, + "dataset_column": column, + "sample_index": i, + } + ) + mismatches.append(mismatch) + elif dataset_name == "random_instruct_utf8": + for i in tqdm( + range(num_samples), + desc=f"Processing random_instruct_utf8 (mode: {mode})", + ): + content = ( + [ + { + "role": "system", + "content": "".join([chr(random.randint(0, 127)) for _ in range(512)]), + } + ] + if random.random() < 0.5 + else [] ) - mismatches.append(mismatch) - finally: - if hasattr(dataset, "close"): - print(f"Closing dataset {dataset_name}.") - dataset.close() + for i in range(random.randint(1, 10)): + if len(content) == 0: + content.append( + { + "role": "user", + "content": "".join([chr(random.randint(0, 127)) for _ in range(512)]), + } + ) + else: + content.append( + { + "role": ("assistant" if content[-1]["role"] == "user" else "user"), + "content": "".join([chr(random.randint(0, 127)) for _ in range(512)]), + } + ) + mismatch = compare_tokenizers(hf_tokenizer, mc_tokenizer, content, max_tokens, mode) + if mismatch: + mismatch.update( + { + "dataset_name": dataset_name, + "dataset_column": column, + "sample_index": i, + } + ) + mismatches.append(mismatch) return mismatches def test_tokenizer( hf_model: str, mc_model: str, + random_only: bool, modes: list[str], hf_token: str | None = None, - config: dict[str, Any] = {}, -) -> tuple[dict[str, list[dict[str, Any]]], dict[str, Any]]: + config: dict[str] = {}, +) -> tuple[dict[str, list[dict[str]]], dict[str]]: """Test tokenizer consistency for each mode and dataset in the config. Returns: (mismatch_results, stats) """ @@ -177,6 +239,11 @@ def test_tokenizer( for dataset_config in mode_config["datasets"]: dataset_name = dataset_config["name"] + if random_only and dataset_name not in [ + "random_utf8", + "random_instruct_utf8", + ]: + continue split = dataset_config.get("split", "train") column = dataset_config.get("column", "text") num_samples = dataset_config.get("num_samples", 1000) @@ -221,9 +288,7 @@ def test_tokenizer( return all_mismatches, stats -def generate_mismatch_report( - mismatch_results: dict[str, list[dict[str, Any]]], stats: dict[str, Any] -) -> dict[str, Any]: +def generate_mismatch_report(mismatch_results: dict[str, list[dict[str]]], stats: dict[str]) -> dict[str]: """Generate a detailed report using precomputed stats and return it as a dict.""" report = {"summary_by_mode": {}, "top_tokens_by_mode": {}, "top_tokens_overall": {}} @@ -299,8 +364,8 @@ def generate_mismatch_report( def save_results( - results: dict[str, list[dict[str, Any]]], - report: dict[str, Any], + results: dict[str, list[dict[str]]], + report: dict[str], mismatches_filename: str = "tokenizer_mismatches_data.json", report_filename: str = "tokenizer_mismatches_report.json", ) -> None: @@ -331,10 +396,14 @@ def save_results( default="external/transformers/scripts/full_compare_tokenizer_config.json", help="Path to JSON config file with mode-specific settings.", ) + parser.add_argument( + "--random_only", + action="store_true", + help="Only test random UTF-8 data (for quick testing).", + ) parser.add_argument( "--hf_token", type=str, - default=None, help="Hugging Face token for private models.", ) parser.add_argument( @@ -363,6 +432,7 @@ def save_results( results, stats = test_tokenizer( args.hf_model, args.mc_model if args.mc_model else args.hf_model, + args.random_only, args.type, args.hf_token, config, diff --git a/external/transformers/scripts/config.json b/external/transformers/scripts/full_compare_tokenizer_config.json similarity index 73% rename from external/transformers/scripts/config.json rename to external/transformers/scripts/full_compare_tokenizer_config.json index 2177e3a5..3ba66372 100644 --- a/external/transformers/scripts/config.json +++ b/external/transformers/scripts/full_compare_tokenizer_config.json @@ -1,6 +1,11 @@ { "text_encode": { "datasets": [ + { + "name": "random_utf8", + "num_samples": 1000, + "max_tokens": 2048 + }, { "name": "HuggingFaceFW/fineweb", "split": "train", @@ -19,11 +24,16 @@ }, "text_instruct": { "datasets": [ + { + "name": "random_instruct_utf8", + "num_samples": 1000, + "max_tokens": 2048 + }, { "name": "HuggingFaceH4/ultrachat_200k", "split": "train_sft", "column": "messages", - "num_samples": 1000, + "num_samples": 500, "max_tokens": 2048 } ]