diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionActionUtil.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionActionUtil.java new file mode 100644 index 0000000000000..8b898b3afb1e2 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionActionUtil.java @@ -0,0 +1,117 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +/** + * ExtensionActionUtil - a class for creating and processing remote requests using byte arrays. + */ +public class ExtensionActionUtil { + + /** + * The Unicode UNIT SEPARATOR used to separate the Request class name and parameter bytes + */ + public static final byte UNIT_SEPARATOR = (byte) '\u001F'; + + /** + * @param request an instance of a request extending {@link ActionRequest}, containing information about the + * request being sent to the remote server. It is used to create a byte array containing the request data, + * which will be sent to the remote server. + * @return An Extension ActionRequest object that represents the deserialized data. + * If an error occurred during the deserialization process, the method will return {@code null}. + * @throws RuntimeException If a RuntimeException occurs while creating the proxy request bytes. + */ + public static byte[] createProxyRequestBytes(ActionRequest request) throws RuntimeException { + byte[] requestClassBytes = request.getClass().getName().getBytes(StandardCharsets.UTF_8); + byte[] requestBytes; + + try { + requestBytes = convertParamsToBytes(request); + assert requestBytes != null; + return ByteBuffer.allocate(requestClassBytes.length + 1 + requestBytes.length) + .put(requestClassBytes) + .put(UNIT_SEPARATOR) + .put(requestBytes) + .array(); + } catch (RuntimeException e) { + throw new RuntimeException("RuntimeException occurred while creating proxyRequestBytes"); + } + } + + /** + * @param requestBytes is a byte array containing the request data, used by the "createActionRequest" + * method to create an "ActionRequest" object, which represents the request model to be processed on the server. + * @return an "Action Request" object representing the request model for processing on the server, + * or {@code null} if the request data is invalid or null. + * @throws ReflectiveOperationException if an exception occurs during the reflective operation, such as when + * resolving the request class, accessing the constructor, or creating an instance using reflection + * @throws NullPointerException if a null pointer exception occurs during the creation of the ActionRequest object + */ + public static ActionRequest createActionRequest(byte[] requestBytes) throws ReflectiveOperationException { + int delimPos = delimPos(requestBytes); + String requestClassName = new String(Arrays.copyOfRange(requestBytes, 0, delimPos + 1), StandardCharsets.UTF_8).stripTrailing(); + try { + Class clazz = Class.forName(requestClassName); + Constructor constructor = clazz.getConstructor(StreamInput.class); + StreamInput requestByteStream = StreamInput.wrap(Arrays.copyOfRange(requestBytes, delimPos + 1, requestBytes.length)); + return (ActionRequest) constructor.newInstance(requestByteStream); + } catch (ReflectiveOperationException e) { + throw new ReflectiveOperationException( + "ReflectiveOperationException occurred while creating extensionAction request from bytes", + e + ); + } catch (NullPointerException e) { + throw new NullPointerException( + "NullPointerException occurred while creating extensionAction request from bytes" + e.getMessage() + ); + } + } + + /** + * Converts the given object of type T, which implements the {@link Writeable} interface, to a byte array. + * @param the type of the object to be converted to bytes, which must implement the {@link Writeable} interface. + * @param writeableObject the object of type T to be converted to bytes. + * @return a byte array containing the serialized bytes of the given object, or {@code null} if the input is invalid or null. + * @throws IllegalStateException if a failure occurs while writing the data + */ + public static byte[] convertParamsToBytes(T writeableObject) throws IllegalStateException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + writeableObject.writeTo(out); + return BytesReference.toBytes(out.bytes()); + } catch (IOException ieo) { + throw new IllegalStateException("Failure writing bytes", ieo); + } + } + + /** + * Searches for the position of the unit separator byte in the given byte array. + * + * @param bytes the byte array to search for the unit separator byte. + * @return the index of the unit separator byte in the byte array, or -1 if it was not found. + */ + public static int delimPos(byte[] bytes) { + for (int offset = 0; offset < bytes.length; ++offset) { + if (bytes[offset] == ExtensionActionUtil.UNIT_SEPARATOR) { + return offset; + } + } + return -1; + } +} diff --git a/server/src/test/java/org/opensearch/extensions/action/ExtensionActionUtilTests.java b/server/src/test/java/org/opensearch/extensions/action/ExtensionActionUtilTests.java new file mode 100644 index 0000000000000..d2b889d33da9a --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/action/ExtensionActionUtilTests.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.action; + +import org.junit.Before; +import org.mockito.Mockito; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +import static org.opensearch.extensions.action.ExtensionActionUtil.UNIT_SEPARATOR; +import static org.opensearch.extensions.action.ExtensionActionUtil.createProxyRequestBytes; + +public class ExtensionActionUtilTests extends OpenSearchTestCase { + private byte[] myBytes; + private final String actionName = "org.opensearch.action.MyExampleRequest"; + private final byte[] actionNameBytes = MyExampleRequest.class.getName().getBytes(StandardCharsets.UTF_8); + + @Before + public void setup() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + MyExampleRequest exampleRequest = new MyExampleRequest(actionName, actionNameBytes); + exampleRequest.writeTo(out); + + byte[] requestBytes = BytesReference.toBytes(out.bytes()); + byte[] requestClass = MyExampleRequest.class.getName().getBytes(StandardCharsets.UTF_8); + this.myBytes = ByteBuffer.allocate(requestClass.length + 1 + requestBytes.length) + .put(requestClass) + .put(UNIT_SEPARATOR) + .put(requestBytes) + .array(); + } + + public void testCreateProxyRequestBytes() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + MyExampleRequest exampleRequest = new MyExampleRequest(actionName, actionNameBytes); + exampleRequest.writeTo(out); + + byte[] result = createProxyRequestBytes(exampleRequest); + assertArrayEquals(this.myBytes, result); + assertThrows(RuntimeException.class, () -> ExtensionActionUtil.createProxyRequestBytes(new MyExampleRequest(null, null))); + } + + public void testCreateActionRequest() throws ReflectiveOperationException { + ActionRequest actionRequest = ExtensionActionUtil.createActionRequest(myBytes); + assertThrows(NullPointerException.class, () -> ExtensionActionUtil.createActionRequest(null)); + assertThrows(ReflectiveOperationException.class, () -> ExtensionActionUtil.createActionRequest(actionNameBytes)); + assertNotNull(actionRequest); + assertFalse(actionRequest.getShouldStoreResult()); + } + + public void testConvertParamsToBytes() throws IOException { + Writeable mockWriteableObject = Mockito.mock(Writeable.class); + Mockito.doThrow(new IOException("Test IOException")).when(mockWriteableObject).writeTo(Mockito.any()); + assertThrows(IllegalStateException.class, () -> ExtensionActionUtil.convertParamsToBytes(mockWriteableObject)); + } + + public void testDelimPos() { + assertTrue(ExtensionActionUtil.delimPos(myBytes) > 0); + assertTrue(ExtensionActionUtil.delimPos(actionNameBytes) < 0); + assertEquals(-1, ExtensionActionUtil.delimPos(actionNameBytes)); + } + + private static class MyExampleRequest extends ActionRequest { + private final String param1; + private final byte[] param2; + + public MyExampleRequest(String param1, byte[] param2) { + this.param1 = param1; + this.param2 = param2; + } + + public MyExampleRequest(StreamInput in) throws IOException { + super(in); + param1 = in.readString(); + param2 = in.readByteArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(param1); + out.writeByteArray(param2); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + } +}