diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index a202aac426c..ace88a81242 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -28,11 +28,13 @@ import org.springframework.util.Assert; /** - * A {@link TextSplitter} that splits text into chunks of a target size in tokens. + * A {@link TextSplitter} that splits text into chunks of a target size in tokens. Now + * supports overlapping tokens between chunks. * * @author Raphael Yu * @author Christian Tzolov * @author Ricken Bazolo + * @author Enginner JiaXing */ public class TokenTextSplitter extends TextSplitter { @@ -46,39 +48,42 @@ public class TokenTextSplitter extends TextSplitter { private static final boolean KEEP_SEPARATOR = true; + private static final int DEFAULT_OVERLAP_SIZE = 0; + private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE); - // The target size of each text chunk in tokens private final int chunkSize; - // The minimum size of each text chunk in characters private final int minChunkSizeChars; - // Discard chunks shorter than this private final int minChunkLengthToEmbed; - // The maximum number of chunks to generate from a text private final int maxNumChunks; private final boolean keepSeparator; + private final int overlapSize; + public TokenTextSplitter() { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR, + DEFAULT_OVERLAP_SIZE); } public TokenTextSplitter(boolean keepSeparator) { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator, + DEFAULT_OVERLAP_SIZE); } public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, - boolean keepSeparator) { + boolean keepSeparator, int overlapSize) { this.chunkSize = chunkSize; this.minChunkSizeChars = minChunkSizeChars; this.minChunkLengthToEmbed = minChunkLengthToEmbed; this.maxNumChunks = maxNumChunks; this.keepSeparator = keepSeparator; + this.overlapSize = overlapSize; } public static Builder builder() { @@ -97,59 +102,52 @@ protected List doSplit(String text, int chunkSize) { List tokens = getEncodedTokens(text); List chunks = new ArrayList<>(); + + int start = 0; int num_chunks = 0; - while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) { - List chunk = tokens.subList(0, Math.min(chunkSize, tokens.size())); + + while (start < tokens.size() && num_chunks < this.maxNumChunks) { + int end = Math.min(start + chunkSize, tokens.size()); + List chunk = tokens.subList(start, end); String chunkText = decodeTokens(chunk); - // Skip the chunk if it is empty or whitespace if (chunkText.trim().isEmpty()) { - tokens = tokens.subList(chunk.size(), tokens.size()); + start = end; continue; } - // Find the last period or punctuation mark in the chunk int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'), Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n')))); if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) { - // Truncate the chunk text at the punctuation mark chunkText = chunkText.substring(0, lastPunctuation + 1); } - String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim() + String chunkTextToAppend = this.keepSeparator ? chunkText.trim() : chunkText.replace(System.lineSeparator(), " ").trim(); + if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) { chunks.add(chunkTextToAppend); + num_chunks++; } - // Remove the tokens corresponding to the chunk text from the remaining tokens - tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size()); - - num_chunks++; - } - - // Handle the remaining tokens - if (!tokens.isEmpty()) { - String remaining_text = decodeTokens(tokens).replace(System.lineSeparator(), " ").trim(); - if (remaining_text.length() > this.minChunkLengthToEmbed) { - chunks.add(remaining_text); - } + // Move start forward by chunkSize - overlapSize to allow overlap + start += chunkSize - this.overlapSize; } return chunks; } - private List getEncodedTokens(String text) { + List getEncodedTokens(String text) { Assert.notNull(text, "Text must not be null"); return this.encoding.encode(text).boxed(); } private String decodeTokens(List tokens) { Assert.notNull(tokens, "Tokens must not be null"); - var tokensIntArray = new IntArrayList(tokens.size()); - tokens.forEach(tokensIntArray::add); - return this.encoding.decode(tokensIntArray); + IntArrayList tokenArray = new IntArrayList(tokens.size()); + tokens.forEach(tokenArray::add); + return this.encoding.decode(tokenArray); } public static final class Builder { @@ -164,6 +162,8 @@ public static final class Builder { private boolean keepSeparator = KEEP_SEPARATOR; + private int overlapSize = DEFAULT_OVERLAP_SIZE; + private Builder() { } @@ -192,9 +192,14 @@ public Builder withKeepSeparator(boolean keepSeparator) { return this; } + public Builder withOverlapSize(int overlapSize) { + this.overlapSize = overlapSize; + return this; + } + public TokenTextSplitter build() { return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed, - this.maxNumChunks, this.keepSeparator); + this.maxNumChunks, this.keepSeparator, this.overlapSize); } } diff --git a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index e803c8a4e40..dd79bab5695 100644 --- a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -20,17 +20,24 @@ import java.util.Map; import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertNotNull; import org.springframework.ai.document.DefaultContentFormatter; import org.springframework.ai.document.Document; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Ricken Bazolo */ public class TokenTextSplitterTest { + private final String SAMPLE_TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Vestibulum volutpat augue et turpis facilisis, id porta ligula interdum. " + + "Proin condimentum justo sed lectus fermentum, a pretium orci iaculis. " + + "Mauris nec pharetra libero. Nulla facilisi. Sed consequat velit id eros volutpat dignissim."; + @Test public void testTokenTextSplitterBuilderWithDefaultValues() { @@ -112,4 +119,83 @@ public void testTokenTextSplitterBuilderWithAllFields() { assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); } + @Test + void testSplitWithOverlap() { + TokenTextSplitter splitter = TokenTextSplitter.builder() + .withChunkSize(40) + .withOverlapSize(10) + .withMinChunkLengthToEmbed(5) + .build(); + + List chunks = splitter.splitText(SAMPLE_TEXT); + + assertNotNull(chunks); + assertTrue(chunks.size() > 1, "Text should be split into multiple chunks"); + + // Compare overlapping tokens between consecutive chunks + List allTokens = splitter.getEncodedTokens(SAMPLE_TEXT); + + for (int i = 1; i < chunks.size(); i++) { + List prevTokens = splitter.getEncodedTokens(chunks.get(i - 1)); + List currTokens = splitter.getEncodedTokens(chunks.get(i)); + + int overlap = getOverlapSize(prevTokens, currTokens); + + // Allow some deviation due to punctuation or sentence trimming + assertTrue(overlap >= 5 && overlap <= 15, + "Expected ~10 overlapping tokens between chunks, but got " + overlap); + } + } + + @Test + void testSplitWithoutOverlap() { + TokenTextSplitter splitter = TokenTextSplitter.builder().withChunkSize(40).withOverlapSize(0).build(); + + List chunks = splitter.splitText(SAMPLE_TEXT); + + assertNotNull(chunks); + assertTrue(chunks.size() > 1); + + for (int i = 1; i < chunks.size(); i++) { + List prev = splitter.getEncodedTokens(chunks.get(i - 1)); + List curr = splitter.getEncodedTokens(chunks.get(i)); + + assertTrue(noOverlap(prev, curr), "There should be no overlap between chunks"); + } + } + + @Test + void testEmptyText() { + TokenTextSplitter splitter = TokenTextSplitter.builder().withChunkSize(50).withOverlapSize(10).build(); + + List chunks = splitter.splitText(" "); + assertTrue(chunks.isEmpty(), "Empty or whitespace-only input should return no chunks"); + } + + /** + * Calculate the number of overlapping tokens between the end of the previous chunk + * and the start of the current chunk. + */ + private int getOverlapSize(List prev, List curr) { + int maxOverlap = Math.min(prev.size(), curr.size()); + for (int i = maxOverlap; i > 0; i--) { + if (prev.subList(prev.size() - i, prev.size()).equals(curr.subList(0, i))) { + return i; + } + } + return 0; + } + + /** + * Check whether there is no overlap between the two token lists. + */ + private boolean noOverlap(List prev, List curr) { + for (int len = Math.min(prev.size(), curr.size()); len > 0; len--) { + if (prev.subList(prev.size() - len, prev.size()).equals(curr.subList(0, len))) { + return false; + } + } + return true; + } + }