diff --git a/.gitignore b/.gitignore index bbe50dd..c1ce6f6 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,5 @@ scores *.tar.xz *.tar.lzma *.tar.lz -*.jsonl \ No newline at end of file +*.jsonl +pretrain \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 8bd305d..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,10 +0,0 @@ -repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.12.11 - hooks: - # Run the linter. - - id: ruff-check - args: [ --fix, --exit-non-zero-on-fix, --select=F,E,W,I ] - # Run the formatter. - - id: ruff-format \ No newline at end of file diff --git a/configs/config_mdbert_ft.yaml b/configs/config_ft.yaml similarity index 100% rename from configs/config_mdbert_ft.yaml rename to configs/config_ft.yaml diff --git a/configs/config_ft_bi.yaml b/configs/config_ft_bi.yaml new file mode 100644 index 0000000..50c531f --- /dev/null +++ b/configs/config_ft_bi.yaml @@ -0,0 +1,37 @@ + +logging_steps: 50 +# log_level_replica: info + +inf_free: false +model_name_or_path: bert-base-uncased +tokenizer_name: bert-base-uncased +use_l0: false + +max_seq_length: 256 +train_file: data/msmarco_hard_negatives +data_type: kd +loss_types: [marginmse] +sample_num_one_query: 2 +use_in_batch_negatives: false +flops_start_T: 0 +flops_d_lambda: 4 +flops_d_T: 50000 +flops_q_lambda: 8 +flops_q_T: 50000 +flops_d_thresh: 150 + +output_dir: output/paper/bi/bert-base-uncased +per_device_eval_batch_size: 50 +per_device_train_batch_size: 20 + +log_level: info +max_steps: 150000 +fp16: true +learning_rate: 0.00002 +weight_decay: 0.01 +lr_scheduler_type: linear +warmup_steps: 6000 +save_strategy: steps +save_steps: 150000 +dataloader_drop_last: true +max_grad_norm: null diff --git a/configs/config_ft_noidf.yaml b/configs/config_ft_noidf.yaml new file mode 100644 index 0000000..bfe99d6 --- /dev/null +++ b/configs/config_ft_noidf.yaml @@ -0,0 +1,32 @@ +logging_steps: 50 +# log_level_replica: info + +inf_free: true +model_name_or_path: bert-base-uncased +tokenizer_name: bert-base-uncased +use_l0: false + +max_seq_length: 256 +train_file: data/msmarco_hard_negatives +data_type: kd +loss_types: [kldiv] +sample_num_one_query: 2 +use_in_batch_negatives: false +flops_d_lambda: 0.004 +flops_d_T: 10000 + +output_dir: output/paper/woidf/bert-base-uncased +per_device_eval_batch_size: 50 +per_device_train_batch_size: 20 + +log_level: info +max_steps: 100000 +fp16: true +learning_rate: 0.00002 +weight_decay: 0.01 +lr_scheduler_type: linear +warmup_steps: 6000 +save_strategy: steps +save_steps: 100000 +dataloader_drop_last: true +max_grad_norm: null \ No newline at end of file diff --git a/dataloader/jsonl_in_seq/jsonl_in_seq.py b/dataloader/jsonl_in_seq/jsonl_in_seq.py index 70480e9..bf6ef35 100644 --- a/dataloader/jsonl_in_seq/jsonl_in_seq.py +++ b/dataloader/jsonl_in_seq/jsonl_in_seq.py @@ -6,6 +6,8 @@ import glob import os +import random +import re from typing import Iterator, List, Tuple import datasets @@ -66,6 +68,14 @@ def _split_generators(self, dl_manager: datasets.DownloadManager): candidates = [line.strip() for line in f if line.strip()] files = self._expand_file_patterns(candidates) + # Optionally shuffle file order before any iteration/printing + shuffle_env = ( + os.environ.get("JSONL_LOCAL_SHUFFLE_FILES", "false").strip().lower() + ) + if shuffle_env in {"1", "true", "yes", "y", "on"}: + random.seed(0) + random.shuffle(files) + print(f"All files: {files}") # No download: files are local paths listed in manifest return [ datasets.SplitGenerator( @@ -86,6 +96,14 @@ def _expand_file_patterns(self, candidates: List[str]) -> List[str]: Only include existing files (exclude directories). """ expanded: List[str] = [] + # Optional: threshold for numeric suffix when using wildcard patterns + suffix_max_env = os.environ.get("JSONL_LOCAL_SUFFIX_MAX") + suffix_max: int | None = None + if suffix_max_env is not None: + try: + suffix_max = int(suffix_max_env) + except Exception: + suffix_max = None for pattern in candidates: has_wildcard = ( ("*" in pattern) @@ -96,6 +114,13 @@ def _expand_file_patterns(self, candidates: List[str]) -> List[str]: if has_wildcard: matches = glob.glob(pattern, recursive=True) file_matches = [p for p in matches if os.path.isfile(p)] + # If threshold is set, filter by parsed numeric suffix. If parsing fails, keep. + if suffix_max is not None: + file_matches = [ + p + for p in file_matches + if self._accept_file_by_suffix(p, suffix_max) + ] file_matches.sort() expanded.extend(file_matches) else: @@ -114,8 +139,26 @@ def _expand_file_patterns(self, candidates: List[str]) -> List[str]: def _generate_examples( self, files: List[str], split_name: str ) -> Iterator[Tuple[str, dict]]: + print(f"Generating examples for {split_name} from {files}") # validation: every 100th line (global index % 100 == 0) # train: the rest + # Determine validation ratio via env (default 1%) by converting to a stride: every Nth line goes to validation + val_ratio_env = os.environ.get("JSONL_LOCAL_VAL_RATIO") + default_ratio = 0.01 + try: + if val_ratio_env is not None: + ratio = float(val_ratio_env) + else: + ratio = default_ratio + except Exception: + ratio = default_ratio + # clamp ratio to (0, 1] + if ratio <= 0: + ratio = default_ratio + if ratio > 1: + ratio = 1.0 + validation_stride = max(1, int(round(1.0 / ratio))) + global_index = 0 for file_idx, fp in enumerate(files): if not os.path.exists(fp): @@ -140,7 +183,7 @@ def _generate_examples( global_index += 1 continue - is_validation = global_index % 100 == 0 + is_validation = global_index % validation_stride == 0 if split_name == "validation" and not is_validation: global_index += 1 continue @@ -151,3 +194,29 @@ def _generate_examples( key = f"{file_idx}-{line_idx}" yield key, {"text": row["text"]} global_index += 1 + + def _accept_file_by_suffix(self, path: str, max_suffix: int) -> bool: + """ + Return True if the file should be kept under the numeric suffix threshold. + Logic: + - Strip multi-part extensions (e.g., .json.gz -> base name) + - Find the last continuous digit sequence in the remaining base name + - If digits are found and parseable, keep only if number <= max_suffix + - If parsing fails or digits not found, keep (as requested) + """ + name = os.path.basename(path) + base = name + # Strip all extensions iteratively (handles .json.gz, .tar.gz, etc.) + while True: + base_no_ext, ext = os.path.splitext(base) + if not ext: + break + base = base_no_ext + match = re.search(r"(\d+)(?!.*\d)", base) + if not match: + return True + try: + value = int(match.group(1)) + except Exception: + return True + return value <= max_suffix diff --git a/evaluate_beir.py b/evaluate_beir.py index 7dc7895..7654edf 100644 --- a/evaluate_beir.py +++ b/evaluate_beir.py @@ -18,7 +18,7 @@ from scripts.args import nano_beir_datasets, parse_args from scripts.dataset.data_utils import cached -from scripts.dataset.dataset import BEIRCorpusDataset +from scripts.dataset.dataset import BEIRCorpusDataset, HFDatasetWrapper from scripts.ingest import ingest from scripts.search import search from scripts.utils import emit_metrics, get_model, set_logging @@ -51,10 +51,14 @@ def get_suffix(model_args, data_args): def load_beir_from_hf( dataset_name: str = "nfcorpus", split: str = "test", + load_corpus: bool = True, ) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]: - ds_corpus = load_dataset( - f"BEIR/{dataset_name}", "corpus", split="corpus", trust_remote_code=True - ) + if load_corpus: + ds_corpus = load_dataset( + f"BEIR/{dataset_name}", "corpus", split="corpus", trust_remote_code=True + ) + else: + ds_corpus = None ds_queries = load_dataset( f"BEIR/{dataset_name}", "queries", split="queries", trust_remote_code=True ) @@ -64,8 +68,11 @@ def load_beir_from_hf( # Build BEIR-style corpus corpus: Dict[str, Dict[str, str]] = {} - for r in ds_corpus: - corpus[str(r["_id"])] = {"title": r["title"], "text": r["text"]} + if load_corpus: + for r in ds_corpus: + corpus[str(r["_id"])] = {"title": r["title"], "text": r["text"]} + else: + corpus = None # Build BEIR-style queries queries: Dict[str, str] = {} @@ -147,14 +154,22 @@ def evaluate_beir(model_args, data_args, training_args, model, accelerator): } avg_res = dict() for dataset in datasets: - corpus, queries, qrels = load_beir_from_hf(dataset_name=dataset, split="test") + _, queries, qrels = load_beir_from_hf( + dataset_name=dataset, split="test", load_corpus=False + ) + corpus = HFDatasetWrapper( + load_dataset( + f"BEIR/{dataset}", "corpus", split="corpus", trust_remote_code=True + ), + sample_function=lambda x: (x["_id"], x["title"] + " " + x["text"]), + ) logger.info( f"Loaded {dataset} with {len(corpus)} documents and {len(queries)} queries" ) if not data_args.skip_ingest: asyncio.run( ingest( - dataset=BEIRCorpusDataset(corpus=corpus), + dataset=corpus, model=model, out_dir=beir_eval_dir, index_name=dataset, @@ -352,6 +367,7 @@ def main(): model = get_model(model_args) accelerator = Accelerator(mixed_precision="fp16") + accelerator.prepare(model) accelerator.wait_for_everyone() evaluate_beir(model_args, data_args, training_args, model, accelerator) @@ -369,6 +385,7 @@ def main(): model_args.model_name_or_path, "idf.json" ) model = get_model(model_args) + accelerator.prepare(model) evaluate_nano_beir( model_args, data_args, training_args, model, accelerator, step ) diff --git a/evaluate_marco.py b/evaluate_marco.py new file mode 100644 index 0000000..a66f844 --- /dev/null +++ b/evaluate_marco.py @@ -0,0 +1,211 @@ +import asyncio +import json +import logging +import os +import sys +from collections import defaultdict +from datetime import datetime +from typing import Dict + +import ir_datasets +from accelerate import Accelerator +from beir.retrieval.evaluation import EvaluateRetrieval +from datasets import load_dataset +from transformers import ( + set_seed, +) + +from evaluate_beir import get_suffix, load_beir_from_hf, prepare_model_args +from scripts.args import parse_args +from scripts.dataset.dataset import HFDatasetWrapper +from scripts.ingest import ingest +from scripts.search import search +from scripts.utils import emit_metrics, get_model, set_logging + +logger = logging.getLogger(__name__) + + +def load_trec_dl(year): + ds = ir_datasets.load(f"msmarco-passage/trec-dl-{year}") + queries = {q.query_id: q.text for q in ds.queries_iter()} + + qrels = defaultdict(dict) + for qr in ds.qrels_iter(): + qrels[qr.query_id][qr.doc_id] = qr.relevance # relevance is graded (0..3) + + queries = {qid: query for qid, query in queries.items() if qid in qrels} + return queries, qrels + + +def _sorted_docids_by_score(run_doc_scores: Dict[str, float]): + return [ + doc_id + for doc_id, _ in sorted( + run_doc_scores.items(), key=lambda x: x[1], reverse=True + ) + ] + + +def compute_mrr_at_k( + run_res: Dict[str, Dict[str, float]], qrels: Dict[str, Dict[str, int]], k: int = 10 +) -> float: + mrr_total = 0.0 + evaluated = 0 + for qid, doc_scores in run_res.items(): + if qid not in qrels: + continue + ranked_docids = _sorted_docids_by_score(doc_scores)[:k] + reciprocal_rank = 0.0 + for rank, doc_id in enumerate(ranked_docids, start=1): + if qrels[qid].get(doc_id, 0) > 0: + reciprocal_rank = 1.0 / rank + break + mrr_total += reciprocal_rank + evaluated += 1 + return mrr_total / evaluated if evaluated > 0 else 0.0 + + +def evaluate_msmarco_dev(model_args, data_args, training_args, model, accelerator): + suffix = get_suffix(model_args, data_args) + out_dir = os.path.join(training_args.output_dir, "evaluate_marco") + os.makedirs(out_dir, exist_ok=True) + + dataset = "msmarco" + index_name = dataset + _, queries, qrels = load_beir_from_hf( + dataset_name=dataset, split="validation", load_corpus=False + ) + corpus = HFDatasetWrapper( + load_dataset("BeIR/msmarco", "corpus", split="corpus"), + sample_function=lambda x: (x["_id"], x["text"]), + ) + + logger.info( + f"Loaded {dataset} dev with {len(corpus)} documents and {len(queries)} queries" + ) + + if not data_args.skip_ingest: + asyncio.run( + ingest( + dataset=corpus, + model=model, + out_dir=out_dir, + index_name=index_name, + accelerator=accelerator, + max_length=data_args.eval_max_seq_length, + batch_size=training_args.per_device_eval_batch_size, + ) + ) + + metrics = {} + if data_args.do_search and accelerator.is_local_main_process: + search_result = asyncio.run( + search( + queries=queries, + model=model, + out_dir=out_dir, + index_name=index_name, + max_length=data_args.eval_max_seq_length, + batch_size=training_args.per_device_eval_batch_size, + inf_free=model_args.inf_free, + use_two_phase=data_args.use_two_phase, + query_prune=data_args.query_prune, + result_size=1000, + ) + ) + + run_res = search_result["run_res"] + mrr10 = compute_mrr_at_k(run_res, qrels, k=10) + ndcg, map_, recall, p = EvaluateRetrieval.evaluate(qrels, run_res, [10, 1000]) + + metrics = { + "MRR@10": mrr10, + "Recall@10": recall.get("Recall@10", 0.0), + "Recall@1000": recall.get("Recall@1000", 0.0), + } + + with open(os.path.join(out_dir, f"msmarco_metrics{suffix}.json"), "w") as f: + json.dump(metrics, f) + + logger.info(f"MSMARCO dev metrics: {metrics}") + + doc_id = training_args.output_dir + suffix + timestamp = datetime.now().timestamp() + + metrics = { + "flops": search_result["flops"], + "MRR@10": metrics["MRR@10"], + "Recall@10": metrics["Recall@10"], + "Recall@1000": metrics["Recall@1000"], + "timestamp": timestamp, + } + emit_metrics(metrics, "msmarco_eval", doc_id) + + accelerator.wait_for_everyone() + return metrics + + +def evaluate_trec_dl(model_args, data_args, training_args, model, accelerator): + suffix = get_suffix(model_args, data_args) + out_dir = os.path.join(training_args.output_dir, "evaluate_marco") + index_name = "msmarco" + + if data_args.do_search and accelerator.is_local_main_process: + metrics = {} + for year in [2019, 2020]: + queries, qrels = load_trec_dl(year) + + search_result = asyncio.run( + search( + queries=queries, + model=model, + out_dir=out_dir, + index_name=index_name, + max_length=data_args.eval_max_seq_length, + batch_size=training_args.per_device_eval_batch_size, + inf_free=model_args.inf_free, + use_two_phase=data_args.use_two_phase, + query_prune=data_args.query_prune, + result_size=1000, + ) + ) + + run_res = search_result["run_res"] + ndcg, map_, recall, p = EvaluateRetrieval.evaluate( + qrels, run_res, [10, 1000] + ) + + metrics[f"trec-dl-{year}-ndcg@10"] = ndcg.get("NDCG@10", 0.0) + metrics[f"trec-dl-{year}-Recall@10"] = recall.get("Recall@10", 0.0) + metrics[f"trec-dl-{year}-Recall@1000"] = recall.get("Recall@1000", 0.0) + + with open(os.path.join(out_dir, f"trec_dl_metrics{suffix}.json"), "w") as f: + json.dump(metrics, f) + + +def main(): + if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + use_yaml = True + else: + use_yaml = False + + model_args, data_args, training_args = parse_args() + if use_yaml: + model_args = prepare_model_args( + model_args, training_args.output_dir, training_args.max_steps + ) + + set_logging(training_args, "eval_msmarco.log") + set_seed(training_args.seed) + + model = get_model(model_args) + accelerator = Accelerator(mixed_precision="fp16") + accelerator.prepare(model) + accelerator.wait_for_everyone() + + evaluate_msmarco_dev(model_args, data_args, training_args, model, accelerator) + evaluate_trec_dl(model_args, data_args, training_args, model, accelerator) + + +if __name__ == "__main__": + main() diff --git a/get_idf.py b/get_idf.py index 3844d2c..3229d5f 100644 --- a/get_idf.py +++ b/get_idf.py @@ -46,6 +46,11 @@ def parse_args(): default=10000, help="Batch size for batched tokenization", ) + parser.add_argument( + "--train-file", + type=str, + default=None, + ) return parser.parse_args() @@ -55,15 +60,21 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True) special_ids = set(tokenizer.all_special_ids or []) - # 1) load dataset - msmarco_corpus = datasets.load_dataset("BeIR/msmarco", "corpus")["corpus"] + if args.train_file is not None: + msmarco_corpus = datasets.load_dataset("json", data_files=args.train_file)[ + "train" + ] - # 2) fix occasional text encoding issues - msmarco_corpus = msmarco_corpus.map( - lambda x: {"text": transform_str(x["text"])}, - num_proc=30, - desc="Normalizing text", - ) + else: + # 1) load dataset + msmarco_corpus = datasets.load_dataset("BeIR/msmarco", "corpus")["corpus"] + + # 2) fix occasional text encoding issues + msmarco_corpus = msmarco_corpus.map( + lambda x: {"text": transform_str(x["text"])}, + num_proc=30, + desc="Normalizing text", + ) # 3) tokenize in parallel and aggregate DF per batch to reduce later accumulation cost def _tokenize_batch(batch): diff --git a/prepare_msmarco_hard_negatives.py b/prepare_msmarco_hard_negatives.py index 619ad6a..1023448 100644 --- a/prepare_msmarco_hard_negatives.py +++ b/prepare_msmarco_hard_negatives.py @@ -39,4 +39,4 @@ def transform_str(s): ) # 5) Save to disk (directory will contain the text-only view) -msmarco_hard_negatives.save_to_disk("data/msmarco_ft") +msmarco_hard_negatives.save_to_disk("data/msmarco_hard_negatives") diff --git a/requirements.txt b/requirements.txt index e6d8050..2618467 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,18 +3,21 @@ torch==2.6.0 --index-url https://pypi.org/simple transformers==4.51.3 datasets==3.5.0 +ruff==0.12.11 accelerate==1.6.0 numpy==2.0.2 -opensearch-py==2.8.0 -beir==2.1.0 +ir_datasets==0.5.11 +opensearch-py +beir ipykernel boto3 uvicorn fastapi pydantic -matplotlib>=3.5.0 -pandas>=1.3.0 -seaborn>=0.11.0 +matplotlib +pandas +seaborn dotenv -orjson==3.11.3 -ruff==0.12.11 \ No newline at end of file +orjson +evaluate +tensorboard \ No newline at end of file diff --git a/run_train_eval.sh b/run_train_eval.sh index 310dd26..5feec91 100644 --- a/run_train_eval.sh +++ b/run_train_eval.sh @@ -28,6 +28,7 @@ do # Evaluate the model torchrun --nproc_per_node=${N_DEVICES} evaluate_beir.py $CONFIG_PATH + torchrun --nproc_per_node=${N_DEVICES} evaluate_marco.py $CONFIG_PATH echo "Completed processing $CONFIG_PATH" echo "----------------------------------------" diff --git a/scripts/args.py b/scripts/args.py index 9c42132..fff3afa 100644 --- a/scripts/args.py +++ b/scripts/args.py @@ -30,8 +30,10 @@ class DataTrainingArguments: use_in_batch_negatives: bool = field(default=False) flops_d_lambda: float = field(default=1e-3) flops_d_T: float = field(default=10000) + flops_start_T: int = field(default=0) flops_q_lambda: float = field(default=None) flops_q_T: float = field(default=None) + flops_d_thresh: Optional[int] = field(default=None) ranking_loss_weight: float = field(default=1) kd_ensemble_teacher_kwargs: Optional[Union[dict, str]] = field( default_factory=dict, diff --git a/scripts/dataset/dataset.py b/scripts/dataset/dataset.py index c044592..555eed6 100644 --- a/scripts/dataset/dataset.py +++ b/scripts/dataset/dataset.py @@ -385,6 +385,19 @@ def __getitem__(self, idx): } +class HFDatasetWrapper(Dataset): + def __init__(self, hf_dataset, sample_function): + self.hf_dataset = hf_dataset + self.sample_function = sample_function + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + item = self.hf_dataset[idx] + return self.sample_function(item) + + class CombinedRandomSampler(Sampler): def __init__(self, datasets, batch_size, drop_last=True): self.datasets = datasets diff --git a/scripts/ingest.py b/scripts/ingest.py index 2c92b72..60cca0e 100644 --- a/scripts/ingest.py +++ b/scripts/ingest.py @@ -40,12 +40,11 @@ async def ingest( ddp_dataset = DDPDatasetWithRank( dataset, accelerator.local_process_index, accelerator.num_processes ) + dataloader = DataLoader(ddp_dataset, batch_size=batch_size) logger.info( f"Local rank: {accelerator.local_process_index}, index_name: {index_name}, sample number: {len(ddp_dataset)}" ) - dataloader = DataLoader(ddp_dataset, batch_size=batch_size) - accelerator.prepare(model) sparse_encoder = SparseEncoder( sparse_model=model, max_length=max_length, @@ -56,7 +55,7 @@ async def ingest( if accelerator.is_local_main_process: try: # delete the index if exist - os_client.indices.delete(index_name) + os_client.indices.delete(index_name, params={"timeout": 1000}) except Exception: pass diff --git a/scripts/train/loss.py b/scripts/train/loss.py index 06362cd..4f87f8a 100644 --- a/scripts/train/loss.py +++ b/scripts/train/loss.py @@ -15,6 +15,16 @@ def __call__(self, q_rep, d_rep, inputs): def get_loss(self, q_rep, d_rep, inputs): return self.weight * self.__call__(q_rep, d_rep, inputs) + @torch.no_grad() + def get_avg_hits(self, q_rep, d_rep): + """ + Return the average number of common non-zero entries per (q, d) pair. + """ + q_bin = (q_rep > 0).to(dtype=torch.float32) # [bs, dim] + d_bin = (d_rep > 0).to(dtype=torch.float32) # [nd, dim] + common = torch.matmul(q_bin, d_bin.t()) # [bs, nd] + return common.mean() + class KLDivLoss(SparseTrainingLoss): def __init__(self, use_in_batch_negatives=False, weight=1, temperature=1.0): diff --git a/scripts/train/trainer.py b/scripts/train/trainer.py index 20f8f9c..037aaa4 100644 --- a/scripts/train/trainer.py +++ b/scripts/train/trainer.py @@ -73,10 +73,16 @@ def flops_value(self, representation, group_num=1): return torch.sum(flops_per_average_token) def get_lambda(self, lambda_value, lambda_T): - if self.state.global_step >= lambda_T: - return lambda_value + start_T = getattr(self.data_args, "flops_start_T", 0) or 0 step = self.state.global_step + 1 - return lambda_value * (step / lambda_T) ** 2 + # warmup delay: lambda is 0 until start_T + if step <= start_T: + return 0 + # shifted schedule after start_T + shifted_step = step - start_T + if shifted_step >= lambda_T: + return lambda_value + return lambda_value * (shifted_step / lambda_T) ** 2 def compute_loss( self, model: SparseModel, inputs, return_outputs=False, num_items_in_batch=None @@ -102,15 +108,31 @@ def compute_loss( q_rep = gather_rep(q_rep, self.accelerator) if "scores" in inputs: inputs["scores"] = gather_rep(inputs["scores"], self.accelerator) + # compute avg lengths + d_avg_len = (d_rep > 0).sum() / d_rep.shape[0] d_flops = self.flops_value(d_rep, d_rep.shape[0] // q_rep.shape[0]) - flops_loss += d_flops * self.get_lambda( + d_lambda = self.get_lambda( self.data_args.flops_d_lambda, self.data_args.flops_d_T ) - + d_flops_loss = d_flops * d_lambda + enable_flops = True + if self.data_args.flops_d_thresh is not None: + if d_avg_len.item() < float(self.data_args.flops_d_thresh): + d_flops_loss = torch.tensor(0.0, device=d_rep.device, dtype=d_rep.dtype) + enable_flops = False + flops_loss += d_flops_loss + + q_avg_len = (q_rep > 0).sum() / q_rep.shape[0] if not self.model_args.inf_free: - flops_loss += self.flops_value(q_rep) * self.get_lambda( + q_flops = self.flops_value(q_rep) + q_lambda = self.get_lambda( self.data_args.flops_q_lambda, self.data_args.flops_q_T ) + q_flops_loss = q_flops * q_lambda + # gate q flops by d's average length only + if not enable_flops: + q_flops_loss = torch.tensor(0.0, device=q_rep.device, dtype=q_rep.dtype) + flops_loss += q_flops_loss ranking_loss = 0 for loss_function in self.loss_functions: @@ -129,12 +151,26 @@ def compute_loss( if self.state.global_step % self.args.logging_steps == 0: logger.info( - f"Step {self.state.global_step}. ranking loss moving avg:{self.ranking_loss_moving_avg}, d_flops: {d_flops}, flops_loss: {flops_loss} avg doc length: {(d_rep > 0).sum() / d_rep.shape[0]}" + f"Step {self.state.global_step}. ranking loss moving avg:{self.ranking_loss_moving_avg}, d_flops: {d_flops}, flops_loss: {flops_loss}" ) + logger.info(f"avg doc length: {d_avg_len}, avg query length: {q_avg_len}") with torch.no_grad(): nonzero = d_rep[d_rep > 0] + q_nonzero = q_rep[q_rep > 0] + # average common non-zero entries per (q, d) pair under current loss setting + try: + if len(self.loss_functions) > 0: + avg_common_hits = self.loss_functions[0].get_avg_hits( + q_rep, d_rep + ) + logger.info(f"avg common hits per pair: {avg_common_hits}") + except Exception as e: + logger.warning(f"avg common hits computation failed: {e}") + logger.info( + f"nonzero entries: {torch.mean(nonzero)} {torch.max(nonzero)}" + ) logger.info( - f"nonzero entries: {torch.mean(nonzero)} {torch.mean(nonzero)} {torch.max(nonzero)}" + f"q_nonzero entries: {torch.mean(q_nonzero)} {torch.max(q_nonzero)}" ) # DP reduce grad by sum, while DDP reduce grad by mean # scale the loss to fix the gap diff --git a/train_bpe_tokenizer.py b/train_bpe_tokenizer.py index 30f0a39..1a30453 100644 --- a/train_bpe_tokenizer.py +++ b/train_bpe_tokenizer.py @@ -1,5 +1,6 @@ import json import os +from enum import Enum from datasets import load_dataset from tokenizers import ( @@ -8,66 +9,114 @@ models, trainers, ) -from tokenizers.pre_tokenizers import ByteLevel +from tokenizers.normalizers import BertNormalizer +from tokenizers.pre_tokenizers import ByteLevel, Sequence +from tokenizers.processors import TemplateProcessing from transformers import AutoTokenizer, PreTrainedTokenizerFast -def batch_iterator( - batch_size=1000, num_workers=40, prefetch_factor=5, persistent_workers=True -): - # Only keep the text column to avoid decoding the rest of the columns unnecessarily - from torch.utils.data import DataLoader - - dataloader = DataLoader( - dataset, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - batch_size=batch_size, - persistent_workers=persistent_workers, - ) - - for batch in dataloader: - yield batch["text"] +class PROCESSING(Enum): + ALBERT = 1 + BERT = 2 + BERT_METASPACE = 3 -# use_data_file = True -# data_file = "data/wikibook.ml128.jsonl" -# data_name = "dataloader/jsonl_in_seq" -# os.environ["JSONL_LOCAL_FILES"] = "/opt/dlami/nvme/dolma/*" -# output_dir = "modernbert-bpe-1" - -use_data_file = False -data_name = "dataloader/jsonl_in_seq" -data_file = "data/wikibook.ml128.jsonl" -os.environ["JSONL_LOCAL_FILES"] = "/opt/dlami/nvme/dolma/*" -output_dir = "modernbert-bpe-full-dl" +use_data_file = True +processing = PROCESSING.BERT_METASPACE +# os.environ["JSONL_LOCAL_FILES"] = "/opt/dlami/nvme/dolma/wiki*,/opt/dlami/nvme/dolma/book*" +# os.environ["JSONL_LOCAL_SUFFIX_MAX"] = "5" +output_dir = "modernbert-bpe-bert-10k" +vocab_size = 10000 +data_files = [ + os.path.join("/home/ubuntu/tokenizer_corpus/", f) + for f in os.listdir("/home/ubuntu/tokenizer_corpus/") + if f.startswith("wiki") or f.startswith("book") +] +print(data_files) if use_data_file: dataset = load_dataset( "json", - data_files=data_file, + data_files=data_files, split="train", + num_proc=40, ) - # dataset = dataset.select(range(30000)) + + def batch_iterator(batch_size=2000): + if len(dataset) < 1e8: + texts = dataset["text"] + for i in range(0, len(texts), batch_size): + yield texts[i : i + batch_size] + else: + print("batching") + batched_dataset = dataset.batch(batch_size) + for batch in batched_dataset: + yield batch["text"] + else: dataset = load_dataset( - data_name, + "dataloader/jsonl_in_seq", split="train", streaming=True, trust_remote_code=True, ) + def batch_iterator( + batch_size=1000, num_workers=10, prefetch_factor=5, persistent_workers=True + ): + # Only keep the text column to avoid decoding the rest of the columns unnecessarily + from torch.utils.data import DataLoader + + dataloader = DataLoader( + dataset, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + batch_size=batch_size, + persistent_workers=persistent_workers, + ) + + for batch in dataloader: + yield batch["text"] + + albert_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") mdbert_tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large") +bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") tokenizer = Tokenizer(models.BPE()) -tokenizer.normalizer = albert_tokenizer.backend_tokenizer.normalizer -tokenizer.pre_tokenizer = mdbert_tokenizer.backend_tokenizer.pre_tokenizer -tokenizer.pre_tokenizer.add_prefix_space = True -tokenizer.post_processor = mdbert_tokenizer.backend_tokenizer.post_processor + +if processing == PROCESSING.ALBERT: + tokenizer.normalizer = albert_tokenizer.backend_tokenizer.normalizer + tokenizer.pre_tokenizer = mdbert_tokenizer.backend_tokenizer.pre_tokenizer + tokenizer.pre_tokenizer.add_prefix_space = True +elif processing == PROCESSING.BERT_METASPACE: + tokenizer.normalizer = BertNormalizer( + clean_text=True, + handle_chinese_chars=False, + strip_accents=True, + lowercase=True, + ) + tokenizer.pre_tokenizer = Sequence( + # bert pre_tokenizer = Sequence([WhitespaceSplit(), Punctuation(behavior="isolated")]) + [bert_tokenizer.backend_tokenizer.pre_tokenizer, ByteLevel()] + ) +elif processing == PROCESSING.BERT: + tokenizer.normalizer = BertNormalizer( + clean_text=True, + handle_chinese_chars=False, + strip_accents=True, + lowercase=True, + ) + tokenizer.pre_tokenizer = Sequence( + [ + bert_tokenizer.backend_tokenizer.pre_tokenizer, + ByteLevel(add_prefix_space=False), + ] + ) + trainer = trainers.BpeTrainer( - vocab_size=30000, - min_frequency=2, + vocab_size=vocab_size, + min_frequency=10, initial_alphabet=ByteLevel.alphabet(), special_tokens=list(mdbert_tokenizer.special_tokens_map.values()), ) @@ -124,6 +173,38 @@ def batch_iterator( hf_tokenizer.backend_tokenizer.pre_tokenizer.add_prefix_space = True hf_tokenizer.model_max_length = mdbert_tokenizer.model_max_length + +# Align post-processor with ModernBERT but bind to current special token ids +cls_tok = hf_tokenizer.cls_token +sep_tok = hf_tokenizer.sep_token +mask_tok = hf_tokenizer.mask_token +pad_tok = hf_tokenizer.pad_token +unk_tok = hf_tokenizer.unk_token + +cls_id = hf_tokenizer.cls_token_id +sep_id = hf_tokenizer.sep_token_id +mask_id = hf_tokenizer.mask_token_id +pad_id = hf_tokenizer.pad_token_id +unk_id = hf_tokenizer.unk_token_id + +special_token_pairs = [] +for tok, tid in [ + (cls_tok, cls_id), + (sep_tok, sep_id), + (mask_tok, mask_id), + (pad_tok, pad_id), + (unk_tok, unk_id), +]: + if tok is not None and tid is not None: + special_token_pairs.append((tok, tid)) + +template = TemplateProcessing( + single=f"{cls_tok}:0 $A:0 {sep_tok}:0", + pair=f"{cls_tok}:0 $A:0 {sep_tok}:0 $B:0 {sep_tok}:0", + special_tokens=special_token_pairs, +) + +hf_tokenizer.backend_tokenizer.post_processor = template hf_tokenizer.save_pretrained(output_dir) tokenizer.save(os.path.join(output_dir, "original_config.json")) diff --git a/transform_mean.py b/transform_mean.py new file mode 100644 index 0000000..63f4cbc --- /dev/null +++ b/transform_mean.py @@ -0,0 +1,98 @@ +import torch +from deepfocus.focus import get_overlapping_tokens +from transformers import AutoModelForMaskedLM, AutoTokenizer + +fuzzy = False +save = True +target_is_bpe = True +source_model_id = "answerdotai/ModernBERT-base" +target_model_id = "modernbert-bpe-bert-10k" +save_name = "modernbert-bpe-bert-10k-focus" + +source_tokenizer = AutoTokenizer.from_pretrained(source_model_id) +target_tokenizer = AutoTokenizer.from_pretrained(target_model_id) +if target_is_bpe and target_tokenizer.backend_tokenizer.decoder is None: + target_tokenizer.backend_tokenizer.decoder = ( + source_tokenizer.backend_tokenizer.decoder + ) + print("reset target tokenizer. ", target_tokenizer.backend_tokenizer.decoder) + +overlap, additional_tokens = get_overlapping_tokens( + target_tokenizer, + source_tokenizer, + match_symbols=True, + exact_match_all=True, + fuzzy_match_all=fuzzy, +) + +target_overlap_tokens = [] +source_overlap_tokens = [] +for key, value in overlap.items(): + target_overlap_tokens.append(key) + source_overlap_tokens.append(value.source[0].native_form) + +print(len(target_overlap_tokens), len(source_overlap_tokens)) + +# Load pretrained models +source_model = AutoModelForMaskedLM.from_pretrained( + source_model_id, trust_remote_code=True +) + +# Original embeddings +source_emb = source_model.get_input_embeddings() +orig_weight = source_emb.weight.data # (V_roberta, dim) + +# target vocab size and embedding dim (use source dim) +V_target = len(target_tokenizer) +dim = orig_weight.shape[1] + +# Overlap tokens (strings) and their IDs in each vocab +overlap_ids_source = [ + source_tokenizer.convert_tokens_to_ids(tok) for tok in source_overlap_tokens +] +overlap_ids_target = [ + target_tokenizer.convert_tokens_to_ids(tok) for tok in target_overlap_tokens +] + +# Create new embedding matrix of shape (V_target, dim) +new_weight = torch.zeros( + (V_target, orig_weight.shape[1]), dtype=orig_weight.dtype, device=orig_weight.device +) + +# 1. Copy original embeddings for overlap tokens +torch_tensor_overlap_source = orig_weight[overlap_ids_source] # (|O|, dim) +new_weight[overlap_ids_target] = torch_tensor_overlap_source + +# 2. Compute embeddings for non-overlap tokens +# Identify new token IDs in target vocab not in overlap +all_target_ids = list(range(V_target)) +new_ids = [i for i in all_target_ids if i not in set(overlap_ids_target)] + +# Prepare tensors +source_overlap = torch_tensor_overlap_source # (|O|, dim) in source space + +# Mean of overlap embeddings in source space, replicated for all new ids +mean_source_overlap = source_overlap.mean(dim=0) # (dim,) +new_rows = mean_source_overlap.repeat(len(new_ids), 1) # (N_new, dim) + +# Assign to new_weight +new_weight[new_ids] = new_rows + +# Replace source's embedding layer with the new matrix +new_emb_layer = torch.nn.Embedding.from_pretrained(new_weight, freeze=False) + +source_model.resize_token_embeddings(len(target_tokenizer)) +source_model.set_input_embeddings(new_emb_layer) +source_model.set_output_embeddings(new_emb_layer) + +source_model.config.bos_token_id = target_tokenizer.vocab["[CLS]"] +source_model.config.eos_token_id = target_tokenizer.vocab["[SEP]"] +source_model.config.pad_token_id = target_tokenizer.vocab["[PAD]"] +source_model.config.sep_token_id = target_tokenizer.vocab["[SEP]"] +source_model.config.cls_token_id = target_tokenizer.vocab["[CLS]"] + +source_model.save_pretrained(save_name) +if save_name != target_model_id: + target_tokenizer.save_pretrained(save_name) + +print("Saved updated source model.")