Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,5 @@ scores
*.tar.xz
*.tar.lzma
*.tar.lz
*.jsonl
*.jsonl
pretrain
10 changes: 0 additions & 10 deletions .pre-commit-config.yaml

This file was deleted.

File renamed without changes.
37 changes: 37 additions & 0 deletions configs/config_ft_bi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

logging_steps: 50
# log_level_replica: info

inf_free: false
model_name_or_path: answerdotai/ModernBERT-base
tokenizer_name: answerdotai/ModernBERT-base
use_l0: false

max_seq_length: 256
train_file: data/msmarco_hard_negatives
data_type: kd
loss_types: [marginmse]
sample_num_one_query: 2
use_in_batch_negatives: false
flops_d_lambda: 6
flops_d_T: 50000
flops_q_lambda: 10
flops_q_T: 50000
flops_d_thresh: 180
flops_q_thresh: 90

output_dir: output/paper/bi/modernbert-base
per_device_eval_batch_size: 50
per_device_train_batch_size: 20

log_level: info
max_steps: 150000
fp16: true
learning_rate: 0.00002
weight_decay: 0.01
lr_scheduler_type: linear
warmup_steps: 6000
save_strategy: steps
save_steps: 150000
dataloader_drop_last: true
max_grad_norm: null
32 changes: 32 additions & 0 deletions configs/config_ft_noidf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
logging_steps: 50
# log_level_replica: info

inf_free: true
model_name_or_path: bert-base-uncased
tokenizer_name: bert-base-uncased
use_l0: false

max_seq_length: 256
train_file: data/msmarco_hard_negatives
data_type: kd
loss_types: [kldiv]
sample_num_one_query: 2
use_in_batch_negatives: false
flops_d_lambda: 0.004
flops_d_T: 10000

output_dir: output/paper/woidf/bert-base-uncased
per_device_eval_batch_size: 50
per_device_train_batch_size: 20

log_level: info
max_steps: 100000
fp16: true
learning_rate: 0.00002
weight_decay: 0.01
lr_scheduler_type: linear
warmup_steps: 6000
save_strategy: steps
save_steps: 100000
dataloader_drop_last: true
max_grad_norm: null
71 changes: 70 additions & 1 deletion dataloader/jsonl_in_seq/jsonl_in_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import glob
import os
import random
import re
from typing import Iterator, List, Tuple

import datasets
Expand Down Expand Up @@ -66,6 +68,14 @@ def _split_generators(self, dl_manager: datasets.DownloadManager):
candidates = [line.strip() for line in f if line.strip()]
files = self._expand_file_patterns(candidates)

# Optionally shuffle file order before any iteration/printing
shuffle_env = (
os.environ.get("JSONL_LOCAL_SHUFFLE_FILES", "false").strip().lower()
)
if shuffle_env in {"1", "true", "yes", "y", "on"}:
random.seed(0)
random.shuffle(files)
print(f"All files: {files}")
# No download: files are local paths listed in manifest
return [
datasets.SplitGenerator(
Expand All @@ -86,6 +96,14 @@ def _expand_file_patterns(self, candidates: List[str]) -> List[str]:
Only include existing files (exclude directories).
"""
expanded: List[str] = []
# Optional: threshold for numeric suffix when using wildcard patterns
suffix_max_env = os.environ.get("JSONL_LOCAL_SUFFIX_MAX")
suffix_max: int | None = None
if suffix_max_env is not None:
try:
suffix_max = int(suffix_max_env)
except Exception:
suffix_max = None
for pattern in candidates:
has_wildcard = (
("*" in pattern)
Expand All @@ -96,6 +114,13 @@ def _expand_file_patterns(self, candidates: List[str]) -> List[str]:
if has_wildcard:
matches = glob.glob(pattern, recursive=True)
file_matches = [p for p in matches if os.path.isfile(p)]
# If threshold is set, filter by parsed numeric suffix. If parsing fails, keep.
if suffix_max is not None:
file_matches = [
p
for p in file_matches
if self._accept_file_by_suffix(p, suffix_max)
]
file_matches.sort()
expanded.extend(file_matches)
else:
Expand All @@ -114,8 +139,26 @@ def _expand_file_patterns(self, candidates: List[str]) -> List[str]:
def _generate_examples(
self, files: List[str], split_name: str
) -> Iterator[Tuple[str, dict]]:
print(f"Generating examples for {split_name} from {files}")
# validation: every 100th line (global index % 100 == 0)
# train: the rest
# Determine validation ratio via env (default 1%) by converting to a stride: every Nth line goes to validation
val_ratio_env = os.environ.get("JSONL_LOCAL_VAL_RATIO")
default_ratio = 0.01
try:
if val_ratio_env is not None:
ratio = float(val_ratio_env)
else:
ratio = default_ratio
except Exception:
ratio = default_ratio
# clamp ratio to (0, 1]
if ratio <= 0:
ratio = default_ratio
if ratio > 1:
ratio = 1.0
validation_stride = max(1, int(round(1.0 / ratio)))

global_index = 0
for file_idx, fp in enumerate(files):
if not os.path.exists(fp):
Expand All @@ -140,7 +183,7 @@ def _generate_examples(
global_index += 1
continue

is_validation = global_index % 100 == 0
is_validation = global_index % validation_stride == 0
if split_name == "validation" and not is_validation:
global_index += 1
continue
Expand All @@ -151,3 +194,29 @@ def _generate_examples(
key = f"{file_idx}-{line_idx}"
yield key, {"text": row["text"]}
global_index += 1

def _accept_file_by_suffix(self, path: str, max_suffix: int) -> bool:
"""
Return True if the file should be kept under the numeric suffix threshold.
Logic:
- Strip multi-part extensions (e.g., .json.gz -> base name)
- Find the last continuous digit sequence in the remaining base name
- If digits are found and parseable, keep only if number <= max_suffix
- If parsing fails or digits not found, keep (as requested)
"""
name = os.path.basename(path)
base = name
# Strip all extensions iteratively (handles .json.gz, .tar.gz, etc.)
while True:
base_no_ext, ext = os.path.splitext(base)
if not ext:
break
base = base_no_ext
match = re.search(r"(\d+)(?!.*\d)", base)
if not match:
return True
try:
value = int(match.group(1))
except Exception:
return True
return value <= max_suffix
33 changes: 25 additions & 8 deletions evaluate_beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from scripts.args import nano_beir_datasets, parse_args
from scripts.dataset.data_utils import cached
from scripts.dataset.dataset import BEIRCorpusDataset
from scripts.dataset.dataset import BEIRCorpusDataset, HFDatasetWrapper
from scripts.ingest import ingest
from scripts.search import search
from scripts.utils import emit_metrics, get_model, set_logging
Expand Down Expand Up @@ -51,10 +51,14 @@ def get_suffix(model_args, data_args):
def load_beir_from_hf(
dataset_name: str = "nfcorpus",
split: str = "test",
load_corpus: bool = True,
) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
ds_corpus = load_dataset(
f"BEIR/{dataset_name}", "corpus", split="corpus", trust_remote_code=True
)
if load_corpus:
ds_corpus = load_dataset(
f"BEIR/{dataset_name}", "corpus", split="corpus", trust_remote_code=True
)
else:
ds_corpus = None
ds_queries = load_dataset(
f"BEIR/{dataset_name}", "queries", split="queries", trust_remote_code=True
)
Expand All @@ -64,8 +68,11 @@ def load_beir_from_hf(

# Build BEIR-style corpus
corpus: Dict[str, Dict[str, str]] = {}
for r in ds_corpus:
corpus[str(r["_id"])] = {"title": r["title"], "text": r["text"]}
if load_corpus:
for r in ds_corpus:
corpus[str(r["_id"])] = {"title": r["title"], "text": r["text"]}
else:
corpus = None

# Build BEIR-style queries
queries: Dict[str, str] = {}
Expand Down Expand Up @@ -147,14 +154,22 @@ def evaluate_beir(model_args, data_args, training_args, model, accelerator):
}
avg_res = dict()
for dataset in datasets:
corpus, queries, qrels = load_beir_from_hf(dataset_name=dataset, split="test")
_, queries, qrels = load_beir_from_hf(
dataset_name=dataset, split="test", load_corpus=False
)
corpus = HFDatasetWrapper(
load_dataset(
f"BEIR/{dataset}", "corpus", split="corpus", trust_remote_code=True
),
sample_function=lambda x: (x["_id"], x["title"] + " " + x["text"]),
)
logger.info(
f"Loaded {dataset} with {len(corpus)} documents and {len(queries)} queries"
)
if not data_args.skip_ingest:
asyncio.run(
ingest(
dataset=BEIRCorpusDataset(corpus=corpus),
dataset=corpus,
model=model,
out_dir=beir_eval_dir,
index_name=dataset,
Expand Down Expand Up @@ -352,6 +367,7 @@ def main():

model = get_model(model_args)
accelerator = Accelerator(mixed_precision="fp16")
accelerator.prepare(model)
accelerator.wait_for_everyone()

evaluate_beir(model_args, data_args, training_args, model, accelerator)
Expand All @@ -369,6 +385,7 @@ def main():
model_args.model_name_or_path, "idf.json"
)
model = get_model(model_args)
accelerator.prepare(model)
evaluate_nano_beir(
model_args, data_args, training_args, model, accelerator, step
)
Expand Down
Loading