Skip to content

Commit

Permalink
Fix StreamChannelConnectionCaptureSerializer behavior for nioBuffers …
Browse files Browse the repository at this point in the history
…that have larger capacity than limit

Signed-off-by: Andre Kurait <akurait@amazon.com>
  • Loading branch information
AndreKurait committed May 1, 2024
1 parent 1809a80 commit 86a6b3a
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,34 @@ public static int maxBytesNeededForASegmentedObservation(Instant timestamp, int
int tsTagAndContentSize = CodedOutputStream.computeInt32Size(TrafficObservation.TS_FIELD_NUMBER, tsContentSize) + tsContentSize;

// Capture required bytes
int dataSize = CodedOutputStream.computeByteBufferSize(dataFieldNumber, buffer);
int dataSize = computeByteBufferRemainingSize(dataFieldNumber, buffer);
int captureTagAndContentSize = CodedOutputStream.computeInt32Size(observationFieldNumber, dataSize) + dataSize;

// Observation and closing index required bytes
return bytesNeededForObservationAndClosingIndex(tsTagAndContentSize + captureTagAndContentSize,
Integer.MAX_VALUE);
}

/**
* This function determines the number of bytes needed to write the remaining bytes in a byteBuffer and its tag.
* Use this over CodeOutputStream.computeByteBufferSize(int fieldNumber, ByteBuffer buffer) due to the latter
* relying on the ByteBuffer capacity instead of limit in size calculation.
*/
public static int computeByteBufferRemainingSize(int fieldNumber, ByteBuffer buffer) {
return CodedOutputStream.computeTagSize(fieldNumber) + computeByteBufferRemainingSizeNoTag(buffer);
}

/**
* This function determines the number of bytes needed to write the remaining bytes in a byteBuffer. Use this over
* CodeOutputStream.computeByteBufferSizeNoTag(ByteBuffer buffer) due to the latter relying on the
* ByteBuffer capacity instead of limit in size calculation.
*/
public static int computeByteBufferRemainingSizeNoTag(ByteBuffer buffer) {
int bufferSize = buffer.remaining();
return CodedOutputStream.computeUInt32SizeNoTag(bufferSize) + bufferSize;
}


/**
* This function determines the number of bytes needed to store a TrafficObservation and a closing index for a
* TrafficStream, from the provided input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Timestamp;
import com.google.protobuf.WireFormat;
import io.netty.buffer.ByteBuf;

import java.util.function.IntSupplier;
Expand Down Expand Up @@ -182,8 +183,12 @@ private void writeTimestampForNowToCurrentStream(Instant timestamp) throws IOExc
}

private void writeByteBufferToCurrentStream(int fieldNum, ByteBuffer byteBuffer) throws IOException {
if (byteBuffer.remaining() > 0) {
getOrCreateCodedOutputStream().writeByteBuffer(fieldNum, byteBuffer);
if (byteBuffer.hasRemaining()) {
// CodedOutputStream.writeByteBuffer writes based on capacity and ignores limits so prefer write
getOrCreateCodedOutputStream().writeTag(fieldNum, WireFormat.WIRETYPE_LENGTH_DELIMITED);
getOrCreateCodedOutputStream().writeUInt32NoTag(byteBuffer.remaining());
getOrCreateCodedOutputStream().write(byteBuffer.duplicate());
assert byteBuffer.hasRemaining() : "byteBuffer position should not be modified when writing.";
} else {
getOrCreateCodedOutputStream().writeUInt32NoTag(0);
}
Expand Down Expand Up @@ -279,8 +284,7 @@ void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timesta
addDataMessage(captureFieldNumber, dataFieldNumber, timestamp, buf.nioBuffer());
}

void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timestamp, ByteBuffer nioBuffer) throws IOException {
var readOnlyDataBuffer = nioBuffer.asReadOnlyBuffer();
void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timestamp, ByteBuffer buffer) throws IOException {
int segmentFieldNumber;
int segmentDataFieldNumber;
if (captureFieldNumber == TrafficObservation.READ_FIELD_NUMBER) {
Expand All @@ -296,40 +300,44 @@ void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timesta
// the potentially required bytes for simplicity. This could leave ~5 bytes of unused space in the CodedOutputStream
// when considering the case of a message that does not need segments or for the case of a smaller segment created
// from a much larger message
int messageAndOverheadBytesLeft = CodedOutputStreamSizeUtil.maxBytesNeededForASegmentedObservation(timestamp,
segmentFieldNumber, segmentDataFieldNumber, readOnlyDataBuffer);
int trafficStreamOverhead = messageAndOverheadBytesLeft - readOnlyDataBuffer.capacity();
final int messageAndOverheadBytesLeft = CodedOutputStreamSizeUtil.maxBytesNeededForASegmentedObservation(timestamp,
segmentFieldNumber, segmentDataFieldNumber, buffer);
final int dataSize = CodedOutputStreamSizeUtil.computeByteBufferRemainingSizeNoTag(buffer);
final int trafficStreamOverhead = messageAndOverheadBytesLeft - dataSize;

// Ensure that space for at least one data byte and overhead exists, otherwise a flush is necessary.
flushIfNeeded(() -> (trafficStreamOverhead + 1));
// Ensure that space for at least one data byte, one length byte, and overhead exists, otherwise a flush is necessary.
flushIfNeeded(() -> (trafficStreamOverhead + 2)).join();
assert getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft() == -1 ||
getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft() > trafficStreamOverhead
: "COS does not have space for data";

// If our message is empty or can fit in the current CodedOutputStream no chunking is needed, and we can continue
var spaceLeft = getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft();
if (readOnlyDataBuffer.limit() == 0 || spaceLeft == -1 || messageAndOverheadBytesLeft <= spaceLeft) {
if (!buffer.hasRemaining() || spaceLeft == -1 || messageAndOverheadBytesLeft <= spaceLeft) {
int minExpectedSpaceAfterObservation = spaceLeft - messageAndOverheadBytesLeft;
addSubstreamMessage(captureFieldNumber, dataFieldNumber, timestamp, readOnlyDataBuffer);
addSubstreamMessage(captureFieldNumber, dataFieldNumber, timestamp, buffer);
observationSizeSanityCheck(minExpectedSpaceAfterObservation, captureFieldNumber);
return;
}

while(readOnlyDataBuffer.position() < readOnlyDataBuffer.limit()) {
var readBuffer = buffer.duplicate();
while(readBuffer.hasRemaining()) {
flushIfNeeded(() -> (trafficStreamOverhead + 2)).join();
// COS checked for unbounded limit above
int availableCOSSpace = getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft();
int chunkBytes = messageAndOverheadBytesLeft > availableCOSSpace ? availableCOSSpace - trafficStreamOverhead : readOnlyDataBuffer.limit() - readOnlyDataBuffer.position();
ByteBuffer bb = readOnlyDataBuffer.slice();
bb.limit(chunkBytes);
bb = bb.slice();
readOnlyDataBuffer.position(readOnlyDataBuffer.position() + chunkBytes);
addSubstreamMessage(segmentFieldNumber, segmentDataFieldNumber, timestamp, bb);
int minExpectedSpaceAfterObservation = availableCOSSpace - chunkBytes - trafficStreamOverhead;
final int availableCOSSpace = getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft();
final int maxLengthSpace = CodedOutputStream.computeUInt32SizeNoTag(readBuffer.remaining());
final int maxBytesSpace = availableCOSSpace - trafficStreamOverhead - maxLengthSpace;
final int nextChunkBytes = Math.min(maxBytesSpace, readBuffer.remaining());

var dataBytes = new byte[nextChunkBytes];
readBuffer.get(dataBytes, 0, nextChunkBytes);
addSubstreamMessage(segmentFieldNumber, segmentDataFieldNumber, timestamp, ByteBuffer.wrap(dataBytes));

final int minExpectedSpaceAfterObservation = maxBytesSpace - nextChunkBytes;
observationSizeSanityCheck(minExpectedSpaceAfterObservation, segmentFieldNumber);
// 1 to N-1 chunked messages
if (readOnlyDataBuffer.position() < readOnlyDataBuffer.limit()) {
flushCommitAndResetStream(false);
messageAndOverheadBytesLeft = messageAndOverheadBytesLeft - chunkBytes;
}
}
writeEndOfSegmentMessage(timestamp);

}

void addSubstreamMessage(int captureFieldNumber, int dataFieldNumber, int dataCountFieldNumber, int dataCount,
Expand All @@ -342,7 +350,7 @@ void addSubstreamMessage(int captureFieldNumber, int dataFieldNumber, int dataCo
segmentCountSize = CodedOutputStream.computeInt32Size(dataCountFieldNumber, dataCount);
}
if (byteBuffer.remaining() > 0) {
dataSize = CodedOutputStream.computeByteBufferSize(dataFieldNumber, byteBuffer);
dataSize = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(dataFieldNumber, byteBuffer);
captureClosureLength = CodedOutputStream.computeInt32SizeNoTag(dataSize + segmentCountSize);
}
beginSubstreamObservation(timestamp, captureFieldNumber, captureClosureLength + dataSize + segmentCountSize);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package org.opensearch.migrations.trafficcapture;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.nio.ByteBuffer;
import java.time.Instant;

class CodedOutputStreamSizeUtilTest {

@Test
void testGetSizeOfTimestamp() {
// Timestamp with only seconds (no explicit nanoseconds)
Instant timestampSecondsOnly = Instant.parse("2024-01-01T00:00:00Z");
int sizeSecondsOnly = CodedOutputStreamSizeUtil.getSizeOfTimestamp(timestampSecondsOnly);
assertEquals( 6, sizeSecondsOnly);

// Timestamp with both seconds and nanoseconds
Instant timestampWithNanos = Instant.parse("2024-12-31T23:59:59.123456789Z");
int sizeWithNanos = CodedOutputStreamSizeUtil.getSizeOfTimestamp(timestampWithNanos);
assertEquals( 11, sizeWithNanos);
}

@Test
void testMaxBytesNeededForASegmentedObservation() {
Instant timestamp = Instant.parse("2024-01-01T00:00:00Z");
ByteBuffer buffer = ByteBuffer.allocate(100).limit(50);
buffer.position(25);
int result = CodedOutputStreamSizeUtil.maxBytesNeededForASegmentedObservation(timestamp, 1, 2, buffer);
assertEquals(45, result);
}

@Test
void test_computeByteBufferRemainingSize() {
ByteBuffer buffer = ByteBuffer.allocate(100).limit(50);
int result = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(2, buffer);
assertEquals(52, result);
}

@Test
void test_computeByteBufferRemainingSize_ByteBufferAtCapacity() {
ByteBuffer buffer = ByteBuffer.allocate(200);
int result = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(2, buffer);
assertEquals(203, result);
}

@Test
void test_computeByteBufferRemainingSize_EmptyByteBuffer() {
ByteBuffer buffer = ByteBuffer.allocate(0);
int result = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(2, buffer);
assertEquals(2, result);
}

@Test
void testBytesNeededForObservationAndClosingIndex() {
int observationContentSize = 50;
int numberOfTrafficStreamsSoFar = 10;

int result = CodedOutputStreamSizeUtil.bytesNeededForObservationAndClosingIndex(observationContentSize, numberOfTrafficStreamsSoFar);
assertEquals(54, result);
}

@Test
void testBytesNeededForObservationAndClosingIndex_WithZeroContent() {
int observationContentSize = 0;
int numberOfTrafficStreamsSoFar = 0;

int result = CodedOutputStreamSizeUtil.bytesNeededForObservationAndClosingIndex(observationContentSize, numberOfTrafficStreamsSoFar);
assertEquals(4, result);
}

@Test
void testBytesNeededForObservationAndClosingIndex_VariousIndices() {
int observationContentSize = 20;

// Test with increasing indices to verify scaling of index size
int[] indices = new int[]{1, 1000, 100000};
int[] expectedResults = new int[]{24, 25, 26};

for (int i = 0; i < indices.length; i++) {
int result = CodedOutputStreamSizeUtil.bytesNeededForObservationAndClosingIndex(observationContentSize, indices[i]);
assertEquals(expectedResults[i], result);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,55 @@ public void testWriteIsHandledForBufferAllocatedLargerThanWritten()
var serializer = createSerializerWithTestHandler(outputBuffersCreated, getEstimatedTrafficStreamByteSize(1, 200));

ByteBuffer byteBuffer = ByteBuffer.allocateDirect(100);
byteBuffer.limit(50);
byteBuffer.putInt(1);
byteBuffer.put(FAKE_READ_PACKET_DATA.getBytes(StandardCharsets.UTF_8));
byteBuffer.flip();

serializer.addDataMessage(TrafficObservation.WRITE_FIELD_NUMBER, WriteObservation.DATA_FIELD_NUMBER, REFERENCE_TIMESTAMP, byteBuffer);
var future = serializer.flushCommitAndResetStream(true);
future.get();

Assertions.assertEquals(0, byteBuffer.position());

var outputBuffersList = new ArrayList<>(outputBuffersCreated);
TrafficStream reconstitutedTrafficStream = TrafficStream.parseFrom(outputBuffersList.get(0));
Assertions.assertEquals(1, reconstitutedTrafficStream.getSubStream(0).getWrite().getData().size());
Assertions.assertEquals(FAKE_READ_PACKET_DATA, reconstitutedTrafficStream.getSubStream(0).getWrite().getData().toStringUtf8());
}

@Test
public void testWriteIsHandledForBufferAllocatedLargerThanWrittenWithChunking()
throws IOException, ExecutionException, InterruptedException {
var outputBuffersCreated = new ConcurrentLinkedQueue<ByteBuffer>();
var serializer = createSerializerWithTestHandler(outputBuffersCreated, getEstimatedTrafficStreamByteSize(1, 4));

ByteBuffer byteBuffer = ByteBuffer.allocate(100);
byteBuffer.put(FAKE_READ_PACKET_DATA.getBytes(StandardCharsets.UTF_8));
byteBuffer.flip();

Assertions.assertEquals(0, byteBuffer.position());
Assertions.assertEquals(100, byteBuffer.capacity());
Assertions.assertEquals(16, byteBuffer.limit());

serializer.addDataMessage(TrafficObservation.WRITE_FIELD_NUMBER, WriteObservation.DATA_FIELD_NUMBER, REFERENCE_TIMESTAMP, byteBuffer);
var future = serializer.flushCommitAndResetStream(true);
future.get();

Assertions.assertEquals(0, byteBuffer.position());

List<TrafficObservation> observations = new ArrayList<>();
for (ByteBuffer buffer : outputBuffersCreated) {
var trafficStream = TrafficStream.parseFrom(buffer);
observations.add(trafficStream.getSubStream(0));
}

StringBuilder reconstructedData = new StringBuilder();
for (TrafficObservation observation : observations) {
var stringChunk = observation.getWriteSegment().getData().toStringUtf8();
reconstructedData.append(stringChunk);
}
Assertions.assertEquals(FAKE_READ_PACKET_DATA, reconstructedData.toString());
}


@Test
public void testWithLimitlessCodedOutputStreamHolder()
throws IOException, ExecutionException, InterruptedException {
Expand Down

0 comments on commit 86a6b3a

Please sign in to comment.