Skip to content
Closed
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0b16294
Update assistant mask exception.
pramodith Aug 20, 2025
cb0258d
use model_max_length
pramodith Aug 20, 2025
f7dd77e
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Aug 20, 2025
067cf05
Log % of trainable assistant tokens
pramodith Aug 21, 2025
faf3f69
filter out rows with no assistant tokens.
pramodith Aug 21, 2025
bb22c8f
address co-pilot comments.
pramodith Aug 21, 2025
4b0c21c
IterableDataset doesn't have num_rows.
pramodith Aug 21, 2025
3eed0cc
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Aug 21, 2025
8c758c7
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Aug 22, 2025
023ddaf
Compute total trainable tokens for all types of datasets after trunca…
pramodith Aug 22, 2025
054e4a7
input_ids should use length for total token count.
pramodith Aug 22, 2025
0400a27
precommit
pramodith Aug 22, 2025
bce8654
account for iterable dataset.
pramodith Aug 22, 2025
a25d404
revert
pramodith Aug 22, 2025
e3efa70
Merge branch 'main' into pramodith/update_assistant_token_exception
qgallouedec Aug 27, 2025
3fd22ef
revert
qgallouedec Aug 27, 2025
0fa5da8
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Aug 27, 2025
f6738f2
Add tests and the case for conversational + assistant only.
pramodith Aug 27, 2025
3601e68
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Aug 27, 2025
09c1ed8
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Sep 2, 2025
b4cc0a3
Use batched and polars.
pramodith Sep 9, 2025
8cb4ce3
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Sep 9, 2025
9da4b44
Merge branch 'pramodith/update_assistant_token_exception' of https://…
pramodith Sep 10, 2025
57723c6
Merge branch 'main' into pramodith/update_assistant_token_exception
pramodith Sep 10, 2025
caf0c6e
Add polars to requirements.
pramodith Sep 10, 2025
54e4eed
Merge branch 'pramodith/update_assistant_token_exception' of https://…
pramodith Sep 10, 2025
f5fdb2f
Merge branch 'main' into pramodith/update_assistant_token_exception
qgallouedec Sep 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,9 +1009,27 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
# Packing adds new column "seq_lengths" needed for document aware FlashAttention
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
if args.assistant_only_loss:
total_assistant_tokens_before_truncation = sum([sum(row["assistant_masks"]) for row in dataset])
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
if args.assistant_only_loss:
dataset = dataset.filter(lambda row: 1 in row["assistant_masks"])
total_assistant_tokens_after_truncation = sum([sum(row["assistant_masks"]) for row in dataset])
if total_assistant_tokens_after_truncation == 0:
raise RuntimeError(
"After truncation, the dataset has no trainable assistant tokens. This usually means that "
"the max length is too short."
)
percentage_of_retained_assistant_tokens = (
total_assistant_tokens_after_truncation / total_assistant_tokens_before_truncation
) * 100
logger.info(
f"Total number of trainable assistant tokens after truncation: {total_assistant_tokens_after_truncation}. "
f"Percentage of retained assistant tokens after truncating dataset: {percentage_of_retained_assistant_tokens:.2f}%"
)

# For Liger kernel, ensure only the essential columns
if args.use_liger_kernel:
collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}
Expand Down
Loading