Skip to content

Commit

Permalink
Simplify Netty RefCounting and ByteBuf Consumption (#592)
Browse files Browse the repository at this point in the history
* Refactor HttpByteBufFormatter without refCount side effects and add in helpers for refSafe stream mappings

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Add back in IReplayContexts class referencing

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Rename ByteBufUtils to NettyUtils and convert util libraries to final class

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Convert RefSafeHolder get to return object instead of optional

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Add refSafeTransform to RefSafeStreamUtils

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Enable semantic checking for @MustBeClosed with RefCounted holders and streams

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Rename to createRefCntNeutralCloseableByteBufStream and prefer non-static imports

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Add RefSafeStreamUtilsTest

Signed-off-by: Andre Kurait <akurait@amazon.com>

* Remove static import

Signed-off-by: Andre Kurait <akurait@amazon.com>

---------

Signed-off-by: Andre Kurait <akurait@amazon.com>
  • Loading branch information
AndreKurait authored Apr 24, 2024
1 parent d8b4ba8 commit 9df3614
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 143 deletions.
16 changes: 16 additions & 0 deletions TrafficCapture/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@ allprojects {
subprojects {
apply plugin: 'java'
apply plugin: 'maven-publish'

// TODO: Expand to do more static checking in more projects
if (project.name == "trafficReplayer" || project.name == "trafficCaptureProxyServer") {
dependencies {
annotationProcessor group: 'com.google.errorprone', name: 'error_prone_core', version: '2.26.1'
}
tasks.named('compileJava', JavaCompile) {
if (project.name == "trafficReplayer" || project.name == "trafficCaptureProxyServer") {
options.compilerArgs += [
"-XDcompilePolicy=simple",
"-Xplugin:ErrorProne -XepDisableAllChecks -Xep:MustBeClosed:ERROR -XepDisableWarningsInGeneratedCode",
]
}
}
}

task javadocJar(type: Jar, dependsOn: javadoc) {
archiveClassifier.set('javadoc')
from javadoc.destinationDir
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package org.opensearch.migrations.replay;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
Expand All @@ -11,9 +9,6 @@
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
Expand All @@ -24,6 +19,10 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.opensearch.migrations.replay.util.NettyUtils;
import org.opensearch.migrations.replay.util.RefSafeHolder;

@Slf4j
public class HttpByteBufFormatter {
Expand Down Expand Up @@ -64,61 +63,51 @@ public static String httpPacketBytesToString(HttpMessageType msgType, List<byte[
return httpPacketBytesToString(msgType, byteArrStream, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBytesToString(HttpMessageType msgType, Stream<byte[]> byteArrStream) {
return httpPacketBytesToString(msgType, byteArrStream, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBufsToString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean releaseByteBufs) {
return httpPacketBufsToString(msgType, byteBufStream, releaseByteBufs, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBytesToString(HttpMessageType msgType, List<byte[]> byteArrStream, String lineDelimiter) {
return httpPacketBytesToString(msgType,
Optional.ofNullable(byteArrStream).map(p -> p.stream()).orElse(Stream.of()), lineDelimiter);
public static String httpPacketBufsToString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream) {
return httpPacketBufsToString(msgType, byteBufStream, DEFAULT_LINE_DELIMITER);
}

public static String httpPacketBytesToString(HttpMessageType msgType, Stream<byte[]> byteArrStream, String lineDelimiter) {
public static String httpPacketBytesToString(HttpMessageType msgType, List<byte[]> byteArrs, String lineDelimiter) {
// This isn't memory efficient,
// but stringifying byte bufs through a full parse and reserializing them was already really slow!
return httpPacketBufsToString(msgType, byteArrStream.map(Unpooled::wrappedBuffer), true, lineDelimiter);
try (var stream = NettyUtils.createRefCntNeutralCloseableByteBufStream(byteArrs)) {
return httpPacketBufsToString(msgType, stream, lineDelimiter);
}
}

public static String httpPacketBufsToString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean releaseByteBufs, String lineDelimiter) {
public static String httpPacketBufsToString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream, String lineDelimiter) {
switch (printStyle.get().orElse(PacketPrintFormat.TRUNCATED)) {
case TRUNCATED:
return httpPacketBufsToString(byteBufStream, Utils.MAX_BYTES_SHOWN_FOR_TO_STRING, releaseByteBufs);
return httpPacketBufsToString(byteBufStream, Utils.MAX_BYTES_SHOWN_FOR_TO_STRING);
case FULL_BYTES:
return httpPacketBufsToString(byteBufStream, Long.MAX_VALUE, releaseByteBufs);
return httpPacketBufsToString(byteBufStream, Long.MAX_VALUE);
case PARSED_HTTP:
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, false, releaseByteBufs,
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, false,
lineDelimiter);
case PARSED_HTTP_SORTED_HEADERS:
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, true, releaseByteBufs,
return httpPacketsToPrettyPrintedString(msgType, byteBufStream, true,
lineDelimiter);
default:
throw new IllegalStateException("Unknown PacketPrintFormat: " + printStyle.get());
}
}

public static String httpPacketsToPrettyPrintedString(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean sortHeaders, boolean releaseByteBufs, String lineDelimiter) {
HttpMessage httpMessage = parseHttpMessageFromBufs(msgType, byteBufStream, releaseByteBufs);
var holderOp = Optional.ofNullable((httpMessage instanceof ByteBufHolder) ? (ByteBufHolder) httpMessage : null);
try {
if (httpMessage instanceof FullHttpRequest) {
return prettyPrintNettyRequest((FullHttpRequest) httpMessage, sortHeaders, lineDelimiter);
} else if (httpMessage instanceof FullHttpResponse) {
return prettyPrintNettyResponse((FullHttpResponse) httpMessage, sortHeaders, lineDelimiter);
} else if (httpMessage == null) {
return "[NULL]";
boolean sortHeaders, String lineDelimiter) {
try(var messageHolder = RefSafeHolder.create(parseHttpMessageFromBufs(msgType, byteBufStream))) {
final HttpMessage httpMessage = messageHolder.get();
if (httpMessage != null) {
if (httpMessage instanceof FullHttpRequest) {
return prettyPrintNettyRequest((FullHttpRequest) httpMessage, sortHeaders, lineDelimiter);
} else if (httpMessage instanceof FullHttpResponse) {
return prettyPrintNettyResponse((FullHttpResponse) httpMessage, sortHeaders, lineDelimiter);
} else {
throw new IllegalStateException("Embedded channel with an HttpObjectAggregator returned an " +
"unexpected object of type " + httpMessage.getClass() + ": " + httpMessage);
}
} else {
throw new IllegalStateException("Embedded channel with an HttpObjectAggregator returned an " +
"unexpected object of type " + httpMessage.getClass() + ": " + httpMessage);
return "[NULL]";
}
} finally {
holderOp.ifPresent(bbh->bbh.content().release());
}
}

Expand Down Expand Up @@ -153,58 +142,40 @@ private static String prettyPrintNettyMessage(StringJoiner sj, boolean sorted, H
* @param byteBufStream
* @return
*/
public static HttpMessage parseHttpMessageFromBufs(HttpMessageType msgType, Stream<ByteBuf> byteBufStream,
boolean releaseByteBufs) {
public static HttpMessage parseHttpMessageFromBufs(HttpMessageType msgType, Stream<ByteBuf> byteBufStream) {
EmbeddedChannel channel = new EmbeddedChannel(
msgType == HttpMessageType.REQUEST ? new HttpServerCodec() : new HttpClientCodec(),
new HttpContentDecompressor(),
new HttpObjectAggregator(Utils.MAX_PAYLOAD_SIZE_TO_PRINT) // Set max content length if needed
);

byteBufStream.forEach(b -> {
try {
channel.writeInbound(b.retainedDuplicate());
} finally {
if (releaseByteBufs) {
b.release();
}
}
});

try {
byteBufStream.forEachOrdered(b -> channel.writeInbound(b.retainedDuplicate()));
return channel.readInbound();
} finally {
channel.finishAndReleaseAll();
}
}

public static FullHttpRequest parseHttpRequestFromBufs(Stream<ByteBuf> byteBufStream, boolean releaseByteBufs) {
return (FullHttpRequest) parseHttpMessageFromBufs(HttpMessageType.REQUEST, byteBufStream, releaseByteBufs);
public static FullHttpRequest parseHttpRequestFromBufs(Stream<ByteBuf> byteBufStream) {
return (FullHttpRequest) parseHttpMessageFromBufs(HttpMessageType.REQUEST, byteBufStream);
}

public static FullHttpResponse parseHttpResponseFromBufs(Stream<ByteBuf> byteBufStream, boolean releaseByteBufs) {
return (FullHttpResponse) parseHttpMessageFromBufs(HttpMessageType.RESPONSE, byteBufStream, releaseByteBufs);
public static FullHttpResponse parseHttpResponseFromBufs(Stream<ByteBuf> byteBufStream) {
return (FullHttpResponse) parseHttpMessageFromBufs(HttpMessageType.RESPONSE, byteBufStream);
}

public static String httpPacketBufsToString(Stream<ByteBuf> byteBufStream, long maxBytesToShow,
boolean releaseByteBufs) {
public static String httpPacketBufsToString(Stream<ByteBuf> byteBufStream, long maxBytesToShow) {
if (byteBufStream == null) {
return "null";
}
return byteBufStream.map(originalByteBuf -> {
try {
var bb = originalByteBuf.duplicate();
var length = bb.readableBytes();
var str = IntStream.range(0, length).map(idx -> bb.readByte())
.limit(maxBytesToShow)
.mapToObj(b -> "" + (char) b)
.collect(Collectors.joining());
return "[" + (length > maxBytesToShow ? str + "..." : str) + "]";
} finally {
if (releaseByteBufs) {
originalByteBuf.release();
}
}})
.collect(Collectors.joining(","));
var bb = originalByteBuf.duplicate();
var length = bb.readableBytes();
var str = IntStream.range(0, length).map(idx -> bb.readByte())
.limit(maxBytesToShow)
.mapToObj(b -> "" + (char) b)
.collect(Collectors.joining());
return "[" + (length > maxBytesToShow ? str + "..." : str) + "]";
}).collect(Collectors.joining(","));
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.opensearch.migrations.replay;

import io.netty.buffer.Unpooled;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Lombok;
Expand All @@ -13,6 +12,7 @@
import java.time.Instant;
import java.util.Optional;
import java.util.stream.Stream;
import org.opensearch.migrations.replay.util.NettyUtils;

@Slf4j
@EqualsAndHashCode(exclude = "currentSegmentBytes")
Expand Down Expand Up @@ -65,17 +65,17 @@ public Stream<byte[]> stream() {
}

public String format(Optional<HttpByteBufFormatter.HttpMessageType> messageTypeOp) {
var packetBytesAsStr = messageTypeOp.map(mt-> HttpByteBufFormatter.httpPacketBytesToString(mt, packetBytes,
HttpByteBufFormatter.LF_LINE_DELIMITER))
.orElseGet(()-> HttpByteBufFormatter.httpPacketBufsToString(
packetBytes.stream().map(Unpooled::wrappedBuffer),
Utils.MAX_PAYLOAD_SIZE_TO_PRINT, true));
final StringBuilder sb = new StringBuilder("HttpMessageAndTimestamp{");
sb.append("firstPacketTimestamp=").append(firstPacketTimestamp);
sb.append(", lastPacketTimestamp=").append(lastPacketTimestamp);
sb.append(", message=[").append(packetBytesAsStr);
sb.append("]}");
return sb.toString();
try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(packetBytes)) {
var packetBytesAsStr = messageTypeOp.map(mt-> HttpByteBufFormatter.httpPacketBytesToString(mt, packetBytes,
HttpByteBufFormatter.LF_LINE_DELIMITER))
.orElseGet(()-> HttpByteBufFormatter.httpPacketBufsToString(bufStream, Utils.MAX_PAYLOAD_SIZE_TO_PRINT));
final StringBuilder sb = new StringBuilder("HttpMessageAndTimestamp{");
sb.append("firstPacketTimestamp=").append(firstPacketTimestamp);
sb.append(", lastPacketTimestamp=").append(lastPacketTimestamp);
sb.append(", message=[").append(packetBytesAsStr);
sb.append("]}");
return sb.toString();
}
}

public void addSegment(byte[] data) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
package org.opensearch.migrations.replay;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.util.ReferenceCounted;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.opensearch.migrations.replay.datatypes.TransformedPackets;
import org.opensearch.migrations.replay.tracing.IReplayContexts;

import java.time.Duration;
import java.util.Base64;
import java.util.LinkedHashMap;
Expand All @@ -17,7 +10,12 @@
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.opensearch.migrations.replay.datatypes.TransformedPackets;
import org.opensearch.migrations.replay.tracing.IReplayContexts;
import org.opensearch.migrations.replay.util.NettyUtils;
import org.opensearch.migrations.replay.util.RefSafeHolder;

/**
* TODO - This class will pull all bodies in as a byte[], even if that byte[] isn't
Expand Down Expand Up @@ -102,11 +100,6 @@ public static void fillStatusCodeMetrics(@NonNull IReplayContexts.ITupleHandling
targetResponseOp.ifPresent(r -> context.setTargetStatus((Integer) r.get(STATUS_CODE_KEY)));
}


private static Stream<ByteBuf> byteToByteBufStream(List<byte[]> incoming) {
return incoming.stream().map(Unpooled::wrappedBuffer);
}

private static byte[] getBytesFromByteBuf(ByteBuf buf) {
var bytes = new byte[buf.readableBytes()];
buf.getBytes(buf.readerIndex(), bytes);
Expand Down Expand Up @@ -138,17 +131,20 @@ private static Map<String, Object> convertRequest(@NonNull IReplayContexts.ITupl
@NonNull List<byte[]> data) {
return makeSafeMap(context, () -> {
var map = new LinkedHashMap<String, Object>();
var message = HttpByteBufFormatter.parseHttpRequestFromBufs(byteToByteBufStream(data), true);
try {
map.put("Request-URI", message.uri());
map.put("Method", message.method().toString());
map.put("HTTP-Version", message.protocolVersion().toString());
context.setMethod(message.method().toString());
context.setEndpoint(message.uri());
context.setHttpVersion(message.protocolVersion().toString());
return fillMap(map, message.headers(), message.content());
} finally {
Optional.ofNullable(message).ifPresent(ReferenceCounted::release);
try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(data);
var messageHolder = RefSafeHolder.create(HttpByteBufFormatter.parseHttpRequestFromBufs(bufStream))) {
var message = messageHolder.get();
if (message != null) {
map.put("Request-URI", message.uri());
map.put("Method", message.method().toString());
map.put("HTTP-Version", message.protocolVersion().toString());
context.setMethod(message.method().toString());
context.setEndpoint(message.uri());
context.setHttpVersion(message.protocolVersion().toString());
return fillMap(map, message.headers(), message.content());
} else {
return Map.of("Exception", "Message couldn't be parsed as a full http message");
}
}
});
}
Expand All @@ -157,18 +153,18 @@ private static Map<String, Object> convertResponse(@NonNull IReplayContexts.ITup
@NonNull List<byte[]> data, Duration latency) {
return makeSafeMap(context, () -> {
var map = new LinkedHashMap<String, Object>();
var message = HttpByteBufFormatter.parseHttpResponseFromBufs(byteToByteBufStream(data), true);
if (message == null) {
return Map.of("Exception", "Message couldn't be parsed as a full http message");
}
try {
map.put("HTTP-Version", message.protocolVersion());
map.put(STATUS_CODE_KEY, message.status().code());
map.put("Reason-Phrase", message.status().reasonPhrase());
map.put(RESPONSE_TIME_MS_KEY, latency.toMillis());
return fillMap(map, message.headers(), message.content());
} finally {
Optional.ofNullable(message).ifPresent(ReferenceCounted::release);
try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(data);
var messageHolder = RefSafeHolder.create(HttpByteBufFormatter.parseHttpResponseFromBufs(bufStream))) {
var message = messageHolder.get();
if (message != null) {
map.put("HTTP-Version", message.protocolVersion());
map.put(STATUS_CODE_KEY, message.status().code());
map.put("Reason-Phrase", message.status().reasonPhrase());
map.put(RESPONSE_TIME_MS_KEY, latency.toMillis());
return fillMap(map, message.headers(), message.content());
} else {
return Map.of("Exception", "Message couldn't be parsed as a full http message");
}
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public String toString() {
if (targetResponseDuration != null) { sj.add("targetResponseDuration=").add(targetResponseDuration+""); }
Optional.ofNullable(targetRequestData).ifPresent(d-> sj.add("targetRequestData=")
.add(d.isClosed() ? "CLOSED" : HttpByteBufFormatter.httpPacketBufsToString(
HttpByteBufFormatter.HttpMessageType.REQUEST, d.streamUnretained(), false, LF_LINE_DELIMITER)));
HttpByteBufFormatter.HttpMessageType.REQUEST, d.streamUnretained(), LF_LINE_DELIMITER)));
Optional.ofNullable(targetResponseData).filter(d->!d.isEmpty()).ifPresent(d -> sj.add("targetResponseData=")
.add(HttpByteBufFormatter.httpPacketBytesToString(HttpByteBufFormatter.HttpMessageType.RESPONSE, d, LF_LINE_DELIMITER)));
sj.add("transformStatus=").add(transformationStatus+"");
Expand Down
Loading

0 comments on commit 9df3614

Please sign in to comment.