diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index c3ad98c51a6ad..955b289bec497 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -19,11 +19,14 @@ import io.netty.channel.ChannelOption; import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.socket.nio.NioChannelOption; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseEncoder; @@ -417,21 +420,39 @@ protected HttpMessage createMessage(String[] initialLine) throws Exception { ); } - ch.pipeline() - .addLast("decoder_compress", new HttpContentDecompressor()) // this handles request body decompression - .addLast("encoder", new HttpResponseEncoder() { - @Override - protected boolean isContentAlwaysEmpty(HttpResponse msg) { - // non-chunked responses (Netty4HttpResponse extends Netty's DefaultFullHttpResponse) with chunked transfer - // encoding are only sent by us in response to HEAD requests and must always have an empty body - if (msg instanceof Netty4FullHttpResponse netty4FullHttpResponse && HttpUtil.isTransferEncodingChunked(msg)) { - assert netty4FullHttpResponse.content().isReadable() == false; - return true; - } - return super.isContentAlwaysEmpty(msg); + ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor() { // this handles request body decompression + private String currentUri; + + @Override + protected void decode(ChannelHandlerContext ctx, HttpObject msg, java.util.List out) throws Exception { + if (msg instanceof HttpRequest request) { + currentUri = request.uri(); + } + super.decode(ctx, msg, out); + } + + @Override + protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { + if (currentUri != null && currentUri.startsWith("/_prometheus") && "snappy".equalsIgnoreCase(contentEncoding)) { + // Prometheus remote write uses raw Snappy block format, not the framed + // format that Netty's SnappyFrameDecoder expects. Skip auto-decompression + // and let the application layer handle it. + return null; } - }) - .addLast(new Netty4HttpContentSizeHandler(decoder, handlingSettings.maxContentLength())); + return super.newContentDecoder(contentEncoding); + } + }).addLast("encoder", new HttpResponseEncoder() { + @Override + protected boolean isContentAlwaysEmpty(HttpResponse msg) { + // non-chunked responses (Netty4HttpResponse extends Netty's DefaultFullHttpResponse) with chunked transfer + // encoding are only sent by us in response to HEAD requests and must always have an empty body + if (msg instanceof Netty4FullHttpResponse netty4FullHttpResponse && HttpUtil.isTransferEncodingChunked(msg)) { + assert netty4FullHttpResponse.content().isReadable() == false; + return true; + } + return super.isContentAlwaysEmpty(msg); + } + }).addLast(new Netty4HttpContentSizeHandler(decoder, handlingSettings.maxContentLength())); if (handlingSettings.compression()) { ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.compressionLevel()) { diff --git a/server/src/main/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregator.java b/server/src/main/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregator.java index 5f19f7fba865f..a045e9790d184 100644 --- a/server/src/main/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregator.java +++ b/server/src/main/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregator.java @@ -16,7 +16,9 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.index.IndexingPressure; +import java.io.IOException; import java.util.ArrayList; +import java.util.Objects; /** * Accumulates a streamed HTTP request body while tracking memory usage via {@link IndexingPressure}. @@ -35,6 +37,28 @@ */ public class IndexingPressureAwareContentAggregator implements BaseRestHandler.RequestBodyChunkConsumer { + /** + * Transforms the accumulated request body before it is handed to the {@link CompletionHandler}. + * Implementations must release the input reference when they produce new output. + */ + @FunctionalInterface + public interface BodyPostProcessor { + + BodyPostProcessor NOOP = (body, size) -> body; + + /** + * Post-processes the accumulated request body (e.g. decompression). + * + * @param body The accumulated raw body to process. + * Unless the post-processor returns the same reference, it is responsible for closing it. + * The caller must not use this reference after this method returns. + * @param maxSize The maximum permitted size for the result. + * @return The post-processed body. Must not exceed {@code maxSize}. The caller is responsible for closing the returned reference. + * @throws IOException on processing failure + */ + ReleasableBytesReference process(ReleasableBytesReference body, long maxSize) throws IOException; + } + /** * Callback for request body accumulation lifecycle events. */ @@ -61,6 +85,7 @@ public interface CompletionHandler { private final IndexingPressure.Coordinating coordinating; private final long maxRequestSize; private final CompletionHandler completionHandler; + private final BodyPostProcessor bodyPostProcessor; private ArrayList chunks; private long accumulatedSize; @@ -70,12 +95,14 @@ public IndexingPressureAwareContentAggregator( RestRequest request, IndexingPressure.Coordinating coordinating, long maxRequestSize, - CompletionHandler completionHandler + CompletionHandler completionHandler, + BodyPostProcessor bodyPostProcessor ) { this.request = request; this.coordinating = coordinating; this.maxRequestSize = maxRequestSize; this.completionHandler = completionHandler; + this.bodyPostProcessor = Objects.requireNonNull(bodyPostProcessor); } @Override @@ -91,21 +118,7 @@ public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boo } accumulatedSize += chunk.length(); - if (accumulatedSize > maxRequestSize) { - chunk.close(); - closed = true; - if (chunks != null) { - Releasables.close(chunks); - chunks = null; - } - coordinating.close(); - completionHandler.onFailure( - channel, - new ElasticsearchStatusException( - "request body too large, max [" + maxRequestSize + "] bytes", - RestStatus.REQUEST_ENTITY_TOO_LARGE - ) - ); + if (failIfAboveLimit(channel, chunk)) { return; } @@ -126,16 +139,55 @@ public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boo } chunks = null; + try { + fullBody = bodyPostProcessor.process(fullBody, maxRequestSize); + } catch (Exception e) { + closeOnFailure(channel, e, fullBody); + return; + } + accumulatedSize = fullBody.length(); + if (failIfAboveLimit(channel, fullBody)) { + return; + } + long excess = maxRequestSize - accumulatedSize; if (excess > 0) { coordinating.reduceBytes(excess); } - closed = true; completionHandler.onComplete(channel, fullBody, coordinating); } } + /** + * @return {@code true} if the limit was exceeded and failure handling was performed, otherwise {@code false}. + */ + private boolean failIfAboveLimit(RestChannel channel, Releasable releasable) { + if (accumulatedSize > maxRequestSize) { + closeOnFailure( + channel, + new ElasticsearchStatusException( + "request body too large, max [" + maxRequestSize + "] bytes", + RestStatus.REQUEST_ENTITY_TOO_LARGE + ), + releasable + ); + return true; + } + return false; + } + + private void closeOnFailure(RestChannel channel, Exception e, Releasable releasable) { + releasable.close(); + if (chunks != null) { + Releasables.close(chunks); + chunks = null; + } + closed = true; + coordinating.close(); + completionHandler.onFailure(channel, e); + } + @Override public void streamClose() { if (closed == false) { diff --git a/server/src/test/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregatorTests.java b/server/src/test/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregatorTests.java index 41da649dbdbca..ae1aa80ea2516 100644 --- a/server/src/test/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/rest/IndexingPressureAwareContentAggregatorTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.rest; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.settings.Settings; @@ -21,6 +22,7 @@ import org.elasticsearch.test.rest.FakeRestRequest; import org.junit.After; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; @@ -177,6 +179,60 @@ public void testReducesPressureToActualSize() { assertEquals(0, indexingPressure.stats().getCurrentCoordinatingBytes()); } + public void testPostProcessorExpandsContent() { + long maxSize = 1024; + int compressedSize = 50; + int expandedSize = 200; + byte[] expanded = randomByteArrayOfLength(expandedSize); + + initAggregator(maxSize, (body, max) -> { + body.close(); + return new ReleasableBytesReference(new BytesArray(expanded), () -> {}); + }); + + assertEquals(maxSize, indexingPressure.stats().getCurrentCoordinatingBytes()); + + var chunk = randomReleasableBytesReference(compressedSize); + stream.sendNext(chunk, true); + + assertNotNull(contentRef.get()); + assertEquals(expandedSize, contentRef.get().length()); + assertEquals(expandedSize, indexingPressure.stats().getCurrentCoordinatingBytes()); + + pressureRef.get().close(); + assertEquals(0, indexingPressure.stats().getCurrentCoordinatingBytes()); + } + + public void testPostProcessorResultExceedsMaxSize() { + long maxSize = 100; + int compressedSize = 50; + int expandedSize = 200; + byte[] expanded = randomByteArrayOfLength(expandedSize); + + initAggregator(maxSize, (body, max) -> { + body.close(); + return new ReleasableBytesReference(new BytesArray(expanded), () -> {}); + }); + + var chunk = randomReleasableBytesReference(compressedSize); + stream.sendNext(chunk, true); + + assertTooLargeRejected(); + } + + public void testPostProcessorThrowsReleasesResources() { + long maxSize = 1024; + initAggregator(maxSize, (body, max) -> { throw new IOException("decompression failed"); }); + + var chunk = randomReleasableBytesReference(64); + stream.sendNext(chunk, true); + + assertNull(contentRef.get()); + assertNotNull(channel.capturedResponse()); + assertFalse(chunk.hasReferences()); + assertEquals(0, indexingPressure.stats().getCurrentCoordinatingBytes()); + } + private RestRequest newStreamedRequest(FakeHttpBodyStream stream) { var httpRequest = new FakeRestRequest.FakeHttpRequest( RestRequest.Method.POST, @@ -195,6 +251,10 @@ private void assertTooLargeRejected() { } private void initAggregator(long maxSize) { + initAggregator(maxSize, IndexingPressureAwareContentAggregator.BodyPostProcessor.NOOP); + } + + private void initAggregator(long maxSize, IndexingPressureAwareContentAggregator.BodyPostProcessor postProcessor) { var request = newStreamedRequest(stream); channel = new FakeRestChannel(request, true, 1); var coordinating = indexingPressure.markCoordinatingOperationStarted(1, maxSize, false); @@ -213,7 +273,8 @@ public void onComplete(RestChannel ch, ReleasableBytesReference content, Releasa public void onFailure(RestChannel ch, Exception e) { ch.sendResponse(new RestResponse(RestStatus.REQUEST_ENTITY_TOO_LARGE, e.getMessage())); } - } + }, + postProcessor ); stream.setHandler(new HttpBody.ChunkHandler() { @Override diff --git a/x-pack/plugin/prometheus/src/javaRestTest/java/org/elasticsearch/xpack/prometheus/PrometheusRemoteWriteRestIT.java b/x-pack/plugin/prometheus/src/javaRestTest/java/org/elasticsearch/xpack/prometheus/PrometheusRemoteWriteRestIT.java index e7b160e831b0e..ee14fa9d88c5f 100644 --- a/x-pack/plugin/prometheus/src/javaRestTest/java/org/elasticsearch/xpack/prometheus/PrometheusRemoteWriteRestIT.java +++ b/x-pack/plugin/prometheus/src/javaRestTest/java/org/elasticsearch/xpack/prometheus/PrometheusRemoteWriteRestIT.java @@ -7,6 +7,11 @@ package org.elasticsearch.xpack.prometheus; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.compression.Snappy; + +import org.apache.http.HttpHeaders; import org.apache.http.entity.ByteArrayEntity; import org.apache.http.entity.ContentType; import org.apache.http.util.EntityUtils; @@ -253,7 +258,8 @@ private void sendAndAssertSuccess(RemoteWrite.WriteRequest writeRequest) throws private void sendAndAssertSuccess(RemoteWrite.WriteRequest writeRequest, String endpoint) throws IOException { Request request = new Request("POST", endpoint); - request.setEntity(new ByteArrayEntity(writeRequest.toByteArray(), ContentType.create("application/x-protobuf"))); + request.setEntity(new ByteArrayEntity(snappyEncode(writeRequest.toByteArray()), ContentType.create("application/x-protobuf"))); + request.setOptions(request.getOptions().toBuilder().addHeader(HttpHeaders.CONTENT_ENCODING, "snappy")); Response response = client().performRequest(request); assertThat(response.getStatusLine().getStatusCode(), equalTo(204)); } @@ -264,7 +270,8 @@ private String sendAndAssertBadRequest(RemoteWrite.WriteRequest writeRequest) th private String sendAndAssertBadRequest(RemoteWrite.WriteRequest writeRequest, String endpoint) throws IOException { Request request = new Request("POST", endpoint); - request.setEntity(new ByteArrayEntity(writeRequest.toByteArray(), ContentType.create("application/x-protobuf"))); + request.setEntity(new ByteArrayEntity(snappyEncode(writeRequest.toByteArray()), ContentType.create("application/x-protobuf"))); + request.setOptions(request.getOptions().toBuilder().addHeader(HttpHeaders.CONTENT_ENCODING, "snappy")); ResponseException e = expectThrows(ResponseException.class, () -> client().performRequest(request)); assertThat(e.getResponse().getStatusLine().getStatusCode(), equalTo(400)); return EntityUtils.toString(e.getResponse().getEntity()); @@ -337,4 +344,19 @@ private boolean dataStreamExists(String dataStream) throws IOException { throw e; } } + + private static byte[] snappyEncode(byte[] input) { + ByteBuf in = Unpooled.wrappedBuffer(input); + ByteBuf out = Unpooled.buffer(input.length); + try { + new Snappy().encode(in, out, input.length); + byte[] result = new byte[out.readableBytes()]; + out.readBytes(result); + return result; + } finally { + in.release(); + out.release(); + } + } + } diff --git a/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/PrometheusPlugin.java b/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/PrometheusPlugin.java index 61f0b743a7c77..e8b0cd00032f9 100644 --- a/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/PrometheusPlugin.java +++ b/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/PrometheusPlugin.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.prometheus; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.SetOnce; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.FeatureFlag; @@ -43,6 +45,7 @@ public class PrometheusPlugin extends Plugin implements ActionPlugin { private final SetOnce indexTemplateRegistry = new SetOnce<>(); private final SetOnce indexingPressure = new SetOnce<>(); + private final SetOnce> recycler = new SetOnce<>(); private final boolean enabled; private final long maxProtobufContentLengthBytes; @@ -56,6 +59,7 @@ public Collection createComponents(PluginServices services) { Settings settings = services.environment().settings(); ClusterService clusterService = services.clusterService(); indexingPressure.set(services.indexingPressure()); + recycler.set(services.bigArrays().bytesRefRecycler()); indexTemplateRegistry.set( new PrometheusIndexTemplateRegistry( settings, @@ -93,7 +97,7 @@ public Collection getRestHandlers( ) { if (enabled) { assert indexingPressure.get() != null : "indexing pressure must be set if plugin is enabled"; - return List.of(new PrometheusRemoteWriteRestAction(indexingPressure.get(), maxProtobufContentLengthBytes)); + return List.of(new PrometheusRemoteWriteRestAction(indexingPressure.get(), maxProtobufContentLengthBytes, recycler.get())); } return List.of(); } diff --git a/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestAction.java b/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestAction.java index 14869cfc037b5..bce900bb561dc 100644 --- a/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestAction.java +++ b/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestAction.java @@ -7,12 +7,15 @@ package org.elasticsearch.xpack.prometheus.rest; +import org.apache.http.HttpHeaders; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.logging.LogManager; @@ -42,10 +45,12 @@ public class PrometheusRemoteWriteRestAction extends BaseRestHandler { private final IndexingPressure indexingPressure; private final long maxRequestSizeBytes; + private final Recycler recycler; - public PrometheusRemoteWriteRestAction(IndexingPressure indexingPressure, long maxRequestSizeBytes) { + public PrometheusRemoteWriteRestAction(IndexingPressure indexingPressure, long maxRequestSizeBytes, Recycler recycler) { this.indexingPressure = indexingPressure; this.maxRequestSizeBytes = maxRequestSizeBytes; + this.recycler = recycler; } @Override @@ -82,6 +87,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli var coordinating = indexingPressure.markCoordinatingOperationStarted(1, maxRequestSizeBytes, false); + // while the remote write spec mandates snappy, we intentionally want to allow additional compression formats + var bodyPostProcessor = "snappy".equals(request.header(HttpHeaders.CONTENT_ENCODING)) + ? new SnappyBlockDecoder(recycler) + : IndexingPressureAwareContentAggregator.BodyPostProcessor.NOOP; + return new IndexingPressureAwareContentAggregator( request, coordinating, @@ -131,7 +141,8 @@ public void onFailure(RestChannel channel, Exception e) { new RestResponse(ExceptionsHelper.status(e), RestResponse.TEXT_CONTENT_TYPE, new BytesArray(e.getMessage())) ); } - } + }, + bodyPostProcessor ); } } diff --git a/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/SnappyBlockDecoder.java b/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/SnappyBlockDecoder.java new file mode 100644 index 0000000000000..f244f054e313e --- /dev/null +++ b/x-pack/plugin/prometheus/src/main/java/org/elasticsearch/xpack/prometheus/rest/SnappyBlockDecoder.java @@ -0,0 +1,318 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.elasticsearch.xpack.prometheus.rest; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.recycler.Recycler; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.rest.IndexingPressureAwareContentAggregator; +import org.elasticsearch.rest.RestStatus; + +import java.io.IOException; +import java.util.ArrayList; + +/** + * Decodes a Snappy block-compressed request body directly into recycled 16 KiB pages. + *

+ * This is a fork of + * + * {@code io.netty.handler.codec.compression.Snappy} adapted to operate + * on a {@link BytesReference} input (via {@link StreamInput}, zero-copy across chunks) and + * recycled pages for output, following the same approach as + * {@code org.elasticsearch.transport.Lz4TransportDecompressor}. + *

+ * Neither the input nor the output requires a contiguous {@code byte[]} allocation. + * Output pages are pre-allocated from the preamble-declared uncompressed length, + * providing an additional bound that decoded output cannot exceed the declared size. + * + * @see Snappy format description + */ +public final class SnappyBlockDecoder implements IndexingPressureAwareContentAggregator.BodyPostProcessor { + + private static final int LITERAL = 0; + private static final int COPY_1_BYTE_OFFSET = 1; + private static final int COPY_2_BYTE_OFFSET = 2; + private static final int COPY_4_BYTE_OFFSET = 3; + + private final Recycler recycler; + + public SnappyBlockDecoder(Recycler recycler) { + this.recycler = recycler; + } + + @Override + public ReleasableBytesReference process(ReleasableBytesReference body, long maxSize) throws IOException { + try (body) { + return decode(body.streamInput(), maxSize); + } + } + + private ReleasableBytesReference decode(StreamInput in, long maxSize) throws IOException { + if (in.available() <= 0) { + throw new IOException("empty snappy input"); + } + + int uncompressedLength = 0; + int shift = 0; + while (in.available() > 0) { + int b = readUByte(in); + uncompressedLength |= (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + break; + } + shift += 7; + if (shift >= 32) { + throw new IOException("snappy preamble is too large"); + } + } + + if (uncompressedLength < 0) { + throw new IOException("negative snappy uncompressed length: " + uncompressedLength); + } + if (uncompressedLength > maxSize) { + throw new ElasticsearchStatusException( + "snappy decompressed size [" + uncompressedLength + "] exceeds maximum [" + maxSize + "] bytes", + RestStatus.REQUEST_ENTITY_TOO_LARGE + ); + } + if (uncompressedLength == 0) { + return ReleasableBytesReference.empty(); + } + + var out = new PagedOutput(recycler, uncompressedLength); + try { + while (in.available() > 0 && out.written < uncompressedLength) { + int tag = readUByte(in); + switch (tag & 0x03) { + case LITERAL: { + int literalLen = readLiteralLength(tag, in); + out.writeLiteral(in, literalLen); + break; + } + + case COPY_1_BYTE_OFFSET: { + int length = 4 + ((tag & 0x1C) >> 2); + int offset = ((tag & 0xE0) << 8 >> 5) | readUByte(in); + validateOffset(offset, out.written); + out.selfCopy(offset, length); + break; + } + + case COPY_2_BYTE_OFFSET: { + int length = 1 + ((tag >> 2) & 0x3F); + int b0 = readUByte(in); + int b1 = readUByte(in); + int offset = b0 | (b1 << 8); + validateOffset(offset, out.written); + out.selfCopy(offset, length); + break; + } + + case COPY_4_BYTE_OFFSET: { + int length = 1 + ((tag >> 2) & 0x3F); + int b0 = readUByte(in); + int b1 = readUByte(in); + int b2 = readUByte(in); + int b3 = readUByte(in); + int offset = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24); + validateOffset(offset, out.written); + out.selfCopy(offset, length); + break; + } + } + } + + if (out.written != uncompressedLength) { + throw new IOException("snappy: expected " + uncompressedLength + " bytes but decoded " + out.written); + } + + return out.toBytesReference(); + } catch (Exception e) { + out.close(); + throw e; + } + } + + private static int readUByte(StreamInput in) throws IOException { + return in.readByte() & 0xFF; + } + + /** + * Reads the literal length from the tag and any following length bytes. + */ + private static int readLiteralLength(int tag, StreamInput in) throws IOException { + int val = (tag >> 2) & 0x3F; + return switch (val) { + case 60 -> readUByte(in) + 1; + case 61 -> (readUByte(in) | (readUByte(in) << 8)) + 1; + case 62 -> (readUByte(in) | (readUByte(in) << 8) | (readUByte(in) << 16)) + 1; + case 63 -> (readUByte(in) | (readUByte(in) << 8) | (readUByte(in) << 16) | (readUByte(in) << 24)) + 1; + default -> val + 1; + }; + } + + private static void validateOffset(int offset, int writtenSoFar) throws IOException { + if (offset <= 0) { + throw new IOException("snappy: invalid copy offset " + offset); + } + if (offset > writtenSoFar) { + throw new IOException("snappy: copy offset " + offset + " exceeds decoded bytes " + writtenSoFar); + } + } + + /** + * Output buffer backed by recycled pages that supports both sequential writes and random + * read access (needed for Snappy back-reference copies). + */ + static final class PagedOutput { + private final ArrayList> pages; + private final int pageSize; + private final int uncompressedLength; + int written; + + PagedOutput(Recycler recycler, int uncompressedLength) { + this.pageSize = recycler.pageSize(); + this.uncompressedLength = uncompressedLength; + int numPages = Math.ceilDiv(uncompressedLength, pageSize); + this.pages = new ArrayList<>(numPages); + try { + for (int i = 0; i < numPages; i++) { + pages.add(recycler.obtain()); + } + } catch (Exception e) { + Releasables.close(pages); + throw e; + } + } + + /** + * Reads {@code length} bytes from the input and writes them directly into + * output pages (literal copy, no intermediate buffer). + */ + void writeLiteral(StreamInput in, int length) throws IOException { + validateOutputLength(length); + int remaining = length; + while (remaining > 0) { + int pageOff = written % pageSize; + BytesRef page = pages.get(written / pageSize).v(); + int spaceInPage = pageSize - pageOff; + int toCopy = Math.min(remaining, spaceInPage); + + // Copy from the input stream directly into the output page + in.readBytes(page.bytes, page.offset + pageOff, toCopy); + written += toCopy; + remaining -= toCopy; + } + } + + /** + * Copies {@code length} bytes from {@code backOffset} bytes behind the write position, + * handling the overlap case where {@code backOffset < length} (run-length repetition). + */ + void selfCopy(int backOffset, int length) throws IOException { + validateOutputLength(length); + int srcPos = written - backOffset; + if (backOffset >= length) { + bulkCopy(srcPos, length); + } else { + int remaining = length; + // Each iteration doubles the copyable region: after copying backOffset bytes, + // the destination now contains those bytes too, extending the available source. + int copyable = backOffset; + while (remaining > 0) { + int toCopy = Math.min(remaining, copyable); + bulkCopy(srcPos, toCopy); + remaining -= toCopy; + copyable += toCopy; + } + } + } + + private void validateOutputLength(int length) throws IOException { + // widen to long to avoid integer overflow + if ((long) written + length > uncompressedLength) { + throw new IOException( + "snappy: output of " + + ((long) written + length) + + " bytes would exceed declared uncompressed length of " + + uncompressedLength + ); + } + } + + /** + * Bulk-copies {@code length} bytes from absolute position {@code srcPos} to the current + * write position, handling page-boundary crossings on both source and destination. + */ + private void bulkCopy(int srcPos, int length) { + int remaining = length; + int src = srcPos; + while (remaining > 0) { + int srcPageOff = src % pageSize; + int dstPageOff = written % pageSize; + + BytesRef srcPage = pages.get(src / pageSize).v(); + BytesRef dstPage = pages.get(written / pageSize).v(); + + int srcAvail = pageSize - srcPageOff; + int dstAvail = pageSize - dstPageOff; + int toCopy = Math.min(remaining, Math.min(srcAvail, dstAvail)); + + System.arraycopy(srcPage.bytes, srcPage.offset + srcPageOff, dstPage.bytes, dstPage.offset + dstPageOff, toCopy); + + src += toCopy; + written += toCopy; + remaining -= toCopy; + } + } + + /** Assembles the written pages into a {@link ReleasableBytesReference}. */ + ReleasableBytesReference toBytesReference() { + assert pages.isEmpty() == false + : "toBytesReference() should only be called when uncompressedLength > 0, so at least one page must exist"; + if (pages.size() == 1) { + Recycler.V page = pages.getFirst(); + BytesRef ref = page.v(); + return new ReleasableBytesReference(new BytesArray(ref.bytes, ref.offset, written), page); + } + BytesReference[] refs = new BytesReference[pages.size()]; + int remaining = written; + for (int i = 0; i < pages.size(); i++) { + BytesRef ref = pages.get(i).v(); + int len = Math.min(remaining, pageSize); + refs[i] = new BytesArray(ref.bytes, ref.offset, len); + remaining -= len; + } + return new ReleasableBytesReference(CompositeBytesReference.of(refs), Releasables.wrap(pages)); + } + + void close() { + Releasables.close(pages); + } + } +} diff --git a/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestActionTests.java b/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestActionTests.java index 9bed8eb3d55ea..20da6c278f361 100644 --- a/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestActionTests.java +++ b/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/PrometheusRemoteWriteRestActionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.http.HttpBody; @@ -26,6 +27,7 @@ import org.elasticsearch.test.rest.FakeRestChannel; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.BytesRefRecycler; import org.junit.After; import org.junit.Before; @@ -101,15 +103,38 @@ public void testOversizedBodyReturnsPlainText413() { } } + public void testSuccessfulWriteWithoutSnappy() { + client = new NoOpNodeClient(threadPool) { + @Override + @SuppressWarnings("unchecked") + public void doExecute( + ActionType actionType, + Request req, + ActionListener listener + ) { + assertThat(actionType, equalTo(PrometheusRemoteWriteTransportAction.TYPE)); + var remoteWriteRequest = (PrometheusRemoteWriteTransportAction.RemoteWriteRequest) req; + assertThat(remoteWriteRequest.remoteWriteRequest.length(), equalTo(64)); + remoteWriteRequest.close(); + listener.onResponse((Response) new PrometheusRemoteWriteTransportAction.RemoteWriteResponse()); + } + }; + try (var response = executeRemoteWrite(1024, 64, false)) { + assertThat(response.status(), equalTo(RestStatus.NO_CONTENT)); + } + } + private RestResponse executeRemoteWrite(int maxSize, int bodySize) { + return executeRemoteWrite(maxSize, bodySize, true); + } + + private RestResponse executeRemoteWrite(int maxSize, int bodySize, boolean snappy) { var stream = new FakeHttpBodyStream(); - var action = new PrometheusRemoteWriteRestAction(indexingPressure, maxSize); - var httpRequest = new FakeRestRequest.FakeHttpRequest( - RestRequest.Method.POST, - "/_prometheus/api/v1/write", - Map.of("Content-Type", List.of("application/x-protobuf")), - stream - ); + var action = new PrometheusRemoteWriteRestAction(indexingPressure, maxSize, BytesRefRecycler.NON_RECYCLING_INSTANCE); + var headers = snappy + ? Map.of("Content-Type", List.of("application/x-protobuf"), "Content-Encoding", List.of("snappy")) + : Map.of("Content-Type", List.of("application/x-protobuf")); + var httpRequest = new FakeRestRequest.FakeHttpRequest(RestRequest.Method.POST, "/_prometheus/api/v1/write", headers, stream); var request = RestRequest.request(parserConfig(), httpRequest, new FakeRestRequest.FakeHttpChannel(null)); var channel = new FakeRestChannel(request, true, 1); var consumer = (BaseRestHandler.RequestBodyChunkConsumer) action.prepareRequest(request, client); @@ -129,7 +154,11 @@ public void close() { } catch (Exception e) { throw new AssertionError(e); } - stream.sendNext(randomReleasableBytesReference(bodySize), true); + byte[] body = randomByteArrayOfLength(bodySize); + if (snappy) { + body = SnappyBlockDecoderTests.snappyEncode(body); + } + stream.sendNext(new ReleasableBytesReference(new BytesArray(body), () -> {}), true); RestResponse response = channel.capturedResponse(); assertNotNull(response); return response; diff --git a/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/SnappyBlockDecoderTests.java b/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/SnappyBlockDecoderTests.java new file mode 100644 index 0000000000000..f455501ece634 --- /dev/null +++ b/x-pack/plugin/prometheus/src/test/java/org/elasticsearch/xpack/prometheus/rest/SnappyBlockDecoderTests.java @@ -0,0 +1,581 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.prometheus.rest; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.compression.Snappy; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.MockBytesRefRecycler; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Arrays; + +import static org.elasticsearch.common.bytes.BytesReferenceTestUtils.equalBytes; + +public class SnappyBlockDecoderTests extends ESTestCase { + + private final MockBytesRefRecycler recycler = new MockBytesRefRecycler(); + private final SnappyBlockDecoder decoder = new SnappyBlockDecoder(recycler); + + @After + public void closeRecycler() { + recycler.close(); + } + + public void testDecodeLiteralOnly() throws IOException { + byte[] original = randomByteArrayOfLength(between(1, 1000)); + byte[] compressed = snappyEncode(original); + + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), original.length)) { + assertEquals(original.length, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + public void testDecodeEmpty() throws IOException { + // Snappy block: preamble = 0 (empty) + byte[] compressed = new byte[] { 0 }; + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), 1024)) { + assertEquals(0, result.length()); + } + } + + public void testDecodeRejectsOversizedOutput() { + int length = between(10, 1000); + byte[] original = randomByteArrayOfLength(length); + byte[] compressed = snappyEncode(original); + int maxSize = between(1, length - 1); + + var ex = expectThrows(Exception.class, () -> { + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), maxSize)) { + fail("expected exception for input of length " + length + " with maxSize " + maxSize); + } + }); + assertTrue(ex.getMessage().contains("exceeds maximum")); + } + + public void testDecodeWithBackReference() throws IOException { + // Create data with repetition to trigger back-references + byte[] original = new byte[256]; + byte[] pattern = randomByteArrayOfLength(8); + for (int i = 0; i < original.length; i++) { + original[i] = pattern[i % pattern.length]; + } + + byte[] compressed = snappyEncode(original); + // Compressed should be smaller due to repetition + assertTrue("expected compression, got " + compressed.length + " >= " + original.length, compressed.length < original.length); + + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), original.length)) { + assertEquals(original.length, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + public void testDecodeReleasesInput() throws IOException { + byte[] original = randomByteArrayOfLength(64); + byte[] compressed = snappyEncode(original); + var input = new ReleasableBytesReference(new BytesArray(compressed), () -> {}); + + try (var result = decoder.process(input, 1024)) { + assertFalse("input should be released after process()", input.hasReferences()); + } + } + + public void testDecodeMalformedInput() { + // Preamble says 10 bytes uncompressed, then a COPY_4_BYTE_OFFSET tag with insufficient data + byte[] garbage = new byte[] { 10, (byte) 0xFF }; + expectThrows(IOException.class, () -> { + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(garbage), () -> {}), 1024)) { + fail("expected IOException for malformed input"); + } + }); + } + + public void testDecodeEmptyInput() { + expectThrows(IOException.class, () -> { + try (var result = decoder.process(new ReleasableBytesReference(BytesArray.EMPTY, () -> {}), 1024)) { + fail("expected IOException for empty input"); + } + }); + } + + public void testDecodeRoundTripsVariousSizes() throws IOException { + for (int size : new int[] { 1, 15, 16, 60, 61, 255, 256, 4096, 16384 }) { + byte[] original = randomByteArrayOfLength(size); + byte[] compressed = snappyEncode(original); + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), size)) { + assertEquals(size, result.length()); + assertArrayEquals("failed for size " + size, original, BytesReference.toBytes(result)); + } + } + } + + public void testDecodeSpansMultiplePages() throws IOException { + // 48 KiB of repeating pattern exercises page boundaries (page = 16 KiB) and back-references + int size = 48 * 1024; + byte[] original = new byte[size]; + byte[] pattern = randomByteArrayOfLength(between(4, 64)); + for (int i = 0; i < size; i++) { + original[i] = pattern[i % pattern.length]; + } + byte[] compressed = snappyEncode(original); + assertTrue("expected compression", compressed.length < original.length); + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), size)) { + assertEquals(size, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + public void testDecodeLargeRandomDataAcrossPages() throws IOException { + // Random data > 16 KiB doesn't compress well but exercises multi-page literal writes + int size = 32 * 1024 + between(1, 1000); + byte[] original = randomByteArrayOfLength(size); + byte[] compressed = snappyEncode(original); + try (var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), size)) { + assertEquals(size, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + public void testDecodeBackRefAcrossPageBoundary() throws IOException { + // Place a pattern just before a page boundary, then repeat it across the boundary. + // Page size is 16384, so we put ~16380 bytes of random data, then a 16-byte pattern, + // then repeat the pattern several times to create a back-ref that spans pages. + int prefixLen = 16384 - 16; + byte[] prefix = randomByteArrayOfLength(prefixLen); + byte[] pattern = randomByteArrayOfLength(16); + int repetitions = 20; + int totalSize = prefixLen + pattern.length * (1 + repetitions); + byte[] original = new byte[totalSize]; + System.arraycopy(prefix, 0, original, 0, prefixLen); + for (int i = 0; i <= repetitions; i++) { + System.arraycopy(pattern, 0, original, prefixLen + i * pattern.length, pattern.length); + } + byte[] compressed = snappyEncode(original); + try ( + var result = decoder.process(new ReleasableBytesReference(new BytesArray(compressed), () -> {}), totalSize + between(0, 1024)) + ) { + assertEquals(totalSize, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + public void testDecodeWithCompositeInput() throws IOException { + // Split compressed data across multiple BytesReference chunks to exercise InputCursor chunk transitions + byte[] original = randomByteArrayOfLength(between(100, 1000)); + byte[] compressed = snappyEncode(original); + + // Split into 3 chunks at arbitrary points + int split1 = between(1, compressed.length - 2); + int split2 = between(split1 + 1, compressed.length - 1); + BytesReference composite = CompositeBytesReference.of( + new BytesArray(compressed, 0, split1), + new BytesArray(compressed, split1, split2 - split1), + new BytesArray(compressed, split2, compressed.length - split2) + ); + + try (var result = decoder.process(new ReleasableBytesReference(composite, () -> {}), original.length + between(0, 1024))) { + assertEquals(original.length, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + public void testDecodeCopy4ByteOffset() throws IOException { + // Hand-craft a Snappy stream with COPY_4_BYTE_OFFSET (tag type 3). + // Netty's encoder never produces this tag type, so manual construction is needed. + // Use an offset > 65535 to exercise the 4-byte offset path specifically. + int literalLen = 65536; + byte[] literal = randomByteArrayOfLength(literalLen); + int copyOffset = literalLen; + int copyLength = 10; + int total = literalLen + copyLength; + + var buf = new ByteArrayOutputStream(); + writeVarint(buf, total); + buf.write(61 << 2); // literal tag, case 61 (2-byte length) + int encodedLiteralLen = literalLen - 1; + buf.write(encodedLiteralLen & 0xFF); + buf.write((encodedLiteralLen >> 8) & 0xFF); + buf.write(literal); + buf.write(0x03 | ((copyLength - 1) << 2)); // COPY_4_BYTE_OFFSET tag + writeIntLE(buf, copyOffset); + + byte[] expected = new byte[total]; + System.arraycopy(literal, 0, expected, 0, literalLen); + System.arraycopy(literal, 0, expected, literalLen, copyLength); + + byte[] compressed = buf.toByteArray(); + try (var result = decoder.process(releasable(compressed), total)) { + assertEquals(total, result.length()); + assertArrayEquals(expected, BytesReference.toBytes(result)); + } + } + + public void testDecodeRunLengthRepetition() throws IOException { + // Single byte repeated forces copies with offset=1, which triggers + // the backOffset < length path in selfCopy (run-length encoding). + byte value = randomByte(); + int size = between(256, 1024); + byte[] original = new byte[size]; + Arrays.fill(original, value); + + byte[] compressed = snappyEncode(original); + assertTrue("expected compression", compressed.length < original.length); + + try (var result = decoder.process(releasable(compressed), size)) { + assertEquals(size, result.length()); + byte[] decoded = BytesReference.toBytes(result); + for (int i = 0; i < size; i++) { + assertEquals("mismatch at index " + i, value, decoded[i]); + } + } + } + + public void testDecodeLiteralLengthCase60() throws IOException { + // Case 60: 1 extra byte encodes literal lengths 61-256 + int length = between(61, 256); + assertDecodesHandCraftedLiteral(length, 60); + } + + public void testDecodeLiteralLengthCase61() throws IOException { + // Case 61: 2 extra bytes encode literal lengths 257-65536 + int length = between(257, 65536); + assertDecodesHandCraftedLiteral(length, 61); + } + + public void testDecodeLiteralLengthCase62() throws IOException { + // Case 62: 3 extra bytes encode literal lengths 65537+ + int length = between(65537, 70000); + assertDecodesHandCraftedLiteral(length, 62); + } + + public void testDecodeLiteralLengthCase63() throws IOException { + // Case 63: 4 extra bytes encode the literal length + int length = between(1, 1000); + assertDecodesHandCraftedLiteral(length, 63); + } + + private void assertDecodesHandCraftedLiteral(int length, int caseNum) throws IOException { + byte[] data = randomByteArrayOfLength(length); + var buf = new ByteArrayOutputStream(); + writeVarint(buf, length); + buf.write(caseNum << 2); // literal tag with specified case number + int encodedLen = length - 1; + int extraBytes = caseNum - 59; // case 60 → 1 byte, 61 → 2, 62 → 3 + for (int i = 0; i < extraBytes; i++) { + buf.write((encodedLen >> (i * 8)) & 0xFF); + } + buf.write(data); + + byte[] compressed = buf.toByteArray(); + try (var result = decoder.process(releasable(compressed), length)) { + assertEquals(length, result.length()); + assertArrayEquals(data, BytesReference.toBytes(result)); + } + } + + public void testDecodeOverflowVarintPreamble() { + // 5-byte varint that decodes to -1 (0xFFFFFFFF), triggering the negative length check + byte[] overflow = new byte[] { (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, 0x0F }; + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(overflow), Long.MAX_VALUE)) { + fail("should not reach here"); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("negative")); + } + + public void testDecodeTooLargeVarintPreamble() { + // 5 continuation bytes → shift reaches 35, exceeding the 32-bit limit + byte[] tooLarge = new byte[] { (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, 0x00 }; + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(tooLarge), Long.MAX_VALUE)) { + fail("should not reach here"); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("preamble")); + } + + public void testDecodeReleasesResourcesOnError() throws IOException { + var trackingRecycler = new MockBytesRefRecycler(); + var trackingDecoder = new SnappyBlockDecoder(trackingRecycler); + + // Stream claims 40000 bytes but only provides a 20000-byte literal. + // The decoder allocates pages for the literal, then fails on the length mismatch. + byte[] literal = randomByteArrayOfLength(20000); + var buf = new ByteArrayOutputStream(); + writeVarint(buf, 40000); + buf.write(61 << 2); // literal tag, case 61 (2-byte length) + int len = 20000 - 1; + buf.write(len & 0xFF); + buf.write((len >> 8) & 0xFF); + try { + buf.write(literal); + } catch (IOException e) { + throw new AssertionError(e); + } + + byte[] compressed = buf.toByteArray(); + expectThrows(IOException.class, () -> { + try (var result = trackingDecoder.process(releasable(compressed), 40000)) { + fail("expected IOException for truncated stream"); + } + }); + trackingRecycler.close(); + } + + public void testDecodeLiteralExceedsDeclaredLength() throws IOException { + // Preamble claims 5 bytes, but the literal tag says 10 bytes + var buf = new ByteArrayOutputStream(); + writeVarint(buf, 5); + buf.write(9 << 2); // literal tag, val=9, length=10 + buf.write(new byte[10], 0, 10); + + byte[] compressed = buf.toByteArray(); + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(compressed), 1024)) { + fail(); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("would exceed declared uncompressed length")); + } + + public void testDecodeCopyExceedsDeclaredLength() throws IOException { + // Preamble claims 6 bytes, literal writes 5, then a copy tries to write 4 more + var buf = new ByteArrayOutputStream(); + writeVarint(buf, 6); + buf.write(4 << 2); // literal tag, val=4, length=5 + buf.write(new byte[5], 0, 5); + // COPY_1_BYTE_OFFSET: length=4, offset=5 + buf.write(0x01); // length = 4 + ((0x01 & 0x1C) >> 2) = 4 + buf.write(5); + + byte[] compressed = buf.toByteArray(); + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(compressed), 1024)) { + fail(); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("would exceed declared uncompressed length")); + } + + public void testDecodeSnappyBomb() throws IOException { + // A small compressed payload that claims a huge uncompressed length (just under maxSize) + // but the actual tags try to produce even more output than declared. + // Preamble declares 100 bytes, a single 1-byte literal, then a run-length copy + // with offset=1 and length=60 that fills to 61, then another copy pushing past 100. + var buf = new ByteArrayOutputStream(); + writeVarint(buf, 100); + buf.write(0 << 2); // literal tag, val=0, length=1 + buf.write(0x42); + // COPY_2_BYTE_OFFSET: length = 1 + ((tag >> 2) & 0x3F), offset in next 2 bytes + // length=64 (max for this tag type), offset=1 (run-length expansion) + buf.write(0x02 | (63 << 2)); // COPY_2_BYTE_OFFSET, length = 1 + 63 = 64 + buf.write(1); // offset low + buf.write(0); // offset high + // Now written=65. Another copy of 64 would push to 129, exceeding declared 100. + buf.write(0x02 | (63 << 2)); + buf.write(1); + buf.write(0); + + byte[] compressed = buf.toByteArray(); + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(compressed), 1024)) { + fail(); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("would exceed declared uncompressed length")); + } + + public void testDecodeInvalidCopyOffsetZero() throws IOException { + // COPY_2_BYTE_OFFSET with offset=0 triggers "invalid copy offset" error + var buf = new ByteArrayOutputStream(); + writeVarint(buf, 10); + buf.write(4 << 2); // literal tag, val=4, length=5 + buf.write(new byte[5], 0, 5); + buf.write(0x02 | (4 << 2)); // COPY_2_BYTE_OFFSET, length=5 + buf.write(0); // offset low byte + buf.write(0); // offset high byte + + byte[] compressed = buf.toByteArray(); + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(compressed), 10)) { + fail("expected IOException for zero copy offset"); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("invalid copy offset")); + } + + public void testDecodeCopyOffsetExceedsWritten() throws IOException { + // COPY_1_BYTE_OFFSET with offset=5 but only 2 bytes written so far + var buf = new ByteArrayOutputStream(); + writeVarint(buf, 10); + buf.write(1 << 2); // literal tag, val=1, length=2 + buf.write(new byte[2], 0, 2); + buf.write(0x01); // COPY_1_BYTE_OFFSET, length=4, offset_high=0 + buf.write(5); // offset low byte + + byte[] compressed = buf.toByteArray(); + var ex = expectThrows(IOException.class, () -> { + try (var result = decoder.process(releasable(compressed), 10)) { + fail("expected IOException for offset exceeding written bytes"); + } + }); + assertTrue(ex.getMessage(), ex.getMessage().contains("exceeds decoded bytes")); + } + + public void testDecodeTruncatedVarintPreamble() throws IOException { + // Varint byte 0x80 has continuation bit set but no following byte. + // The partial value is 0, so the decoder treats it as an empty block. + byte[] truncated = new byte[] { (byte) 0x80 }; + try (var result = decoder.process(releasable(truncated), 1024)) { + assertEquals(0, result.length()); + } + } + + public void testDecodeTrailingDataAfterBlock() throws IOException { + // Valid compressed block with trailing bytes after the complete block. + // Exercises the out.written < uncompressedLength loop exit condition. + byte[] original = randomByteArrayOfLength(between(1, 100)); + byte[] compressed = snappyEncode(original); + int trailingLen = between(1, 10); + byte[] withTrailing = new byte[compressed.length + trailingLen]; + System.arraycopy(compressed, 0, withTrailing, 0, compressed.length); + for (int i = compressed.length; i < withTrailing.length; i++) { + withTrailing[i] = randomByte(); + } + + try (var result = decoder.process(releasable(withTrailing), original.length)) { + assertEquals(original.length, result.length()); + assertArrayEquals(original, BytesReference.toBytes(result)); + } + } + + /** + * Randomly synthesize a valid Snappy block as a sequence of arbitrary tagged elements and assert it decodes correctly. + */ + public void testSyntheticCompressedStream() throws IOException { + final var uncompressed = new byte[scaledRandomIntBetween(0, ByteSizeUnit.MB.toIntBytes(32))]; + final var compressed = new RecyclerBytesStreamOutput(recycler); + writeVarint(compressed, uncompressed.length); + + int uncompressedPosition = 0; + while (true) { + int remaining = uncompressed.length - uncompressedPosition; + if (remaining == 0) { + break; + } + if (uncompressedPosition == 0 || randomBoolean()) { + final var literal = randomByteArrayOfLength(scaledRandomIntBetween(1, remaining)); + System.arraycopy(literal, 0, uncompressed, uncompressedPosition, literal.length); + writeLiteralLength(compressed, literal.length); + compressed.write(literal); + uncompressedPosition += literal.length; + } else { + final int copyLength = between(1, Math.min(remaining, 64)); + int copyPosition = between(0, uncompressedPosition - 1); + writeCopy(compressed, uncompressedPosition - copyPosition, copyLength); + for (int i = 0; i < copyLength; i++) { + uncompressed[uncompressedPosition++] = uncompressed[copyPosition++]; + } + } + } + + try (var decoded = decoder.process(compressed.moveToBytesReference(), uncompressed.length + between(0, 1024))) { + assertThat(decoded, equalBytes(new BytesArray(uncompressed))); + } + } + + private void writeLiteralLength(OutputStream out, int length) throws IOException { + int offsetLength = length - 1; + if (offsetLength > 0xFFFFFF || randomBoolean()) { + out.write(63 << 2); + out.write(offsetLength & 0xFF); + out.write((offsetLength >> 8) & 0xFF); + out.write((offsetLength >> 16) & 0xFF); + out.write((offsetLength >> 24) & 0xFF); + } else if (offsetLength > 0xFFFF || randomBoolean()) { + out.write(62 << 2); + out.write(offsetLength & 0xFF); + out.write((offsetLength >> 8) & 0xFF); + out.write((offsetLength >> 16) & 0xFF); + } else if (offsetLength > 0xFF || randomBoolean()) { + out.write(61 << 2); + out.write(offsetLength & 0xFF); + out.write((offsetLength >> 8) & 0xFF); + } else if (offsetLength > 59 || randomBoolean()) { + out.write(60 << 2); + out.write(offsetLength & 0xFF); + } else { + out.write(offsetLength << 2); + } + } + + private void writeCopy(OutputStream out, int offset, int length) throws IOException { + if (offset > 0xFFFF || randomBoolean()) { + out.write(0x03 | ((length - 1) << 2)); + out.write(offset & 0xFF); + out.write((offset >> 8) & 0xFF); + out.write((offset >> 16) & 0xFF); + out.write((offset >> 24) & 0xFF); + } else if (offset > 0x7FF || ((length - 4) | 7) != 7 || randomBoolean()) { + out.write(0x02 | ((length - 1) << 2)); + out.write(offset & 0xFF); + out.write((offset >> 8) & 0xFF); + } else { + out.write(0x01 | (length - 4 << 2) | (((offset >> 8) & 0x07) << 5)); + out.write(offset & 0xFF); + } + } + + private static ReleasableBytesReference releasable(byte[] data) { + return new ReleasableBytesReference(new BytesArray(data), () -> {}); + } + + private static void writeVarint(OutputStream buf, int value) throws IOException { + while ((value & ~0x7F) != 0) { + buf.write((value & 0x7F) | 0x80); + value >>>= 7; + } + buf.write(value); + } + + private static void writeIntLE(ByteArrayOutputStream buf, int value) { + buf.write(value & 0xFF); + buf.write((value >> 8) & 0xFF); + buf.write((value >> 16) & 0xFF); + buf.write((value >> 24) & 0xFF); + } + + /** Encodes data using Netty's Snappy encoder for test round-trip verification. */ + static byte[] snappyEncode(byte[] input) { + ByteBuf in = Unpooled.wrappedBuffer(input); + ByteBuf out = Unpooled.buffer(input.length); + try { + new Snappy().encode(in, out, input.length); + byte[] result = new byte[out.readableBytes()]; + out.readBytes(result); + return result; + } finally { + in.release(); + out.release(); + } + } +}