From 2145397e95017fb08e5dc5bce749be50c90a6f8d Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Mon, 21 Apr 2025 22:30:04 +0400 Subject: [PATCH 1/9] Using batch mapping for the load_and_validate_dataset function which uses data.map in process_messages_into_input_ids function in data_process.py Signed-off-by: aryanorpe --- src/instructlab/training/data_process.py | 44 +++++++++++++----------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 49b99c56..547c27bc 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -723,30 +723,32 @@ def pretraining_is_using_legacy_granite_chat_template(ds: Dataset) -> bool: return False -def ensure_dataset_is_compatible_with_legacy_format( - sample: t.Dict[str, t.Any], -) -> t.Dict[str, t.Any]: +def ensure_dataset_is_compatible_with_legacy_format(batch: t.Dict[str, t.List[t.Any]]) -> t.Dict[str, t.List[t.Any]]: """ - Given a sample that uses the legacy pre-training format, we unroll the samples into ones with the - original messages contents. + Given a batch of samples using the legacy pre-training format, unroll the samples into ones with + the original messages contents. """ - # deepcopy to prevent re-referencing the existing objects - new_sample = { - "messages": [], - "unmask": sample.get("unmask", False), - } - for msg in sample["messages"]: - if msg["role"] != "pretraining": - new_sample["messages"].append(msg) - continue + processed_messages = [] + unmask_flags = [] - # handle unmasking - new_sample["messages"].extend( - extract_messages_from_pretraining_text(msg["content"]) - ) - new_sample["unmask"] = True + for messages, unmask_flag in zip(batch["messages"], batch.get("unmask", [False] * len(batch["messages"]))): + new_messages = [] + unmask = unmask_flag - return new_sample + for msg in messages: + if msg["role"] != "pretraining": + new_messages.append(msg) + else: + new_messages.extend(extract_messages_from_pretraining_text(msg["content"])) + unmask = True # if any pretraining message is found, set unmask to True + + processed_messages.append(new_messages) + unmask_flags.append(unmask) + + return { + "messages": processed_messages, + "unmask": unmask_flags, + } def filter_samples_by_length( @@ -876,6 +878,8 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset: return data.map( ensure_dataset_is_compatible_with_legacy_format, + batched=True, + batch_size=1000, num_proc=num_procs, desc="Ensuring dataset is compatible with legacy format.", ) From c092c46f2fbc17fea5f390d7b1fc2b617002b234 Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Wed, 30 Apr 2025 19:40:48 +0400 Subject: [PATCH 2/9] Ran 'make fix' command to fix linter error. Signed-off-by: aryanorpe --- src/instructlab/training/data_process.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 547c27bc..48ce3459 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -723,7 +723,9 @@ def pretraining_is_using_legacy_granite_chat_template(ds: Dataset) -> bool: return False -def ensure_dataset_is_compatible_with_legacy_format(batch: t.Dict[str, t.List[t.Any]]) -> t.Dict[str, t.List[t.Any]]: +def ensure_dataset_is_compatible_with_legacy_format( + batch: t.Dict[str, t.List[t.Any]], +) -> t.Dict[str, t.List[t.Any]]: """ Given a batch of samples using the legacy pre-training format, unroll the samples into ones with the original messages contents. @@ -731,7 +733,9 @@ def ensure_dataset_is_compatible_with_legacy_format(batch: t.Dict[str, t.List[t. processed_messages = [] unmask_flags = [] - for messages, unmask_flag in zip(batch["messages"], batch.get("unmask", [False] * len(batch["messages"]))): + for messages, unmask_flag in zip( + batch["messages"], batch.get("unmask", [False] * len(batch["messages"])) + ): new_messages = [] unmask = unmask_flag @@ -739,7 +743,9 @@ def ensure_dataset_is_compatible_with_legacy_format(batch: t.Dict[str, t.List[t. if msg["role"] != "pretraining": new_messages.append(msg) else: - new_messages.extend(extract_messages_from_pretraining_text(msg["content"])) + new_messages.extend( + extract_messages_from_pretraining_text(msg["content"]) + ) unmask = True # if any pretraining message is found, set unmask to True processed_messages.append(new_messages) From 34dffde889f745f1b76d1c773d81a586c400cc3e Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Wed, 30 Apr 2025 22:23:22 +0400 Subject: [PATCH 3/9] Added batch mapping for the `process_samples` function in `process_messages_into_input_ids`. Signed-off-by: aryanorpe --- src/instructlab/training/data_process.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 48ce3459..6bab637b 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -589,7 +589,7 @@ def unmask_messages( ) -def unmask_sample( +def unmask_sample_single( sample: t.Dict[str, t.Any], tokenizer: PreTrainedTokenizer ) -> ProcessedMessagesData: """ @@ -618,6 +618,25 @@ def unmask_sample( return unmask_messages(sample["messages"], tokenizer, unmask_roles) +def unmask_sample( + batch: t.Dict[str, t.List[t.Any]], tokenizer: PreTrainedTokenizer +) -> t.Dict[str, t.List[t.Any]]: + input_ids_list = [] + labels_list = [] + + for i in range(len(batch["messages"])): + sample = {key: batch[key][i] for key in batch} + result = unmask_sample_single(sample, tokenizer) + + input_ids_list.append(result["input_ids"]) + labels_list.append(result["labels"]) + + return { + "input_ids": input_ids_list, + "labels": labels_list, + } + + def extract_messages_from_pretraining_text(text: str) -> t.List[Message]: """ Given a message from a pretraining message that was formatted using either the generic @@ -925,6 +944,8 @@ def process_samples( # Process the dataset processed_data = data.map( process_sample_fn, + batched=True, + batch_size=1000, num_proc=num_cpu_procs, desc="Converting samples into input_ids and labels...", load_from_cache_file=False, From 54f5f953dbf5cc04760c5848e4ee9e8ce003a151 Mon Sep 17 00:00:00 2001 From: Aryan Orpe <53704316+aryanorpe@users.noreply.github.com> Date: Sat, 17 May 2025 20:18:37 +0400 Subject: [PATCH 4/9] Update src/instructlab/training/data_process.py Co-authored-by: James Kunstle <52969093+JamesKunstle@users.noreply.github.com> Signed-off-by: Aryan Orpe <53704316+aryanorpe@users.noreply.github.com> --- src/instructlab/training/data_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 6bab637b..49df827c 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -618,7 +618,7 @@ def unmask_sample_single( return unmask_messages(sample["messages"], tokenizer, unmask_roles) -def unmask_sample( +def unmask_batch( batch: t.Dict[str, t.List[t.Any]], tokenizer: PreTrainedTokenizer ) -> t.Dict[str, t.List[t.Any]]: input_ids_list = [] From ad4d71867670ad66c55fa89f95f2b290962fee16 Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Sun, 18 May 2025 09:13:59 +0400 Subject: [PATCH 5/9] Updated code to use renamed function unmask_sample > unmask_batch Signed-off-by: aryanorpe --- src/instructlab/training/data_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 49df827c..c3ebf97f 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -939,7 +939,7 @@ def process_samples( """Process samples to generate input_ids and labels.""" # Create a wrapper function for unmask_sample - process_sample_fn = partial(unmask_sample, tokenizer=tokenizer) + process_sample_fn = partial(unmask_batch, tokenizer=tokenizer) # Process the dataset processed_data = data.map( From 7e9ec6a1e78800f20704b2267c8988a6c32103f1 Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Sun, 18 May 2025 09:18:38 +0400 Subject: [PATCH 6/9] Added configurable batch size in function . Signed-off-by: aryanorpe --- src/instructlab/training/data_process.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index c3ebf97f..1e8c2d55 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -934,7 +934,10 @@ def configure_tokenizer(model_path: str) -> PreTrainedTokenizer: def process_samples( - data: Dataset, tokenizer: PreTrainedTokenizer, num_cpu_procs: int + data: Dataset, + tokenizer: PreTrainedTokenizer, + num_cpu_procs: int, + batch_size: int = 1000, ) -> Dataset: """Process samples to generate input_ids and labels.""" @@ -945,7 +948,7 @@ def process_samples( processed_data = data.map( process_sample_fn, batched=True, - batch_size=1000, + batch_size=batch_size, num_proc=num_cpu_procs, desc="Converting samples into input_ids and labels...", load_from_cache_file=False, From ae7e1b34366c84ba27ca59122587376f27de117a Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Sun, 18 May 2025 09:56:12 +0400 Subject: [PATCH 7/9] Added test for function in in . Signed-off-by: aryanorpe --- tests/unit/test_data_process.py | 43 +++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/unit/test_data_process.py diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py new file mode 100644 index 00000000..465e167d --- /dev/null +++ b/tests/unit/test_data_process.py @@ -0,0 +1,43 @@ +import pytest +from datasets import Dataset +from transformers import LlamaTokenizerFast +from instructlab.training.data_process import process_samples + +@pytest.fixture(scope="module") +def tokenizer(): + tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceH4/zephyr-7b-alpha") + + # Ensure UNMASK tokens are treated atomically + tokenizer.add_special_tokens({ + "additional_special_tokens": ["<|UNMASK_BEGIN|>", "<|UNMASK_END|>"] + }) + + # Safety: add a pad token if it's missing + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token or "" + + return tokenizer + +def test_process_samples_outputs_input_ids_and_labels(tokenizer): + dummy_data = Dataset.from_dict({ + "messages": [[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "pretraining", "content": "Some pretraining text"} + ]], + "unmask": [True], + }) + + # Run the function + processed = process_samples(dummy_data, tokenizer, num_cpu_procs=1, batch_size=1) + + # Check the structure of the output + assert "input_ids" in processed.column_names + assert "labels" in processed.column_names + + # Sanity check one sample + sample = processed[0] + assert isinstance(sample["input_ids"], list) + assert isinstance(sample["labels"], list) + assert len(sample["input_ids"]) == len(sample["labels"]) + assert all(isinstance(x, int) for x in sample["input_ids"]) From 8e26794d682807c55e3cc6621bfc7fc984aa00bf Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Sun, 18 May 2025 09:59:22 +0400 Subject: [PATCH 8/9] Added test for process_samples function in data_process.py in tests/unit with linting. Signed-off-by: aryanorpe --- tests/unit/test_data_process.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py index 465e167d..15b2fe29 100644 --- a/tests/unit/test_data_process.py +++ b/tests/unit/test_data_process.py @@ -1,16 +1,20 @@ -import pytest +# Third Party from datasets import Dataset from transformers import LlamaTokenizerFast +import pytest + +# First Party from instructlab.training.data_process import process_samples + @pytest.fixture(scope="module") def tokenizer(): tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceH4/zephyr-7b-alpha") # Ensure UNMASK tokens are treated atomically - tokenizer.add_special_tokens({ - "additional_special_tokens": ["<|UNMASK_BEGIN|>", "<|UNMASK_END|>"] - }) + tokenizer.add_special_tokens( + {"additional_special_tokens": ["<|UNMASK_BEGIN|>", "<|UNMASK_END|>"]} + ) # Safety: add a pad token if it's missing if tokenizer.pad_token is None: @@ -18,15 +22,20 @@ def tokenizer(): return tokenizer + def test_process_samples_outputs_input_ids_and_labels(tokenizer): - dummy_data = Dataset.from_dict({ - "messages": [[ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "pretraining", "content": "Some pretraining text"} - ]], - "unmask": [True], - }) + dummy_data = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "pretraining", "content": "Some pretraining text"}, + ] + ], + "unmask": [True], + } + ) # Run the function processed = process_samples(dummy_data, tokenizer, num_cpu_procs=1, batch_size=1) From 41b7c05b8d310385f2ab98fd9764c110f380cf87 Mon Sep 17 00:00:00 2001 From: aryanorpe Date: Sun, 18 May 2025 14:07:37 +0400 Subject: [PATCH 9/9] Filled up dataset with 100 samples in test_process_samples_outputs_input_ids_and_labels function. Signed-off-by: aryanorpe --- tests/unit/test_data_process.py | 44 +++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py index 15b2fe29..fcb1cac9 100644 --- a/tests/unit/test_data_process.py +++ b/tests/unit/test_data_process.py @@ -24,29 +24,37 @@ def tokenizer(): def test_process_samples_outputs_input_ids_and_labels(tokenizer): + # Create a dummy dataset of 100 samples + messages = [ + [ + {"role": "user", "content": f"Hello {i}"}, + {"role": "assistant", "content": f"Hi there {i}!"}, + {"role": "pretraining", "content": f"Pretraining text {i}"}, + ] + for i in range(100) + ] + + unmask_flags = [True for _ in range(100)] + dummy_data = Dataset.from_dict( { - "messages": [ - [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "pretraining", "content": "Some pretraining text"}, - ] - ], - "unmask": [True], + "messages": messages, + "unmask": unmask_flags, } ) - # Run the function - processed = process_samples(dummy_data, tokenizer, num_cpu_procs=1, batch_size=1) + # Use realistic batch size + processed = process_samples(dummy_data, tokenizer, num_cpu_procs=1, batch_size=8) - # Check the structure of the output + # Check the structure assert "input_ids" in processed.column_names assert "labels" in processed.column_names - - # Sanity check one sample - sample = processed[0] - assert isinstance(sample["input_ids"], list) - assert isinstance(sample["labels"], list) - assert len(sample["input_ids"]) == len(sample["labels"]) - assert all(isinstance(x, int) for x in sample["input_ids"]) + assert len(processed) == 100 + + # Check that input_ids and labels exist and match length for a few random samples + for i in [0, 25, 50, 99]: + sample = processed[i] + assert isinstance(sample["input_ids"], list) + assert isinstance(sample["labels"], list) + assert len(sample["input_ids"]) == len(sample["labels"]) + assert all(isinstance(x, int) for x in sample["input_ids"])