diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java index 37493b6b33..1ace512684 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/RestClientProxyExchange.java @@ -22,12 +22,14 @@ import org.springframework.cloud.gateway.server.mvc.common.MvcUtils; import org.springframework.http.client.ClientHttpResponse; -import org.springframework.util.StreamUtils; +import org.springframework.util.Assert; import org.springframework.web.client.RestClient; import org.springframework.web.servlet.function.ServerResponse; public class RestClientProxyExchange implements ProxyExchange { + private static final int BUFFER_SIZE = 16384; + private final RestClient restClient; public RestClientProxyExchange(RestClient restClient) { @@ -44,7 +46,33 @@ public ServerResponse exchange(Request request) { } private static int copyBody(Request request, OutputStream outputStream) throws IOException { - return StreamUtils.copy(request.getServerRequest().servletRequest().getInputStream(), outputStream); + return copy(request.getServerRequest().servletRequest().getInputStream(), outputStream); + } + + private static int copy(InputStream inputStream, OutputStream outputStream) throws IOException { + Assert.notNull(inputStream, "No InputStream specified"); + Assert.notNull(outputStream, "No OutputStream specified"); + + int readBytes; + var totalReadBytes = 0; + var buffer = new byte[BUFFER_SIZE]; + + while ((readBytes = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, readBytes); + outputStream.flush(); + if (totalReadBytes < Integer.MAX_VALUE) { + try { + totalReadBytes = Math.addExact(totalReadBytes, readBytes); + } + catch (ArithmeticException e) { + totalReadBytes = Integer.MAX_VALUE; + } + } + } + + outputStream.flush(); + + return totalReadBytes; } private static ServerResponse doExchange(Request request, ClientHttpResponse clientResponse) throws IOException { @@ -59,7 +87,7 @@ private static ServerResponse doExchange(Request request, ClientHttpResponse cli InputStream inputStream = MvcUtils.getAttribute(request.getServerRequest(), MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR); // copy body from request to clientHttpRequest - StreamUtils.copy(inputStream, httpServletResponse.getOutputStream()); + copy(inputStream, httpServletResponse.getOutputStream()); } return null; });