Skip to content

Commit

Permalink
Merge pull request opensearch-project#553 from AndreKurait/FixChunked…
Browse files Browse the repository at this point in the history
…Headers

Fix chunked headers
  • Loading branch information
AndreKurait authored Apr 10, 2024
2 parents 48110cb + b04d2a1 commit 49355f9
Show file tree
Hide file tree
Showing 17 changed files with 437 additions and 215 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ logs

# Build files
plugins/opensearch/loggable-transport-netty4/.gradle/
TrafficCapture/**/out/

RFS/.gradle/
RFS/bin/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package org.opensearch.migrations.replay;

import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -59,7 +64,17 @@ public void addingCompressionRequestHeaderCompressesPayload() throws ExecutionEx
() -> "AddCompressionEncodingTest.fullyProcessedResponse");
fullyProcessedResponse.get();

try (var bais = new ByteArrayInputStream(testPacketCapture.getBytesCaptured());

EmbeddedChannel channel = new EmbeddedChannel(
new HttpServerCodec(),
new HttpObjectAggregator(Utils.MAX_PAYLOAD_SIZE_TO_PRINT) // Set max content length if needed
);

channel.writeInbound(Unpooled.wrappedBuffer(testPacketCapture.getBytesCaptured()));
var compressedRequest = ((FullHttpRequest) channel.readInbound());
var compressedByteArr = new byte[compressedRequest.content().readableBytes()];
compressedRequest.content().getBytes(0, compressedByteArr);
try (var bais = new ByteArrayInputStream(compressedByteArr);
var unzipStream = new GZIPInputStream(bais);
var isr = new InputStreamReader(unzipStream, StandardCharsets.UTF_8);
var br = new BufferedReader(isr)) {
Expand All @@ -76,5 +91,6 @@ public void addingCompressionRequestHeaderCompressesPayload() throws ExecutionEx
} while (true);
Assertions.assertEquals(numParts*payloadPartSize, counter);
}
compressedRequest.release();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import lombok.extern.slf4j.Slf4j;

import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -29,6 +28,10 @@
@Slf4j
public class HttpByteBufFormatter {

public static final String CRLF_LINE_DELIMITER = "\r\n";
public static final String LF_LINE_DELIMITER = "\n";
private static final String DEFAULT_LINE_DELIMITER = CRLF_LINE_DELIMITER;

private static final ThreadLocal<Optional<PacketPrintFormat>> printStyle =
ThreadLocal.withInitial(Optional::empty);

Expand Down Expand Up @@ -58,41 +61,56 @@ public static <T> T setPrintStyleFor(PacketPrintFormat packetPrintFormat, Suppli
public enum HttpMessageType { REQUEST, RESPONSE }

public static String httpPacketBytesToString(HttpMessageType msgType, List<byte[]> byteArrStream) {
return httpPacketBytesToString(msgType,
Optional.ofNullable(byteArrStream).map(p -> p.stream()).orElse(Stream.of()));
return httpPacketBytesToString(msgType, byteArrStream, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBytesToString(HttpMessageType msgType, Stream<byte[]> byteArrStream) {
return httpPacketBytesToString(msgType, byteArrStream, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBufsToString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean releaseByteBufs) {
return httpPacketBufsToString(msgType, byteBufStream, releaseByteBufs, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBytesToString(HttpMessageType msgType, List<byte[]> byteArrStream, String lineDelimiter) {
return httpPacketBytesToString(msgType,
Optional.ofNullable(byteArrStream).map(p -> p.stream()).orElse(Stream.of()), lineDelimiter);
}

public static String httpPacketBytesToString(HttpMessageType msgType, Stream<byte[]> byteArrStream, String lineDelimiter) {
// This isn't memory efficient,
// but stringifying byte bufs through a full parse and reserializing them was already really slow!
return httpPacketBufsToString(msgType, byteArrStream.map(Unpooled::wrappedBuffer), true);
return httpPacketBufsToString(msgType, byteArrStream.map(Unpooled::wrappedBuffer), true, lineDelimiter);
}

public static String httpPacketBufsToString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean releaseByteBufs) {
boolean releaseByteBufs, String lineDelimiter) {
switch (printStyle.get().orElse(PacketPrintFormat.TRUNCATED)) {
case TRUNCATED:
return httpPacketBufsToString(byteBufStream, Utils.MAX_BYTES_SHOWN_FOR_TO_STRING, releaseByteBufs);
case FULL_BYTES:
return httpPacketBufsToString(byteBufStream, Long.MAX_VALUE, releaseByteBufs);
case PARSED_HTTP:
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, false, releaseByteBufs);
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, false, releaseByteBufs,
lineDelimiter);
case PARSED_HTTP_SORTED_HEADERS:
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, true, releaseByteBufs);
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, true, releaseByteBufs,
lineDelimiter);
default:
throw new IllegalStateException("Unknown PacketPrintFormat: " + printStyle.get());
}
}

public static String httpPacketsToPrettyPrintedString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean sortHeaders, boolean releaseByteBufs) {
boolean sortHeaders, boolean releaseByteBufs, String lineDelimiter) {
HttpMessage httpMessage = parseHttpMessageFromBufs(msgType, byteBufStream, releaseByteBufs);
var holderOp = Optional.ofNullable((httpMessage instanceof ByteBufHolder) ? (ByteBufHolder) httpMessage : null);
try {
if (httpMessage instanceof FullHttpRequest) {
return prettyPrintNettyRequest((FullHttpRequest) httpMessage, sortHeaders);
return prettyPrintNettyRequest((FullHttpRequest) httpMessage, sortHeaders, lineDelimiter);
} else if (httpMessage instanceof FullHttpResponse) {
return prettyPrintNettyResponse((FullHttpResponse) httpMessage, sortHeaders);
return prettyPrintNettyResponse((FullHttpResponse) httpMessage, sortHeaders, lineDelimiter);
} else if (httpMessage == null) {
return "[NULL]";
} else {
Expand All @@ -104,14 +122,14 @@ public static String httpPacketsToPrettyPrintedString(HttpMessageType msgType, S
}
}

public static String prettyPrintNettyRequest(FullHttpRequest msg, boolean sortHeaders) {
var sj = new StringJoiner("\n");
public static String prettyPrintNettyRequest(FullHttpRequest msg, boolean sortHeaders, String lineDelimiter) {
var sj = new StringJoiner(lineDelimiter);
sj.add(msg.method() + " " + msg.uri() + " " + msg.protocolVersion().text());
return prettyPrintNettyMessage(sj, sortHeaders, msg, msg.content());
}

static String prettyPrintNettyResponse(FullHttpResponse msg, boolean sortHeaders) {
var sj = new StringJoiner("\n");
public static String prettyPrintNettyResponse(FullHttpResponse msg, boolean sortHeaders, String lineDelimiter) {
var sj = new StringJoiner(lineDelimiter);
sj.add(msg.protocolVersion().text() + " " + msg.status().code() + " " + msg.status().reasonPhrase());
return prettyPrintNettyMessage(sj, sortHeaders, msg, msg.content());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ public Stream<byte[]> stream() {
}

public String format(Optional<HttpByteBufFormatter.HttpMessageType> messageTypeOp) {
var packetBytesAsStr = messageTypeOp.map(mt-> HttpByteBufFormatter.httpPacketBytesToString(mt, packetBytes))
var packetBytesAsStr = messageTypeOp.map(mt-> HttpByteBufFormatter.httpPacketBytesToString(mt, packetBytes,
HttpByteBufFormatter.LF_LINE_DELIMITER))
.orElseGet(()-> HttpByteBufFormatter.httpPacketBufsToString(
packetBytes.stream().map(Unpooled::wrappedBuffer),
Utils.MAX_PAYLOAD_SIZE_TO_PRINT, true));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.opensearch.migrations.replay;

import static org.opensearch.migrations.replay.HttpByteBufFormatter.LF_LINE_DELIMITER;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.opensearch.migrations.replay.datatypes.HttpRequestTransformationStatus;
Expand Down Expand Up @@ -52,9 +54,9 @@ public String toString() {
if (targetResponseDuration != null) { sj.add("targetResponseDuration=").add(targetResponseDuration+""); }
Optional.ofNullable(targetRequestData).ifPresent(d-> sj.add("targetRequestData=")
.add(d.isClosed() ? "CLOSED" : HttpByteBufFormatter.httpPacketBufsToString(
HttpByteBufFormatter.HttpMessageType.REQUEST, d.streamUnretained(), false)));
HttpByteBufFormatter.HttpMessageType.REQUEST, d.streamUnretained(), false, LF_LINE_DELIMITER)));
Optional.ofNullable(targetResponseData).filter(d->!d.isEmpty()).ifPresent(d -> sj.add("targetResponseData=")
.add(HttpByteBufFormatter.httpPacketBytesToString(HttpByteBufFormatter.HttpMessageType.RESPONSE, d)));
.add(HttpByteBufFormatter.httpPacketBytesToString(HttpByteBufFormatter.HttpMessageType.RESPONSE, d, LF_LINE_DELIMITER)));
sj.add("transformStatus=").add(transformationStatus+"");
sj.add("errorCause=").add(errorCause == null ? "none" : errorCause.toString());
return sj.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public DiagnosticTrackableCompletableFuture<String, TransformedOutputAndResult<R
.log();
return redriveWithoutTransformation(pipelineOrchestrator.packetReceiver, e);
} finally {
channel.finishAndReleaseAll();
var cf = channel.close();
if (cf.cause() != null) {
log.atInfo().setCause(cf.cause()).setMessage("Exception encountered during write").log();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg)
} else if (msg instanceof HttpContent) {
ctx.fireChannelRead(msg);
} else {
// ByteBufs shouldn't come through, but in case there's a regression in
// RequestPipelineOrchestrator.removeThisAndPreviousHandlers to remove the handlers
// in order rather in reverse order
assert false: "Only HttpRequest and HttpContent should come through here as per RequestPipelineOrchestrator";
// In case message comes through, pass downstream
super.channelRead(ctx, msg);
}
}
Expand Down Expand Up @@ -101,13 +100,15 @@ private void handlePayloadNeutralTransformationOrThrow(ChannelHandlerContext ctx
} else if (headerFieldsAreIdentical(request, httpJsonMessage)) {
log.info(diagnosticLabel + "Transformation isn't necessary. " +
"Resetting the processing pipeline to let the caller send the original network bytes as-is.");
while (pipeline.first() != null) {
pipeline.removeFirst();
}
RequestPipelineOrchestrator.removeAllHandlers(pipeline);

} else if (headerFieldIsIdentical("content-encoding", request, httpJsonMessage) &&
headerFieldIsIdentical("transfer-encoding", request, httpJsonMessage)) {
log.info(diagnosticLabel + "There were changes to the headers that require the message to be reformatted " +
"but the payload doesn't need to be transformed.");
// By adding the baseline handlers and removing this and previous handlers in reverse order,
// we will cause the upstream handlers to flush their in-progress accumulated ByteBufs downstream
// to be processed accordingly
requestPipelineOrchestrator.addBaselineHandlers(pipeline);
ctx.fireChannelRead(httpJsonMessage);
RequestPipelineOrchestrator.removeThisAndPreviousHandlers(pipeline, this);
Expand Down Expand Up @@ -136,9 +137,10 @@ private void handlePayloadNeutralTransformationOrThrow(ChannelHandlerContext ctx
private boolean headerFieldsAreIdentical(HttpRequest request, HttpJsonMessageWithFaultingPayload httpJsonMessage) {
if (!request.uri().equals(httpJsonMessage.path()) ||
!request.method().toString().equals(httpJsonMessage.method()) ||
request.headers().size() != httpJsonMessage.headers().strictHeadersMap.size()) {
request.headers().names().size() != httpJsonMessage.headers().strictHeadersMap.size()) {
return false;
}
// Depends on header size check above for correctness
for (var headerName : httpJsonMessage.headers().keySet()) {
if (!headerFieldIsIdentical(headerName, request, httpJsonMessage)) {
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package org.opensearch.migrations.replay.datahandlers.http;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetectorFactory;
import lombok.Lombok;
import lombok.extern.slf4j.Slf4j;

import java.io.ByteArrayOutputStream;
Expand All @@ -19,7 +16,6 @@
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
* This class does the remaining serialization of the contents coming into it into ByteBuf
Expand All @@ -37,8 +33,6 @@
*/
@Slf4j
public class NettyJsonToByteBufHandler extends ChannelInboundHandlerAdapter {
// TODO: Eventually, we can count up the size of all of the entries in the headers - but for now, I'm being lazy
public static final int MAX_HEADERS_BYTE_SIZE = 64 * 1024;
List<List<Integer>> sharedInProgressChunkSizes;
ByteBuf inProgressByteBuf;
int payloadBufferIndex;
Expand Down Expand Up @@ -133,7 +127,7 @@ private void writeHeadersIntoByteBufs(ChannelHandlerContext ctx,
var headerChunkSizes = sharedInProgressChunkSizes.get(0);
try {
if (headerChunkSizes.size() > 1) {
writeHeadersAsChunks(ctx, httpJson, headerChunkSizes, MAX_HEADERS_BYTE_SIZE);
writeHeadersAsChunks(ctx, httpJson, headerChunkSizes);
return;
}
} catch (Exception e) {
Expand All @@ -149,28 +143,31 @@ private void writeHeadersIntoByteBufs(ChannelHandlerContext ctx,

private static void writeHeadersAsChunks(ChannelHandlerContext ctx,
HttpJsonMessageWithFaultingPayload httpJson,
List<Integer> headerChunkSizes,
int maxLastBufferSize)
throws IOException
{
AtomicInteger chunkIdx = new AtomicInteger(headerChunkSizes.size());
var bufs = headerChunkSizes.stream()
.map(i -> ctx.alloc().buffer(chunkIdx.decrementAndGet()==0?maxLastBufferSize:i).retain())
.toArray(ByteBuf[]::new);
CompositeByteBuf cbb = null;
List<Integer> headerChunkSizes) throws IOException {
var initialSize = headerChunkSizes.stream().mapToInt(Integer::intValue).sum();

ByteBuf buf = null;
try {
cbb = ctx.alloc().compositeBuffer(bufs.length);
cbb.addComponents(true, bufs);
log.debug("cbb.refcnt=" + cbb.refCnt());
try (var bbos = new ByteBufOutputStream(cbb)) {
buf = ctx.alloc().buffer(initialSize);
try (var bbos = new ByteBufOutputStream(buf)) {
writeHeadersIntoStream(httpJson, bbos);
}
for (var bb : bufs) {
ctx.fireChannelRead(bb);

int index = 0;
var chunkSizeIterator = headerChunkSizes.iterator();
while (index < buf.writerIndex()) {
if (!chunkSizeIterator.hasNext()) {
throw Lombok.sneakyThrow(new IllegalStateException("Ran out of input chunks for mapping"));
}
var inputChunkSize = chunkSizeIterator.next();
var scaledChunkSize = (int) (((long) buf.writerIndex() * inputChunkSize) + (initialSize - 1)) / initialSize;
int actualChunkSize = Math.min(buf.writerIndex() - index, scaledChunkSize);
ctx.fireChannelRead(buf.retainedSlice(index, actualChunkSize));
index += actualChunkSize;
}
} finally {
if (cbb != null) {
cbb.release();
if (buf != null) {
buf.release();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ static void removeThisAndPreviousHandlers(ChannelPipeline pipeline, ChannelHandl
}
}

static void removeAllHandlers(ChannelPipeline pipeline) {
while (pipeline.first() != null) {
pipeline.removeLast();
}
}

void addContentRepackingHandlers(ChannelHandlerContext ctx,
IAuthTransformer.StreamingFullMessageTransformer authTransfomer) {
addContentParsingHandlers(ctx, null, authTransfomer);
Expand Down
Loading

0 comments on commit 49355f9

Please sign in to comment.