Skip to content

Commit

Permalink
Add authentication verifier implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Nied <petern@amazon.com>
  • Loading branch information
peternied committed Oct 3, 2023
1 parent 6b02da8 commit cc3db41
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie
return channel -> {
final SecurityRequestChannel securityRequest = SecurityRequestFactory.from(request, channel);


// check if .opendistro_security index has been initialized
if (!ensureIndexExists()) {
internalSeverError(channel, RequestContentValidator.ValidationError.SECURITY_NOT_INITIALIZED.message());
Expand Down
71 changes: 71 additions & 0 deletions src/main/java/org/opensearch/security/filter/NettyRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.opensearch.security.filter;

import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import javax.net.ssl.SSLEngine;

import org.opensearch.rest.RestRequest.Method;

import io.netty.handler.codec.http.HttpRequest;

/**
* Wraps the functionality of HttpRequest for use in the security plugin
*/
public class NettyRequest implements SecurityRequest {
protected final HttpRequest underlyingRequest;

NettyRequest(final HttpRequest request) {
this.underlyingRequest = request;
}

@Override
public Map<String, List<String>> getHeaders() {
final Map<String, List<String>> headers = new HashMap<>();
underlyingRequest.headers().forEach(h -> headers.put(h.getKey(), List.of(h.getValue())));
return headers;
}

@Override
public SSLEngine getSSLEngine() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getSSLEngine'");
}

@Override
public String path() {
try {
return new URL(underlyingRequest.uri()).getPath();
} catch (final MalformedURLException e) {
return "";
}
}

@Override
public Method method() {
return Method.valueOf(underlyingRequest.method().name());
}

@Override
public Optional<InetSocketAddress> getRemoteAddress() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getRemoteAddress'");
}

@Override
public String uri() {
return underlyingRequest.uri();
}

@Override
public Map<String, String> params() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'params'");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.opensearch.security.filter;

import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.commons.lang3.tuple.Triple;
import io.netty.handler.codec.http.HttpRequest;

public class NettyRequestChannel extends NettyRequest implements SecurityRequestChannel {

private final AtomicReference<Triple<Integer, Map<String, String>, String>> completedResult = new AtomicReference<>();
NettyRequestChannel(final HttpRequest request) {
super(request);
}

@Override
public boolean hasCompleted() {
return completedResult.get() != null;
}

@Override
public boolean completeWithResponse(int statusCode, Map<String, String> headers, String body) {
if (hasCompleted()) {
throw new UnsupportedOperationException("This channel has already completed");
}

completedResult.set(Triple.of(statusCode, headers, body));

return true;
}

/** Accessor to get the completed response */
public Triple<Integer, Map<String, String>, String> getCompletedRequest() {
return completedResult.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;

import io.netty.handler.codec.http.HttpRequest;

/**
* Generates wrapped versions of requests for use in the security plugin
*/
Expand All @@ -17,4 +19,9 @@ public static SecurityRequest from(final RestRequest request) {
public static SecurityRequestChannel from(final RestRequest request, final RestChannel channel) {
return new OpenSearchRequestChannel(request, channel);
}

/** Creates a security request from a netty HttpRequest object */
public static SecurityRequestChannel from(HttpRequest request) {
return new NettyRequestChannel(request);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.opensearch.security.http;

import java.util.Optional;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.security.filter.NettyRequestChannel;
import org.opensearch.security.filter.SecurityRequestChannel;
import org.opensearch.security.filter.SecurityRequestFactory;
import org.opensearch.security.filter.SecurityRestFilter;

import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCountUtil;

public class AuthenticationVerifer extends ChannelInboundHandlerAdapter {

final static Logger log = LogManager.getLogger(AuthenticationVerifer.class);

private SecurityRestFilter restFilter;

public AuthenticationVerifer(SecurityRestFilter restFilter) {
this.restFilter = restFilter;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof HttpRequest)) {
ctx.fireChannelRead(msg);
}

final HttpRequest request = (HttpRequest) msg;
final Optional<FullHttpResponse> shouldResponse = getAuthenticationResponse(request);
if (shouldResponse.isPresent()) {
ctx.writeAndFlush(shouldResponse.get()).addListener(ChannelFutureListener.CLOSE);
} else {
// Let the request pass to the next channel handler
ctx.fireChannelRead(msg);
}
}

private Optional<FullHttpResponse> getAuthenticationResponse(HttpRequest request) {

log.info("Checking if request is authenticated:\n" + request);

final NettyRequestChannel requestChannel = (NettyRequestChannel) SecurityRequestFactory.from(request);
restFilter.checkAndAuthenticateRequest(requestChannel);

if (requestChannel.hasCompleted()) {
final FullHttpResponse response = new DefaultFullHttpResponse(
request.protocolVersion(),
HttpResponseStatus.valueOf(requestChannel.getCompletedRequest().getLeft()),
Unpooled.copiedBuffer(requestChannel.getCompletedRequest().getRight().getBytes()));
requestChannel.getCompletedRequest().getMiddle().forEach((key, value) -> response.headers().set(key, value));
return Optional.of(response);
}

return Optional.empty();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.opensearch.security.http;

import java.util.Optional;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.security.filter.NettyRequestChannel;
import org.opensearch.security.filter.SecurityRequestChannel;
import org.opensearch.security.filter.SecurityRequestFactory;
import org.opensearch.security.filter.SecurityRestFilter;

import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCountUtil;

public class AuthenticationVerifer extends ChannelInboundHandlerAdapter {

final static Logger log = LogManager.getLogger(AuthenticationVerifer.class);

private SecurityRestFilter restFilter;

public AuthenticationVerifer(SecurityRestFilter restFilter) {
this.restFilter = restFilter;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof HttpRequest)) {
ctx.fireChannelRead(msg);
}

final HttpRequest request = (HttpRequest) msg;
final Optional<FullHttpResponse> shouldResponse = getAuthenticationResponse(request);
if (shouldResponse.isPresent()) {
ctx.writeAndFlush(shouldResponse.get()).addListener(ChannelFutureListener.CLOSE);
} else {
// Let the request pass to the next channel handler
ctx.fireChannelRead(msg);
}
}

private Optional<FullHttpResponse> getAuthenticationResponse(HttpRequest request) {

log.info("Checking if request is authenticated:\n" + request);

final NettyRequestChannel requestChannel = (NettyRequestChannel) SecurityRequestFactory.from(request);

try {
restFilter.checkAndAuthenticateRequest(requestChannel);
} catch (Exception e) {
log.error(e);
}

if (requestChannel.hasCompleted()) {
final FullHttpResponse response = new DefaultFullHttpResponse(
request.protocolVersion(),
HttpResponseStatus.valueOf(requestChannel.getCompletedRequest().getLeft()),
Unpooled.copiedBuffer(requestChannel.getCompletedRequest().getRight().getBytes()));
requestChannel.getCompletedRequest().getMiddle().forEach((key, value) -> response.headers().set(key, value));
return Optional.of(response);
}

return Optional.empty();
}

}

0 comments on commit cc3db41

Please sign in to comment.