diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index 02b09d7500..f1e0248a6a 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -15,36 +15,38 @@ package com.amazon.dlic.auth.http.jwt; +import static org.apache.http.HttpHeaders.AUTHORIZATION; + import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Collection; +import java.util.Map; +import java.util.Optional; import java.util.Map.Entry; import java.util.regex.Pattern; import com.google.common.annotations.VisibleForTesting; import org.apache.cxf.rs.security.jose.jwt.JwtClaims; import org.apache.cxf.rs.security.jose.jwt.JwtToken; -import org.apache.http.HttpHeaders; -import org.apache.logging.log4j.Logger; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.LogManager; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; import com.amazon.dlic.auth.http.jwt.keybyoidc.AuthenticatorUnavailableException; import com.amazon.dlic.auth.http.jwt.keybyoidc.BadCredentialsException; import com.amazon.dlic.auth.http.jwt.keybyoidc.JwtVerifier; import com.amazon.dlic.auth.http.jwt.keybyoidc.KeyProvider; + import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator { @@ -66,8 +68,8 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { jwtUrlParameter = settings.get("jwt_url_parameter"); - jwtHeaderName = settings.get("jwt_header", HttpHeaders.AUTHORIZATION); - isDefaultAuthHeader = HttpHeaders.AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); + jwtHeaderName = settings.get("jwt_header", AUTHORIZATION); + isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS); @@ -83,8 +85,9 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { } @Override - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) - throws OpenSearchSecurityException { + @SuppressWarnings("removal") + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -101,7 +104,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) throws OpenSearchSecurityException { + private AuthCredentials extractCredentials0(final SecurityRequest request) throws OpenSearchSecurityException { String jwtString = getJwtTokenString(request); @@ -142,7 +145,7 @@ private AuthCredentials extractCredentials0(final RestRequest request) throws Op } - protected String getJwtTokenString(RestRequest request) { + protected String getJwtTokenString(SecurityRequest request) { String jwtToken = request.header(jwtHeaderName); if (isDefaultAuthHeader && jwtToken != null && BASIC.matcher(jwtToken).matches()) { jwtToken = null; @@ -150,10 +153,10 @@ protected String getJwtTokenString(RestRequest request) { if (jwtUrlParameter != null) { if (jwtToken == null || jwtToken.isEmpty()) { - jwtToken = request.param(jwtUrlParameter); + jwtToken = request.params().get(jwtUrlParameter); } else { // just consume to avoid "contains unrecognized parameter" - request.param(jwtUrlParameter); + request.params().get(jwtUrlParameter); } } @@ -234,11 +237,10 @@ public String[] extractRoles(JwtClaims claims) { protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception; @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials authCredentials) { - final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials authCredentials) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""), "") + ); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index b417df047c..44800c88bf 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -15,6 +15,8 @@ package com.amazon.dlic.auth.http.jwt; +import static org.apache.http.HttpHeaders.AUTHORIZATION; + import java.nio.file.Path; import java.security.AccessController; import java.security.Key; @@ -25,22 +27,22 @@ import java.security.spec.InvalidKeySpecException; import java.security.spec.X509EncodedKeySpec; import java.util.Collection; +import java.util.Map; +import java.util.Optional; import java.util.Map.Entry; import java.util.regex.Pattern; -import org.apache.http.HttpHeaders; import org.apache.logging.log4j.Logger; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; import io.jsonwebtoken.Claims; @@ -106,8 +108,8 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { } jwtUrlParameter = settings.get("jwt_url_parameter"); - jwtHeaderName = settings.get("jwt_header", HttpHeaders.AUTHORIZATION); - isDefaultAuthHeader = HttpHeaders.AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); + jwtHeaderName = settings.get("jwt_header", AUTHORIZATION); + isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); jwtParser = _jwtParser; @@ -115,7 +117,9 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { @Override - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException { + @SuppressWarnings("removal") + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -132,7 +136,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) { + private AuthCredentials extractCredentials0(final SecurityRequest request) { if (jwtParser == null) { log.error("Missing Signing Key. JWT authentication will not work"); return null; @@ -143,11 +147,11 @@ private AuthCredentials extractCredentials0(final RestRequest request) { jwtToken = null; } - if((jwtToken == null || jwtToken.isEmpty()) && jwtUrlParameter != null) { - jwtToken = request.param(jwtUrlParameter); + if ((jwtToken == null || jwtToken.isEmpty()) && jwtUrlParameter != null) { + jwtToken = request.params().get(jwtUrlParameter); } else { - //just consume to avoid "contains unrecognized parameter" - request.param(jwtUrlParameter); + // just consume to avoid "contains unrecognized parameter" + request.params().get(jwtUrlParameter); } if (jwtToken == null || jwtToken.length() == 0) { @@ -198,11 +202,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED,""); - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + public Optional reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""), "") + ); } @Override @@ -210,7 +213,7 @@ public String getType() { return "jwt"; } - protected String extractSubject(final Claims claims, final RestRequest request) { + protected String extractSubject(final Claims claims, final SecurityRequest request) { String subject = claims.getSubject(); if(subjectKey != null) { // try to get roles from claims, first as Object to avoid having to catch the ExpectedTypeException @@ -229,17 +232,20 @@ protected String extractSubject(final Claims claims, final RestRequest request) } @SuppressWarnings("unchecked") - protected String[] extractRoles(final Claims claims, final RestRequest request) { - // no roles key specified - if(rolesKey == null) { - return new String[0]; - } - // try to get roles from claims, first as Object to avoid having to catch the ExpectedTypeException - final Object rolesObject = claims.get(rolesKey, Object.class); - if(rolesObject == null) { - log.warn("Failed to get roles from JWT claims with roles_key '{}'. Check if this key is correct and available in the JWT payload.", rolesKey); - return new String[0]; - } + protected String[] extractRoles(final Claims claims, final SecurityRequest request) { + // no roles key specified + if (rolesKey == null) { + return new String[0]; + } + // try to get roles from claims, first as Object to avoid having to catch the ExpectedTypeException + final Object rolesObject = claims.get(rolesKey, Object.class); + if (rolesObject == null) { + log.warn( + "Failed to get roles from JWT claims with roles_key '{}'. Check if this key is correct and available in the JWT payload.", + rolesKey + ); + return new String[0]; + } String[] roles = String.valueOf(rolesObject).split(","); diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 10634a626a..e8136d95ad 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -15,6 +15,8 @@ package com.amazon.dlic.auth.http.kerberos; +import static org.apache.http.HttpStatus.SC_UNAUTHORIZED; + import java.io.Serializable; import java.nio.file.Files; import java.nio.file.Path; @@ -26,15 +28,20 @@ import java.security.PrivilegedExceptionAction; import java.util.Base64; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.Optional; import java.util.Set; import javax.security.auth.Subject; import javax.security.auth.login.LoginException; -import org.apache.logging.log4j.Logger; +import com.google.common.base.Strings; + import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; @@ -42,11 +49,6 @@ import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.env.Environment; -//import org.opensearch.env.Environment; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.ietf.jgss.GSSContext; import org.ietf.jgss.GSSCredential; import org.ietf.jgss.GSSException; @@ -56,9 +58,11 @@ import com.amazon.dlic.auth.http.kerberos.util.JaasKrbUtil; import com.amazon.dlic.auth.http.kerberos.util.KrbConstants; + import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; -import com.google.common.base.Strings; public class HTTPSpnegoAuthenticator implements HTTPAuthenticator { @@ -166,7 +170,8 @@ public Void run() { } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext threadContext) { + @SuppressWarnings("removal") + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -183,7 +188,7 @@ public AuthCredentials run() { return creds; } - private AuthCredentials extractCredentials0(final RestRequest request) { + private AuthCredentials extractCredentials0(final SecurityRequest request) { if (acceptorPrincipal == null || acceptorKeyTabPath == null) { log.error("Missing acceptor principal or keytab configuration. Kerberos authentication will not work"); @@ -273,24 +278,22 @@ public GSSCredential run() throws GSSException { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - - final BytesRestResponse wwwAuthenticateResponse; - XContentBuilder response = getNegotiateResponseBody(); - - if (response != null) { - wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, response); - } else { - wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING); + public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) { + final Map headers = new HashMap<>(); + String responseBody = ""; + final String negotiateResponseBody = getNegotiateResponseBody(); + if (negotiateResponseBody != null) { + responseBody = negotiateResponseBody; + headers.putAll(SecurityResponse.CONTENT_TYPE_APP_JSON); } - if(creds == null || creds.getNativeCredentials() == null) { - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate"); + if (creds == null || creds.getNativeCredentials() == null) { + headers.put("WWW-Authenticate", "Negotiate"); } else { - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate "+Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials())); + headers.put("WWW-Authenticate", "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials())); } - channel.sendResponse(wwwAuthenticateResponse); - return true; + + return Optional.of(new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody)); } @Override @@ -362,24 +365,24 @@ private static String getUsernameFromGSSContext(final GSSContext gssContext, fin return null; } - private XContentBuilder getNegotiateResponseBody() { - try { - XContentBuilder negotiateResponseBody = XContentFactory.jsonBuilder(); - negotiateResponseBody.startObject(); - negotiateResponseBody.field("error"); - negotiateResponseBody.startObject(); - negotiateResponseBody.field("header"); - negotiateResponseBody.startObject(); - negotiateResponseBody.field("WWW-Authenticate", "Negotiate"); - negotiateResponseBody.endObject(); - negotiateResponseBody.endObject(); - negotiateResponseBody.endObject(); - return negotiateResponseBody; - } catch (Exception ex) { - log.error("Can't construct response body", ex); - return null; - } - } + private String getNegotiateResponseBody() { + try { + XContentBuilder negotiateResponseBody = XContentFactory.jsonBuilder(); + negotiateResponseBody.startObject(); + negotiateResponseBody.field("error"); + negotiateResponseBody.startObject(); + negotiateResponseBody.field("header"); + negotiateResponseBody.startObject(); + negotiateResponseBody.field("WWW-Authenticate", "Negotiate"); + negotiateResponseBody.endObject(); + negotiateResponseBody.endObject(); + negotiateResponseBody.endObject(); + return negotiateResponseBody.toString(); + } catch (Exception ex) { + log.error("Can't construct response body", ex); + return null; + } + } private static String stripRealmName(String name, boolean strip){ if (strip && name != null) { diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index c2791a02b8..6157853324 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -23,14 +23,23 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; -import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPathExpressionException; -import org.opensearch.security.DefaultObjectMapper; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.base.Strings; +import com.onelogin.saml2.authn.SamlResponse; +import com.onelogin.saml2.exception.ValidationError; +import com.onelogin.saml2.settings.Saml2Settings; +import com.onelogin.saml2.util.Util; + import org.apache.commons.lang3.StringUtils; import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; @@ -41,32 +50,21 @@ import org.apache.cxf.rs.security.jose.jwt.JwtClaims; import org.apache.cxf.rs.security.jose.jwt.JwtToken; import org.apache.cxf.rs.security.jose.jwt.JwtUtils; -import org.apache.logging.log4j.Logger; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.joda.time.DateTime; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestStatus; -import org.joda.time.DateTime; -import org.xml.sax.SAXException; - +import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.dlic.rest.api.AuthTokenProcessorAction; -import com.fasterxml.jackson.core.JsonParseException; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.common.base.Strings; -import com.onelogin.saml2.authn.SamlResponse; -import com.onelogin.saml2.exception.SettingsException; -import com.onelogin.saml2.exception.ValidationError; -import com.onelogin.saml2.settings.Saml2Settings; -import com.onelogin.saml2.util.Util; +import org.opensearch.security.filter.SecurityResponse; class AuthTokenProcessorHandler { private static final Logger log = LogManager.getLogger(AuthTokenProcessorHandler.class); @@ -126,7 +124,8 @@ class AuthTokenProcessorHandler { } - boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception { + @SuppressWarnings("removal") + Optional handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -134,11 +133,10 @@ boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exceptio sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction>() { @Override - public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, - ParserConfigurationException, SAXException, SettingsException { - return handleLowLevel(restRequest, restChannel); + public Optional run() throws SamlConfigException, IOException { + return handleLowLevel(restRequest); } }); } catch (PrivilegedActionException e) { @@ -150,10 +148,8 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc } } - private AuthTokenProcessorAction.Response handleImpl(RestRequest restRequest, RestChannel restChannel, - String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings) - throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, - SettingsException { + private AuthTokenProcessorAction.Response handleImpl(RestRequest restRequest, + String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings) { if (token_log.isDebugEnabled()) { try { token_log.debug("SAMLResponse for {}\n{}", samlRequestId, new String(Util.base64decoder(samlResponseBase64), StandardCharsets.UTF_8)); @@ -188,8 +184,7 @@ private AuthTokenProcessorAction.Response handleImpl(RestRequest restRequest, Re } } - private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, - IOException, XPathExpressionException, ParserConfigurationException, SAXException, SettingsException { + private Optional handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getXContentType() != XContentType.JSON) { @@ -233,27 +228,19 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue()); } - AuthTokenProcessorAction.Response responseBody = this.handleImpl(restRequest, restChannel, - samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); + AuthTokenProcessorAction.Response responseBody = this.handleImpl(restRequest, + samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return false; + return Optional.empty(); } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", - responseBodyString); - restChannel.sendResponse(authenticateResponse); - - return true; + return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString)); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, - "JSON could not be parsed"); - restChannel.sendResponse(authenticateResponse); - return true; + return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed")); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index 774c31e465..0ff2158232 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -15,12 +15,19 @@ package com.amazon.dlic.auth.http.saml; +import java.io.IOException; import java.net.URL; import java.nio.file.Path; import java.security.AccessController; import java.security.PrivateKey; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.xml.parsers.ParserConfigurationException; import com.google.common.annotations.VisibleForTesting; import net.shibboleth.utilities.java.support.xml.BasicParserPool; @@ -32,10 +39,7 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensaml.core.config.InitializationException; import org.opensaml.core.config.InitializationService; import org.opensaml.saml.metadata.resolver.MetadataResolver; @@ -58,15 +62,17 @@ import net.shibboleth.utilities.java.support.component.ComponentInitializationException; import net.shibboleth.utilities.java.support.component.DestructableComponent; +import org.apache.http.HttpStatus; import org.opensaml.saml.metadata.resolver.impl.AbstractMetadataResolver; import org.opensaml.saml.metadata.resolver.impl.DOMMetadataResolver; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.xml.sax.SAXException; -import javax.xml.parsers.ParserConfigurationException; -import java.io.IOException; -import java.util.regex.Matcher; -import java.util.regex.Pattern; + +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestChannelUnsupported; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.OpenSearchRequest; import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX; import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX; @@ -152,18 +158,18 @@ public HTTPSamlAuthenticator(final Settings settings, final Path configPath) { } @Override - public AuthCredentials extractCredentials(RestRequest restRequest, ThreadContext threadContext) - throws OpenSearchSecurityException { - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) + throws OpenSearchSecurityException { + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { return null; } - AuthCredentials authCredentials = this.httpJwtAuthenticator.extractCredentials(restRequest, threadContext); + AuthCredentials authCredentials = this.httpJwtAuthenticator.extractCredentials(request, threadContext); if (AUTHINFO_SUFFIX.equals(suffix)) { - this.initLogoutUrl(restRequest, threadContext, authCredentials); + this.initLogoutUrl(threadContext, authCredentials); } return authCredentials; @@ -175,27 +181,32 @@ public String getType() { } @Override - public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authCredentials) { + public Optional reRequestAuthentication(final SecurityRequest request, final AuthCredentials authCredentials) { try { - RestRequest restRequest = restChannel.request(); - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (API_AUTHTOKEN_SUFFIX.equals(suffix) - && this.authTokenProcessorHandler.handle(restRequest, restChannel)){ - return true; - } - - Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); - authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)); - - restChannel.sendResponse(authenticateResponse); + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + // Verficiation of SAML ASC endpoint only works with RestRequests + if (!(request instanceof OpenSearchRequest)) { + throw new SecurityRequestChannelUnsupported(); + } else { + final OpenSearchRequest openSearchRequest = (OpenSearchRequest) request; + final RestRequest restRequest = openSearchRequest.breakEncapsulationForRequest(); + Optional restResponse = this.authTokenProcessorHandler.handle(restRequest); + if (restResponse.isPresent()) { + return restResponse; + } + } + } - return true; + final Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)), "") + ); } catch (Exception e) { log.error("Error in reRequestAuthentication()", e); - return false; + return Optional.empty(); } } @@ -395,7 +406,7 @@ String buildLogoutUrl(AuthCredentials authCredentials) { } - private void initLogoutUrl(RestRequest restRequest, ThreadContext threadContext, AuthCredentials authCredentials) { + private void initLogoutUrl(ThreadContext threadContext, AuthCredentials authCredentials) { threadContext.putTransient(ConfigConstants.SSO_LOGOUT_URL, buildLogoutUrl(authCredentials)); } diff --git a/src/main/java/org/opensearch/security/auditlog/AuditLog.java b/src/main/java/org/opensearch/security/auditlog/AuditLog.java index 58740de6c0..e04b01ef0d 100644 --- a/src/main/java/org/opensearch/security/auditlog/AuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/AuditLog.java @@ -40,22 +40,25 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.security.compliance.ComplianceConfig; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; public interface AuditLog extends Closeable { - //login + // login void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, TransportRequest request, Task task); - void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request); + void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request); + void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, TransportRequest request, String action, Task task); - void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request); + void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request); + + // privs + void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request); + + void logGrantedPrivileges(String effectiveUser, SecurityRequest request); - //privs - void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request); - void logGrantedPrivileges(String effectiveUser, RestRequest request); void logMissingPrivileges(String privilege, TransportRequest request, Task task); void logGrantedPrivileges(String privilege, TransportRequest request, Task task); @@ -64,12 +67,14 @@ public interface AuditLog extends Closeable { //spoof void logBadHeaders(TransportRequest request, String action, Task task); - void logBadHeaders(RestRequest request); + + void logBadHeaders(SecurityRequest request); void logSecurityIndexAttempt(TransportRequest request, String action, Task task); void logSSLException(TransportRequest request, Throwable t, String action, Task task); - void logSSLException(RestRequest request, Throwable t); + + void logSSLException(SecurityRequest request, Throwable t); void logDocumentRead(String index, String id, ShardId shardId, Map fieldNameValues); void logDocumentWritten(ShardId shardId, GetResult originalIndex, Index currentIndex, IndexResult result); diff --git a/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java b/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java index 9e00a1d2c6..3ceabfdc22 100644 --- a/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java +++ b/src/main/java/org/opensearch/security/auditlog/AuditLogSslExceptionHandler.java @@ -31,7 +31,8 @@ package org.opensearch.security.auditlog; import org.opensearch.OpenSearchException; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; @@ -47,7 +48,7 @@ public AuditLogSslExceptionHandler(final AuditLog auditLog) { } @Override - public void logError(Throwable t, RestRequest request, int type) { + public void logError(Throwable t, SecurityRequestChannel request, int type) { switch (type) { case 0: auditLog.logSSLException(request, t); diff --git a/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java b/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java index 20b1faa909..f4978204df 100644 --- a/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/NullAuditLog.java @@ -42,6 +42,7 @@ import org.opensearch.index.shard.ShardId; import org.opensearch.rest.RestRequest; import org.opensearch.security.compliance.ComplianceConfig; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; @@ -58,18 +59,13 @@ public void logFailedLogin(String effectiveUser, boolean securityadmin, String i } @Override - public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { - //noop, intentionally left empty + public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { + // noop, intentionally left empty } @Override - public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, TransportRequest request, String action, Task task) { - //noop, intentionally left empty - } - - @Override - public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { - //noop, intentionally left empty + public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { + // noop, intentionally left empty } @Override @@ -93,8 +89,8 @@ public void logBadHeaders(TransportRequest request, String action, Task task) { } @Override - public void logBadHeaders(RestRequest request) { - //noop, intentionally left empty + public void logBadHeaders(SecurityRequest request) { + // noop, intentionally left empty } @Override @@ -108,18 +104,18 @@ public void logSSLException(TransportRequest request, Throwable t, String action } @Override - public void logSSLException(RestRequest request, Throwable t) { - //noop, intentionally left empty + public void logSSLException(SecurityRequest request, Throwable t) { + // noop, intentionally left empty } @Override - public void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request) { - //noop, intentionally left empty + public void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request) { + // noop, intentionally left empty } @Override - public void logGrantedPrivileges(String effectiveUser, RestRequest request) { - //noop, intentionally left empty + public void logGrantedPrivileges(String effectiveUser, SecurityRequest request) { + // noop, intentionally left empty } @Override @@ -147,4 +143,10 @@ public void setConfig(AuditConfig auditConfig) { } + @Override + public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, + TransportRequest request, String action, Task task) { + //noop, intentionally left empty + } + } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java index bd479c0db2..185dc0be40 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java @@ -62,7 +62,6 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; @@ -70,6 +69,7 @@ import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.compliance.ComplianceConfig; import org.opensearch.security.dlic.rest.support.Utils; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.User; @@ -151,7 +151,7 @@ public void logFailedLogin(String effectiveUser, boolean securityadmin, String i @Override - public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { + public void logFailedLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { if(!checkRestFilter(AuditCategory.FAILED_LOGIN, effectiveUser, request)) { return; @@ -184,7 +184,7 @@ public void logSucceededLogin(String effectiveUser, boolean securityadmin, Strin } @Override - public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, RestRequest request) { + public void logSucceededLogin(String effectiveUser, boolean securityadmin, String initiatingUser, SecurityRequest request) { if(!checkRestFilter(AuditCategory.AUTHENTICATED, effectiveUser, request)) { return; @@ -201,8 +201,8 @@ public void logSucceededLogin(String effectiveUser, boolean securityadmin, Strin } @Override - public void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request) { - if(!checkRestFilter(AuditCategory.MISSING_PRIVILEGES, effectiveUser, request)) { + public void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request) { + if (!checkRestFilter(AuditCategory.MISSING_PRIVILEGES, effectiveUser, request)) { return; } @@ -215,8 +215,8 @@ public void logMissingPrivileges(String privilege, String effectiveUser, RestReq } @Override - public void logGrantedPrivileges(String effectiveUser, RestRequest request) { - if(!checkRestFilter(AuditCategory.GRANTED_PRIVILEGES, effectiveUser, request)) { + public void logGrantedPrivileges(String effectiveUser, SecurityRequest request) { + if (!checkRestFilter(AuditCategory.GRANTED_PRIVILEGES, effectiveUser, request)) { return; } @@ -290,7 +290,7 @@ public void logBadHeaders(TransportRequest request, String action, Task task) { } @Override - public void logBadHeaders(RestRequest request) { + public void logBadHeaders(SecurityRequest request) { if(!checkRestFilter(AuditCategory.BAD_HEADERS, getUser(), request)) { return; @@ -338,7 +338,7 @@ public void logSSLException(TransportRequest request, Throwable t, String action } @Override - public void logSSLException(RestRequest request, Throwable t) { + public void logSSLException(SecurityRequest request, Throwable t) { if(!checkRestFilter(AuditCategory.SSL_EXCEPTION, getUser(), request)) { return; @@ -726,7 +726,7 @@ private boolean checkComplianceFilter(final AuditCategory category, final String } @VisibleForTesting - boolean checkRestFilter(final AuditCategory category, final String effectiveUser, RestRequest request) { + boolean checkRestFilter(final AuditCategory category, final String effectiveUser, SecurityRequest request) { final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { log.trace("Check for REST category:{}, effectiveUser:{}, request:{}", category, effectiveUser, request==null?null:request.path()); diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java index 1bb802f291..7b160be6dd 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditLogImpl.java @@ -34,8 +34,8 @@ import org.opensearch.index.engine.Engine.IndexResult; import org.opensearch.index.get.GetResult; import org.opensearch.index.shard.ShardId; -import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.routing.AuditMessageRouter; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; @@ -128,47 +128,33 @@ protected void save(final AuditMessage msg) { } } - @Override - public void logFailedLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, TransportRequest request, Task task) { - if (enabled) { - super.logFailedLogin(effectiveUser, securityAdmin, initiatingUser, request, task); - } - } - - @Override - public void logFailedLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, RestRequest request) { - if (enabled) { - super.logFailedLogin(effectiveUser, securityAdmin, initiatingUser, request); - } - } - - @Override - public void logSucceededLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, TransportRequest request, String action, Task task) { - if (enabled) { - super.logSucceededLogin(effectiveUser, securityAdmin, initiatingUser, request, action, task); - } - } + @Override + public void logFailedLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, SecurityRequest request) { + if (enabled) { + super.logFailedLogin(effectiveUser, securityAdmin, initiatingUser, request); + } + } - @Override - public void logSucceededLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, RestRequest request) { - if (enabled) { - super.logSucceededLogin(effectiveUser, securityAdmin, initiatingUser, request); - } - } + @Override + public void logSucceededLogin(String effectiveUser, boolean securityAdmin, String initiatingUser, SecurityRequest request) { + if (enabled) { + super.logSucceededLogin(effectiveUser, securityAdmin, initiatingUser, request); + } + } - @Override - public void logMissingPrivileges(String privilege, String effectiveUser, RestRequest request) { - if (enabled) { - super.logMissingPrivileges(privilege, effectiveUser, request); - } - } + @Override + public void logMissingPrivileges(String privilege, String effectiveUser, SecurityRequest request) { + if (enabled) { + super.logMissingPrivileges(privilege, effectiveUser, request); + } + } - @Override - public void logGrantedPrivileges(String effectiveUser, RestRequest request) { - if (enabled) { - super.logGrantedPrivileges(effectiveUser, request); - } - } + @Override + public void logGrantedPrivileges(String effectiveUser, SecurityRequest request) { + if (enabled) { + super.logGrantedPrivileges(effectiveUser, request); + } + } @Override public void logMissingPrivileges(String privilege, TransportRequest request, Task task) { @@ -184,12 +170,12 @@ public void logGrantedPrivileges(String privilege, TransportRequest request, Tas } } - @Override - public void logIndexEvent(String privilege, TransportRequest request, Task task) { - if (enabled) { - super.logIndexEvent(privilege, request, task); - } - } + @Override + public void logBadHeaders(SecurityRequest request) { + if (enabled) { + super.logBadHeaders(request); + } + } @Override public void logBadHeaders(TransportRequest request, String action, Task task) { @@ -198,19 +184,12 @@ public void logBadHeaders(TransportRequest request, String action, Task task) { } } - @Override - public void logBadHeaders(RestRequest request) { - if (enabled) { - super.logBadHeaders(request); - } - } - - @Override - public void logSecurityIndexAttempt (TransportRequest request, String action, Task task) { - if (enabled) { - super.logSecurityIndexAttempt(request, action, task); - } - } + @Override + public void logSSLException(SecurityRequest request, Throwable t) { + if (enabled) { + super.logSSLException(request, t); + } + } @Override public void logSSLException(TransportRequest request, Throwable t, String action, Task task) { @@ -219,13 +198,6 @@ public void logSSLException(TransportRequest request, Throwable t, String action } } - @Override - public void logSSLException(RestRequest request, Throwable t) { - if (enabled) { - super.logSSLException(request, t); - } - } - @Override public void logDocumentRead(String index, String id, ShardId shardId, Map fieldNameValues) { if (enabled) { diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java index def54fb041..fe99f37a9d 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java @@ -49,6 +49,8 @@ import org.opensearch.security.auditlog.AuditLog.Operation; import org.opensearch.security.auditlog.AuditLog.Origin; import org.opensearch.security.dlic.rest.support.Utils; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.OpenSearchRequest; import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX; import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX; @@ -368,16 +370,31 @@ void addRestMethod(final RestRequest.Method method) { } } - void addRestRequestInfo(final RestRequest request, final AuditConfig.Filter filter) { + void addRestRequestInfo(final SecurityRequest request, final AuditConfig.Filter filter) { if (request != null) { - final String path = request.path(); + final String path = request.path().toString(); addPath(path); addRestHeaders(request.getHeaders(), filter.shouldExcludeSensitiveHeaders()); addRestParams(request.params()); addRestMethod(request.method()); - if (filter.shouldLogRequestBody() && request.hasContentOrSourceParam()) { + + if (filter.shouldLogRequestBody()) { + + if (!(request instanceof OpenSearchRequest)) { + // The request body is only avaliable on some request sources + return; + } + + final OpenSearchRequest securityRestRequest = (OpenSearchRequest) request; + final RestRequest restRequest = securityRestRequest.breakEncapsulationForRequest(); + + if (!(restRequest.hasContentOrSourceParam())) { + // If there is no content, don't attempt to save any body information + return; + } + try { - final Tuple xContentTuple = request.contentOrSourceParam(); + final Tuple xContentTuple = restRequest.contentOrSourceParam(); final String requestBody = XContentHelper.convertToJson(xContentTuple.v2(), false, xContentTuple.v1()); if (path != null && requestBody != null && SENSITIVE_PATHS.matcher(path).matches() diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index d1454e353e..bfc3575831 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -37,6 +37,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.SortedSet; import java.util.concurrent.Callable; @@ -46,20 +47,29 @@ import javax.naming.ldap.LdapName; import javax.naming.ldap.Rdn; -import org.apache.logging.log4j.Logger; +import com.google.common.base.Strings; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalListener; +import com.google.common.cache.RemovalNotification; +import com.google.common.collect.Multimap; + import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.greenrobot.eventbus.Subscribe; + + import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.auth.blocking.ClientBlockRegistry; import org.opensearch.security.auth.internal.NoOpAuthenticationBackend; import org.opensearch.security.configuration.AdminDNs; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.http.XFFResolver; import org.opensearch.security.securityconf.DynamicConfigModel; import org.opensearch.security.ssl.util.Utils; @@ -70,14 +80,10 @@ import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; -import org.greenrobot.eventbus.Subscribe; -import com.google.common.base.Strings; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.RemovalListener; -import com.google.common.cache.RemovalNotification; -import com.google.common.collect.Multimap; +import static org.apache.http.HttpStatus.SC_FORBIDDEN; +import static org.apache.http.HttpStatus.SC_SERVICE_UNAVAILABLE; +import static org.apache.http.HttpStatus.SC_UNAUTHORIZED; public class BackendRegistry { @@ -351,15 +357,18 @@ public User authenticate(final TransportRequest request, final String sslPrincip * @return The authenticated user, null means another roundtrip * @throws OpenSearchSecurityException */ - public boolean authenticate(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + public boolean authenticate(final SecurityRequestChannel request) { final boolean isDebugEnabled = log.isDebugEnabled(); - if (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress())) { + final boolean isBlockedBasedOnAddress = request.getRemoteAddress() + .map(InetSocketAddress::getAddress) + .map(address -> isBlocked(address)) + .orElse(false); + if (isBlockedBasedOnAddress) { if (isDebugEnabled) { - log.debug("Rejecting REST request because of blocked address: {}", request.getHttpChannel().getRemoteAddress()); + log.debug("Rejecting REST request because of blocked address: {}", request.getRemoteAddress().orElse(null)); } - - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); + request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")); return false; } @@ -379,18 +388,17 @@ public boolean authenticate(final RestRequest request, final RestChannel channel if (!isInitialized()) { log.error("Not yet initialized (you may need to run securityadmin)"); - channel.sendResponse(new BytesRestResponse(RestStatus.SERVICE_UNAVAILABLE, - "OpenSearch Security not initialized.")); + request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, null, "OpenSearch Security not initialized.")); return false; } final TransportAddress remoteAddress = xffResolver.resolve(request); final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { - log.trace("Rest authentication request from {} [original: {}]", remoteAddress, request.getHttpChannel().getRemoteAddress()); - } + log.trace("Rest authentication request from {} [original: {}]", remoteAddress, request.getRemoteAddress().orElse(null)); + } - threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, remoteAddress); + threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, remoteAddress); boolean authenticated = false; @@ -417,7 +425,7 @@ public boolean authenticate(final RestRequest request, final RestChannel channel } final AuthCredentials ac; try { - ac = httpAuthenticator.extractCredentials(request, threadContext); + ac = httpAuthenticator.extractCredentials(request, threadPool.getThreadContext()); } catch (Exception e1) { if (isDebugEnabled) { log.debug("'{}' extracting credentials from {} http authenticator", e1.toString(), httpAuthenticator.getType(), e1); @@ -441,10 +449,17 @@ public boolean authenticate(final RestRequest request, final RestChannel channel continue; } - if(authDomain.isChallenge() && httpAuthenticator.reRequestAuthentication(channel, null)) { - auditLog.logFailedLogin("", false, null, request); - log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); - return false; + if (authDomain.isChallenge()) { + final Optional restResponse = httpAuthenticator.reRequestAuthentication(request, null); + if (restResponse.isPresent()) { + auditLog.logFailedLogin("", false, null, request); + if (isTraceEnabled) { + log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + } + notifyIpAuthFailureListeners(request, authCredenetials); + request.queueForSending(restResponse.get()); + return false; + } } else { //no reRequest possible if (isTraceEnabled) { @@ -455,9 +470,11 @@ public boolean authenticate(final RestRequest request, final RestChannel channel } else { org.apache.logging.log4j.ThreadContext.put("user", ac.getUsername()); if (!ac.isComplete()) { - //credentials found in request but we need another client challenge - if(httpAuthenticator.reRequestAuthentication(channel, ac)) { - //auditLog.logFailedLogin(ac.getUsername()+" ", request); --noauditlog + // credentials found in request but we need another client challenge + final Optional restResponse = httpAuthenticator.reRequestAuthentication(request, ac); + if (restResponse.isPresent()) { + notifyIpAuthFailureListeners(request, ac); + request.queueForSending(restResponse.get()); return false; } else { //no reRequest possible @@ -476,9 +493,10 @@ public boolean authenticate(final RestRequest request, final RestChannel channel } for (AuthFailureListener authFailureListener : this.authBackendFailureListeners.get(authDomain.getBackend().getClass().getName())) { authFailureListener.onAuthFailure( - (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress) ? ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress() - : null, - ac, request); + request.getRemoteAddress().map(InetSocketAddress::getAddress).orElse(null), + ac, + request + ); } continue; } @@ -486,8 +504,13 @@ public boolean authenticate(final RestRequest request, final RestChannel channel if(adminDns.isAdmin(authenticatedUser)) { log.error("Cannot authenticate rest user because admin user is not permitted to login via HTTP"); auditLog.logFailedLogin(authenticatedUser.getName(), true, null, request); - channel.sendResponse(new BytesRestResponse(RestStatus.FORBIDDEN, - "Cannot authenticate user because admin user is not permitted to login via HTTP")); + request.queueForSending( + new SecurityResponse( + SC_FORBIDDEN, + null, + "Cannot authenticate user because admin user is not permitted to login via HTTP" + ) + ); return false; } @@ -505,30 +528,42 @@ public boolean authenticate(final RestRequest request, final RestChannel channel if(authenticated) { final User impersonatedUser = impersonate(request, authenticatedUser); - threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, impersonatedUser==null?authenticatedUser:impersonatedUser); - auditLog.logSucceededLogin((impersonatedUser == null ? authenticatedUser : impersonatedUser).getName(), false, - authenticatedUser.getName(), request); + threadPool.getThreadContext() + .putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, impersonatedUser == null ? authenticatedUser : impersonatedUser); + auditLog.logSucceededLogin( + (impersonatedUser == null ? authenticatedUser : impersonatedUser).getName(), + false, + authenticatedUser.getName(), + request + ); } else { if (isDebugEnabled) { log.debug("User still not authenticated after checking {} auth domains", restAuthDomains.size()); } - if(authCredenetials == null && anonymousAuthEnabled) { - threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, User.ANONYMOUS); - auditLog.logSucceededLogin(User.ANONYMOUS.getName(), false, null, request); + if (authCredenetials == null && anonymousAuthEnabled) { + final String tenant = Utils.coalesce(request.header("securitytenant"), request.header("security_tenant")); + User anonymousUser = new User(User.ANONYMOUS.getName(), new HashSet(User.ANONYMOUS.getRoles()), null); + anonymousUser.setRequestedTenant(tenant); + + threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, anonymousUser); + auditLog.logSucceededLogin(anonymousUser.getName(), false, null, request); if (isDebugEnabled) { log.debug("Anonymous User is authenticated"); } return true; } + Optional challengeResponse = Optional.empty(); + if(firstChallengingHttpAuthenticator != null) { if (isDebugEnabled) { log.debug("Rerequest with {}", firstChallengingHttpAuthenticator.getClass()); } - if(firstChallengingHttpAuthenticator.reRequestAuthentication(channel, null)) { + challengeResponse = firstChallengingHttpAuthenticator.reRequestAuthentication(request, null); + if (challengeResponse.isPresent()) { if (isDebugEnabled) { log.debug("Rerequest {} failed", firstChallengingHttpAuthenticator.getClass()); } @@ -545,17 +580,16 @@ public boolean authenticate(final RestRequest request, final RestChannel channel notifyIpAuthFailureListeners(request, authCredenetials); - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); + request.queueForSending( + challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")) + ); return false; } - return authenticated; } - private void notifyIpAuthFailureListeners(RestRequest request, AuthCredentials authCredentials) { - notifyIpAuthFailureListeners( - (request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress) ? ((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).getAddress() : null, - authCredentials, request); + private void notifyIpAuthFailureListeners(SecurityRequestChannel request, AuthCredentials authCredentials) { + notifyIpAuthFailureListeners(request.getRemoteAddress().map(InetSocketAddress::getAddress).orElse(null), authCredentials, request); } private void notifyIpAuthFailureListeners(InetAddress remoteAddress, AuthCredentials authCredentials, Object request) { @@ -745,7 +779,7 @@ private User impersonate(final TransportRequest tr, final User origPKIuser) thro return aU; } - private User impersonate(final RestRequest request, final User originalUser) throws OpenSearchSecurityException { + private User impersonate(final SecurityRequest request, final User originalUser) throws OpenSearchSecurityException { final String impersonatedUserHeader = request.header("opendistro_security_impersonate_as"); diff --git a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java index b0bd5033ad..f259952fa3 100644 --- a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java @@ -30,11 +30,13 @@ package org.opensearch.security.auth; +import java.util.Optional; + import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; - +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; /** @@ -72,19 +74,18 @@ public interface HTTPAuthenticator { * If the authentication flow needs another roundtrip with the request originator do not mark it as complete. * @throws OpenSearchSecurityException */ - AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException; - + AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) throws OpenSearchSecurityException; + /** * If the {@code extractCredentials()} call was not successful or the authentication flow needs another roundtrip this method * will be called. If the custom HTTP authenticator does not support this method is a no-op and false should be returned. * * If the custom HTTP authenticator does support re-request authentication or supports authentication flows with multiple roundtrips - * then the response should be sent (through the channel) and true must be returned. - * - * @param channel The rest channel to sent back the response via {@code channel.sendResponse()} + * then the response will be returned which can then be sent via response channel. + * + * @param request The request to reauthenticate or not * @param credentials The credentials from the prior authentication attempt - * @return false if re-request is not supported/necessary, true otherwise. - * If true is returned {@code channel.sendResponse()} must be called so that the request completes. + * @return Optional response if is not supported/necessary, response object otherwise. */ - boolean reRequestAuthentication(final RestChannel channel, AuthCredentials credentials); + Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials credentials); } diff --git a/src/main/java/org/opensearch/security/auth/UserInjector.java b/src/main/java/org/opensearch/security/auth/UserInjector.java index 1709f14ab7..eaa7505a8f 100644 --- a/src/main/java/org/opensearch/security/auth/UserInjector.java +++ b/src/main/java/org/opensearch/security/auth/UserInjector.java @@ -42,6 +42,7 @@ import org.opensearch.common.transport.TransportAddress; import org.opensearch.rest.RestRequest; import org.opensearch.security.auditlog.AuditLog; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.security.http.XFFResolver; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.SecurityUtils; @@ -176,8 +177,7 @@ InjectedUser getInjectedUser() { return injectedUser; } - - boolean injectUser(RestRequest request) { + boolean injectUser(SecurityRequestChannel request) { InjectedUser injectedUser = getInjectedUser(); if(injectedUser == null) { return false; diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java b/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java index 8ed778af05..b7040514ca 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java @@ -57,6 +57,8 @@ import org.opensearch.security.dlic.rest.validation.AbstractConfigurationValidator; import org.opensearch.security.dlic.rest.validation.AbstractConfigurationValidator.ErrorType; import org.opensearch.security.privileges.PrivilegesEvaluator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.securityconf.DynamicConfigFactory; import org.opensearch.security.securityconf.Hideable; import org.opensearch.security.securityconf.StaticDefinable; @@ -380,14 +382,15 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie final User user = (User) threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); final String userName = user == null ? null : user.getName(); + final SecurityRequest securityRequest = SecurityRequestFactory.from(request); if (authError != null) { log.error("No permission to access REST API: " + authError); - auditLog.logMissingPrivileges(authError, userName, request); + auditLog.logMissingPrivileges(authError, userName, securityRequest); // for rest request request.params().clear(); return channel -> forbidden(channel, "No permission to access REST API: " + authError); } else { - auditLog.logGrantedPrivileges(userName, request); + auditLog.logGrantedPrivileges(userName, securityRequest); } final Object originalUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java b/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java index 93c05b5232..344c33128e 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/RestApiPrivilegesEvaluator.java @@ -37,6 +37,8 @@ import org.opensearch.rest.RestRequest.Method; import org.opensearch.security.configuration.AdminDNs; import org.opensearch.security.dlic.rest.support.Utils; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.privileges.PrivilegesEvaluator; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -47,387 +49,431 @@ // TODO: Make Singleton? public class RestApiPrivilegesEvaluator { - protected final Logger logger = LogManager.getLogger(this.getClass()); - - private final AdminDNs adminDNs; - private final PrivilegesEvaluator privilegesEvaluator; - private final PrincipalExtractor principalExtractor; - private final Path configPath; - private final ThreadPool threadPool; - private final Settings settings; - - private final Set allowedRoles = new HashSet<>(); - - // endpoints per role, read and cached from settings. Changes here require a - // node restart, so it's save to cache. - private final Map>> disabledEndpointsForRoles = new HashMap<>(); - - // endpoints per user, evaluated and cached dynamically. Changes here - // require a node restart, so it's save to cache. - private final Map>> disabledEndpointsForUsers = new HashMap<>(); - - // globally disabled endpoints and methods, will always be forbidden - Map> globallyDisabledEndpoints = new HashMap<>(); - - // all endpoints and methods, will be returned for users that do not have any access at all - Map> allEndpoints = new HashMap<>(); - - private final Boolean roleBasedAccessEnabled; - - public RestApiPrivilegesEvaluator(Settings settings, AdminDNs adminDNs, PrivilegesEvaluator privilegesEvaluator, PrincipalExtractor principalExtractor, Path configPath, - ThreadPool threadPool) { - - this.adminDNs = adminDNs; - this.privilegesEvaluator = privilegesEvaluator; - this.principalExtractor = principalExtractor; - this.configPath = configPath; - this.threadPool = threadPool; - this.settings = settings; - - // set up - - // all endpoints and methods - Map> allEndpoints = new HashMap<>(); - for(Endpoint endpoint : Endpoint.values()) { - List allMethods = new LinkedList<>(); - allMethods.addAll(Arrays.asList(Method.values())); - allEndpoints.put(endpoint, allMethods); - } - this.allEndpoints = Collections.unmodifiableMap(allEndpoints); - - // setup role based permissions - allowedRoles.addAll(settings.getAsList(ConfigConstants.SECURITY_RESTAPI_ROLES_ENABLED)); - - this.roleBasedAccessEnabled = !allowedRoles.isEmpty(); - - // globally disabled endpoints, disables access to Endpoint/Method combination for all roles - Settings globalSettings = settings.getAsSettings(ConfigConstants.SECURITY_RESTAPI_ENDPOINTS_DISABLED + ".global"); - if (!globalSettings.isEmpty()) { - globallyDisabledEndpoints = parseDisabledEndpoints(globalSettings); - } - - final boolean isDebugEnabled = logger.isDebugEnabled(); - if (isDebugEnabled) { - logger.debug("Globally disabled endpoints: {}", globallyDisabledEndpoints); - } - - for (String role : allowedRoles) { - Settings settingsForRole = settings.getAsSettings(ConfigConstants.SECURITY_RESTAPI_ENDPOINTS_DISABLED + "." + role); - if (settingsForRole.isEmpty()) { - if (isDebugEnabled) { - logger.debug("No disabled endpoints/methods for permitted role {} found, allowing all", role); - } - continue; - } - Map> disabledEndpointsForRole = parseDisabledEndpoints(settingsForRole); - if (!disabledEndpointsForRole.isEmpty()) { - disabledEndpointsForRoles.put(role, disabledEndpointsForRole); - } else { - logger.warn("Disabled endpoints/methods empty for role {}, please check configuration", role); - } - } - if (logger.isTraceEnabled()) { - logger.trace("Parsed permission set for endpoints: {}", disabledEndpointsForRoles); - } - } - - @SuppressWarnings({ "rawtypes" }) - private Map> parseDisabledEndpoints(Settings settings) { - - // Expects Setting like: 'ACTIONGROUPS=["GET", "POST"]' - if (settings == null || settings.isEmpty()) { - logger.error("Settings for disabled endpoint is null or empty: '{}', skipping.", settings); - return Collections.emptyMap(); - } - - final Map> disabledEndpoints = new HashMap>(); - - Map disabledEndpointsSettings = Utils.convertJsonToxToStructuredMap(settings); - - for (Entry value : disabledEndpointsSettings.entrySet()) { - // key is the endpoint, see if it is a valid one - String endpointString = value.getKey().toUpperCase(); - Endpoint endpoint = null; - try { - endpoint = Endpoint.valueOf(endpointString); - } catch (Exception e) { - logger.error("Unknown endpoint '{}' found in configuration, skipping.", endpointString); - continue; - } - // value must be non null - if (value.getValue() == null) { - logger.error("Disabled HTTP methods of endpoint '{}' is null, skipping.", endpointString); - continue; - } - - // value must be an array of methods - if (!(value.getValue() instanceof Collection)) { - logger.error("Disabled HTTP methods of endpoint '{}' must be an array, actually is '{}', skipping.", endpointString, (value.getValue().toString())); - } - List disabledMethods = new LinkedList<>(); - for (Object disabledMethodObj : (Collection) value.getValue()) { - if (disabledMethodObj == null) { - logger.error("Found null value in disabled HTTP methods of endpoint '{}', skipping.", endpointString); - continue; - } - - if (!(disabledMethodObj instanceof String)) { - logger.error("Found non-String value in disabled HTTP methods of endpoint '{}', skipping.", endpointString); - continue; - } - - String disabledMethodAsString = (String) disabledMethodObj; - - // Provide support for '*', means all methods - if (disabledMethodAsString.trim().equals("*")) { - disabledMethods.addAll(Arrays.asList(Method.values())); - break; - } - // no wild card, disabled method must be one of - // RestRequest.Method - Method disabledMethod = null; - try { - disabledMethod = Method.valueOf(disabledMethodAsString.toUpperCase()); - } catch (Exception e) { - logger.error("Invalid HTTP method '{}' found in disabled HTTP methods of endpoint '{}', skipping.", disabledMethodAsString.toUpperCase(), endpointString); - continue; - } - disabledMethods.add(disabledMethod); - } - - disabledEndpoints.put(endpoint, disabledMethods); - - } - return disabledEndpoints; - } - - /** - * Check if the current request is allowed to use the REST API and the - * requested end point. Using an admin certificate grants all permissions. A - * user/role can have restricted end points. - * - * @return an error message if user does not have access, null otherwise - * TODO: log failed attempt in audit log - */ - public String checkAccessPermissions(RestRequest request, Endpoint endpoint) throws IOException { - - if (logger.isDebugEnabled()) { - logger.debug("Checking admin access for endpoint {}, path {} and method {}", endpoint.name(), request.path(), request.method().name()); - } - - // Grant permission for Account endpoint. - // Return null to grant access. - if (endpoint == Endpoint.ACCOUNT) { - return null; - } - - String roleBasedAccessFailureReason = checkRoleBasedAccessPermissions(request, endpoint); - // Role based access granted - if (roleBasedAccessFailureReason == null) { - return null; - } - - String certBasedAccessFailureReason = checkAdminCertBasedAccessPermissions(request); - // TLS access granted, skip checking roles - if (certBasedAccessFailureReason == null) { - return null; - } - - - return constructAccessErrorMessage(roleBasedAccessFailureReason, certBasedAccessFailureReason); - } - - public Boolean currentUserHasRestApiAccess(Set userRoles) { - - // check if user has any role that grants access - return !Collections.disjoint(allowedRoles, userRoles); - - } - - public Map> getDisabledEndpointsForCurrentUser(String userPrincipal, Set userRoles) { - - final boolean isDebugEnabled = logger.isDebugEnabled(); - - // cache - if (disabledEndpointsForUsers.containsKey(userPrincipal)) { - return disabledEndpointsForUsers.get(userPrincipal); - } - - if (!currentUserHasRestApiAccess(userRoles)) { - return this.allEndpoints; - } - - // will contain the final list of disabled endpoints and methods - Map> finalEndpoints = new HashMap<>(); - - // List of all disabled endpoints for user. Disabled endpoints must be configured in all - // roles to take effect. If a role contains a disabled endpoint, but another role - // allows this endpoint (i.e. not contained in the disabled endpoints for this role), - // the access is allowed. - - // make list mutable - List remainingEndpoints = new LinkedList<>(Arrays.asList(Endpoint.values())); - - // only retain endpoints contained in all roles for user - boolean hasDisabledEndpoints = false; - for (String userRole : userRoles) { - Map> endpointsForRole = disabledEndpointsForRoles.get(userRole); - if (endpointsForRole == null || endpointsForRole.isEmpty()) { - continue; - } - Set disabledEndpoints = endpointsForRole.keySet(); - remainingEndpoints.retainAll(disabledEndpoints); - hasDisabledEndpoints = true; - } - - if (isDebugEnabled) { - logger.debug("Remaining endpoints for user {} after retaining all : {}", userPrincipal, remainingEndpoints); - } - - // if user does not have any disabled endpoints, only globally disabled endpoints apply - if (!hasDisabledEndpoints) { - - if (isDebugEnabled) { - logger.debug("No disabled endpoints for user {} at all, only globally disabledendpoints apply.", userPrincipal, remainingEndpoints); - } - disabledEndpointsForUsers.put(userPrincipal, addGloballyDisabledEndpoints(finalEndpoints)); - return finalEndpoints; - - } - - // one or more disabled remaining endpoints, keep only - // methods contained in all roles for each endpoint - for (Endpoint endpoint : remainingEndpoints) { - // make list mutable - List remainingMethodsForEndpoint = new LinkedList<>(Arrays.asList(Method.values())); - for (String userRole : userRoles) { - Map> endpoints = disabledEndpointsForRoles.get(userRole); - if (endpoints != null && !endpoints.isEmpty()) { - remainingMethodsForEndpoint.retainAll(endpoints.get(endpoint)); - } - } - - finalEndpoints.put(endpoint, remainingMethodsForEndpoint); - } - - if (isDebugEnabled) { - logger.debug("Disabled endpoints for user {} after retaining all : {}", userPrincipal, finalEndpoints); - } - - // add globally disabled endpoints and methods, will always be disabled - addGloballyDisabledEndpoints(finalEndpoints); - disabledEndpointsForUsers.put(userPrincipal, finalEndpoints); - - if (isDebugEnabled) { - logger.debug("Disabled endpoints for user {} after retaining all : {}", userPrincipal, disabledEndpointsForUsers.get(userPrincipal)); - } - - return disabledEndpointsForUsers.get(userPrincipal); - } - - private Map> addGloballyDisabledEndpoints(Map> endpoints) { - if(globallyDisabledEndpoints != null && !globallyDisabledEndpoints.isEmpty()) { - Set globalEndoints = globallyDisabledEndpoints.keySet(); - for(Endpoint endpoint : globalEndoints) { - endpoints.putIfAbsent(endpoint, new LinkedList<>()); - endpoints.get(endpoint).addAll(globallyDisabledEndpoints.get(endpoint)); - } - } - return endpoints; - } - - private String checkRoleBasedAccessPermissions(RestRequest request, Endpoint endpoint) { - final boolean isTraceEnabled = logger.isTraceEnabled(); - if (isTraceEnabled) { - logger.trace("Checking role based admin access for endpoint {} and method {}", endpoint.name(), request.method().name()); - } - final boolean isDebugEnabled = logger.isDebugEnabled(); - // Role based access. Check that user has role suitable for admin access - // and that the role has also access to this endpoint. - if (this.roleBasedAccessEnabled) { - - // get current user and roles + protected final Logger logger = LogManager.getLogger(this.getClass()); + + private final AdminDNs adminDNs; + private final PrivilegesEvaluator privilegesEvaluator; + private final PrincipalExtractor principalExtractor; + private final Path configPath; + private final ThreadPool threadPool; + private final Settings settings; + + private final Set allowedRoles = new HashSet<>(); + + // endpoints per role, read and cached from settings. Changes here require a + // node restart, so it's save to cache. + private final Map>> disabledEndpointsForRoles = new HashMap<>(); + + // endpoints per user, evaluated and cached dynamically. Changes here + // require a node restart, so it's save to cache. + private final Map>> disabledEndpointsForUsers = new HashMap<>(); + + // globally disabled endpoints and methods, will always be forbidden + Map> globallyDisabledEndpoints = new HashMap<>(); + + // all endpoints and methods, will be returned for users that do not have any access at all + Map> allEndpoints = new HashMap<>(); + + private final Boolean roleBasedAccessEnabled; + + public RestApiPrivilegesEvaluator( + final Settings settings, + final AdminDNs adminDNs, + final PrivilegesEvaluator privilegesEvaluator, + final PrincipalExtractor principalExtractor, + final Path configPath, + ThreadPool threadPool + ) { + + this.adminDNs = adminDNs; + this.privilegesEvaluator = privilegesEvaluator; + this.principalExtractor = principalExtractor; + this.configPath = configPath; + this.threadPool = threadPool; + this.settings = settings; + // set up + // all endpoints and methods + Map> allEndpoints = new HashMap<>(); + for (Endpoint endpoint : Endpoint.values()) { + List allMethods = new LinkedList<>(); + allMethods.addAll(Arrays.asList(Method.values())); + allEndpoints.put(endpoint, allMethods); + } + this.allEndpoints = Collections.unmodifiableMap(allEndpoints); + + // setup role based permissions + allowedRoles.addAll(settings.getAsList(ConfigConstants.SECURITY_RESTAPI_ROLES_ENABLED)); + + this.roleBasedAccessEnabled = !allowedRoles.isEmpty(); + + // globally disabled endpoints, disables access to Endpoint/Method combination for all roles + Settings globalSettings = settings.getAsSettings(ConfigConstants.SECURITY_RESTAPI_ENDPOINTS_DISABLED + ".global"); + if (!globalSettings.isEmpty()) { + globallyDisabledEndpoints = parseDisabledEndpoints(globalSettings); + } + + final boolean isDebugEnabled = logger.isDebugEnabled(); + if (isDebugEnabled) { + logger.debug("Globally disabled endpoints: {}", globallyDisabledEndpoints); + } + + for (String role : allowedRoles) { + Settings settingsForRole = settings.getAsSettings(ConfigConstants.SECURITY_RESTAPI_ENDPOINTS_DISABLED + "." + role); + if (settingsForRole.isEmpty()) { + if (isDebugEnabled) { + logger.debug("No disabled endpoints/methods for permitted role {} found, allowing all", role); + } + continue; + } + Map> disabledEndpointsForRole = parseDisabledEndpoints(settingsForRole); + if (!disabledEndpointsForRole.isEmpty()) { + disabledEndpointsForRoles.put(role, disabledEndpointsForRole); + } else { + logger.warn("Disabled endpoints/methods empty for role {}, please check configuration", role); + } + } + if (logger.isTraceEnabled()) { + logger.trace("Parsed permission set for endpoints: {}", disabledEndpointsForRoles); + } + } + + @SuppressWarnings({ "rawtypes" }) + private Map> parseDisabledEndpoints(Settings settings) { + + // Expects Setting like: 'ACTIONGROUPS=["GET", "POST"]' + if (settings == null || settings.isEmpty()) { + logger.error("Settings for disabled endpoint is null or empty: '{}', skipping.", settings); + return Collections.emptyMap(); + } + + final Map> disabledEndpoints = new HashMap>(); + + Map disabledEndpointsSettings = Utils.convertJsonToxToStructuredMap(settings); + + for (Entry value : disabledEndpointsSettings.entrySet()) { + // key is the endpoint, see if it is a valid one + String endpointString = value.getKey().toUpperCase(); + Endpoint endpoint = null; + try { + endpoint = Endpoint.valueOf(endpointString); + } catch (Exception e) { + logger.error("Unknown endpoint '{}' found in configuration, skipping.", endpointString); + continue; + } + // value must be non null + if (value.getValue() == null) { + logger.error("Disabled HTTP methods of endpoint '{}' is null, skipping.", endpointString); + continue; + } + + // value must be an array of methods + if (!(value.getValue() instanceof Collection)) { + logger.error( + "Disabled HTTP methods of endpoint '{}' must be an array, actually is '{}', skipping.", + endpointString, + (value.getValue().toString()) + ); + } + List disabledMethods = new LinkedList<>(); + for (Object disabledMethodObj : (Collection) value.getValue()) { + if (disabledMethodObj == null) { + logger.error("Found null value in disabled HTTP methods of endpoint '{}', skipping.", endpointString); + continue; + } + + if (!(disabledMethodObj instanceof String)) { + logger.error("Found non-String value in disabled HTTP methods of endpoint '{}', skipping.", endpointString); + continue; + } + + String disabledMethodAsString = (String) disabledMethodObj; + + // Provide support for '*', means all methods + if (disabledMethodAsString.trim().equals("*")) { + disabledMethods.addAll(Arrays.asList(Method.values())); + break; + } + // no wild card, disabled method must be one of + // RestRequest.Method + Method disabledMethod = null; + try { + disabledMethod = Method.valueOf(disabledMethodAsString.toUpperCase()); + } catch (Exception e) { + logger.error( + "Invalid HTTP method '{}' found in disabled HTTP methods of endpoint '{}', skipping.", + disabledMethodAsString.toUpperCase(), + endpointString + ); + continue; + } + disabledMethods.add(disabledMethod); + } + + disabledEndpoints.put(endpoint, disabledMethods); + + } + return disabledEndpoints; + } + + /** + * Check if the current request is allowed to use the REST API and the + * requested end point. Using an admin certificate grants all permissions. A + * user/role can have restricted end points. + * + * @return an error message if user does not have access, null otherwise + * TODO: log failed attempt in audit log + */ + public String checkAccessPermissions(RestRequest request, Endpoint endpoint) throws IOException { + + if (logger.isDebugEnabled()) { + logger.debug( + "Checking admin access for endpoint {}, path {} and method {}", + endpoint.name(), + request.path(), + request.method().name() + ); + } + + // Grant permission for Account endpoint. + // Return null to grant access. + if (endpoint == Endpoint.ACCOUNT) { + return null; + } + + String roleBasedAccessFailureReason = checkRoleBasedAccessPermissions(request, endpoint); + // Role based access granted + if (roleBasedAccessFailureReason == null) { + return null; + } + + String certBasedAccessFailureReason = checkAdminCertBasedAccessPermissions(request); + // TLS access granted, skip checking roles + if (certBasedAccessFailureReason == null) { + return null; + } + + return constructAccessErrorMessage(roleBasedAccessFailureReason, certBasedAccessFailureReason); + } + + public Boolean currentUserHasRestApiAccess(Set userRoles) { + + // check if user has any role that grants access + return !Collections.disjoint(allowedRoles, userRoles); + + } + + public Map> getDisabledEndpointsForCurrentUser(String userPrincipal, Set userRoles) { + + final boolean isDebugEnabled = logger.isDebugEnabled(); + + // cache + if (disabledEndpointsForUsers.containsKey(userPrincipal)) { + return disabledEndpointsForUsers.get(userPrincipal); + } + + if (!currentUserHasRestApiAccess(userRoles)) { + return this.allEndpoints; + } + + // will contain the final list of disabled endpoints and methods + Map> finalEndpoints = new HashMap<>(); + + // List of all disabled endpoints for user. Disabled endpoints must be configured in all + // roles to take effect. If a role contains a disabled endpoint, but another role + // allows this endpoint (i.e. not contained in the disabled endpoints for this role), + // the access is allowed. + + // make list mutable + List remainingEndpoints = new LinkedList<>(Arrays.asList(Endpoint.values())); + + // only retain endpoints contained in all roles for user + boolean hasDisabledEndpoints = false; + for (String userRole : userRoles) { + Map> endpointsForRole = disabledEndpointsForRoles.get(userRole); + if (endpointsForRole == null || endpointsForRole.isEmpty()) { + continue; + } + Set disabledEndpoints = endpointsForRole.keySet(); + remainingEndpoints.retainAll(disabledEndpoints); + hasDisabledEndpoints = true; + } + + if (isDebugEnabled) { + logger.debug("Remaining endpoints for user {} after retaining all : {}", userPrincipal, remainingEndpoints); + } + + // if user does not have any disabled endpoints, only globally disabled endpoints apply + if (!hasDisabledEndpoints) { + + if (isDebugEnabled) { + logger.debug( + "No disabled endpoints for user {} at all, only globally disabledendpoints apply.", + userPrincipal, + remainingEndpoints + ); + } + disabledEndpointsForUsers.put(userPrincipal, addGloballyDisabledEndpoints(finalEndpoints)); + return finalEndpoints; + + } + + // one or more disabled remaining endpoints, keep only + // methods contained in all roles for each endpoint + for (Endpoint endpoint : remainingEndpoints) { + // make list mutable + List remainingMethodsForEndpoint = new LinkedList<>(Arrays.asList(Method.values())); + for (String userRole : userRoles) { + Map> endpoints = disabledEndpointsForRoles.get(userRole); + if (endpoints != null && !endpoints.isEmpty()) { + remainingMethodsForEndpoint.retainAll(endpoints.get(endpoint)); + } + } + + finalEndpoints.put(endpoint, remainingMethodsForEndpoint); + } + + if (isDebugEnabled) { + logger.debug("Disabled endpoints for user {} after retaining all : {}", userPrincipal, finalEndpoints); + } + + // add globally disabled endpoints and methods, will always be disabled + addGloballyDisabledEndpoints(finalEndpoints); + disabledEndpointsForUsers.put(userPrincipal, finalEndpoints); + + if (isDebugEnabled) { + logger.debug( + "Disabled endpoints for user {} after retaining all : {}", + userPrincipal, + disabledEndpointsForUsers.get(userPrincipal) + ); + } + + return disabledEndpointsForUsers.get(userPrincipal); + } + + private Map> addGloballyDisabledEndpoints(Map> endpoints) { + if (globallyDisabledEndpoints != null && !globallyDisabledEndpoints.isEmpty()) { + Set globalEndoints = globallyDisabledEndpoints.keySet(); + for (Endpoint endpoint : globalEndoints) { + endpoints.putIfAbsent(endpoint, new LinkedList<>()); + endpoints.get(endpoint).addAll(globallyDisabledEndpoints.get(endpoint)); + } + } + return endpoints; + } + + private String checkRoleBasedAccessPermissions(RestRequest request, Endpoint endpoint) { + final boolean isTraceEnabled = logger.isTraceEnabled(); + if (isTraceEnabled) { + logger.trace("Checking role based admin access for endpoint {} and method {}", endpoint.name(), request.method().name()); + } + final boolean isDebugEnabled = logger.isDebugEnabled(); + // Role based access. Check that user has role suitable for admin access + // and that the role has also access to this endpoint. + if (this.roleBasedAccessEnabled) { + + // get current user and roles final User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); final TransportAddress remoteAddress = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS); - // map the users Security roles - Set userRoles = privilegesEvaluator.mapRoles(user, remoteAddress); - - // check if user has any role that grants access - if (currentUserHasRestApiAccess(userRoles)) { - // yes, calculate disabled end points. Since a user can have - // multiple roles, the endpoint - // needs to be disabled in all roles. - - Map> disabledEndpointsForUser = getDisabledEndpointsForCurrentUser(user.getName(), userRoles); - - if (isDebugEnabled) { - logger.debug("Disabled endpoints for user {} : {} ", user, disabledEndpointsForUser); - } - - // check if we have any disabled methods for this endpoint - List disabledMethodsForEndpoint = disabledEndpointsForUser.get(endpoint); - - // no settings, all methods for this endpoint allowed - if (disabledMethodsForEndpoint == null || disabledMethodsForEndpoint.isEmpty()) { - if (isDebugEnabled) { - logger.debug("No disabled methods for user {} and endpoint {}, access allowed ", user, endpoint); - } - return null; - } - - // some methods disabled, check requested method - if (!disabledMethodsForEndpoint.contains(request.method())) { - if (isDebugEnabled) { - logger.debug("Request method {} for user {} and endpoint {} not restricted, access allowed ", request.method(), user, endpoint); - } - return null; - } - - logger.info("User {} with Security roles {} does not have access to endpoint {} and method {}, checking admin TLS certificate now.", user, userRoles, - endpoint.name(), request.method()); - return "User " + user.getName() + " with Security roles " + userRoles + " does not have any access to endpoint " + endpoint.name() + " and method " - + request.method().name(); - } else { - // no, but maybe the request contains a client certificate. - // Remember error reason for better response message later on. - logger.info("User {} with Security roles {} does not have any role privileged for admin access.", user, userRoles); - return "User " + user.getName() + " with Security roles " + userRoles + " does not have any role privileged for admin access"; - } - } - return "Role based access not enabled."; - } - - private String checkAdminCertBasedAccessPermissions(RestRequest request) throws IOException { - if (logger.isTraceEnabled()) { - logger.trace("Checking certificate based admin access for path {} and method {}", request.path(), request.method().name()); - } - - // Certificate based access, Check if we have an admin TLS certificate - SSLRequestHelper.SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor); - - if (sslInfo == null) { - // here we log on error level, since authentication finally failed - logger.warn("No ssl info found in request."); - return "No ssl info found in request."; - } - - X509Certificate[] certs = sslInfo.getX509Certs(); - - if (certs == null || certs.length == 0) { - logger.warn("No client TLS certificate found in request"); - return "No client TLS certificate found in request"; - } - - if (!adminDNs.isAdminDN(sslInfo.getPrincipal())) { - logger.warn("Security admin permissions required but {} is not an admin", sslInfo.getPrincipal()); - return "Security admin permissions required but " + sslInfo.getPrincipal() + " is not an admin"; - } - return null; - } - - private String constructAccessErrorMessage(String roleBasedAccessFailure, String certBasedAccessFailure) { - return roleBasedAccessFailure + ". " + certBasedAccessFailure; - } + // map the users Security roles + Set userRoles = privilegesEvaluator.mapRoles(user, remoteAddress); + + // check if user has any role that grants access + if (currentUserHasRestApiAccess(userRoles)) { + // yes, calculate disabled end points. Since a user can have + // multiple roles, the endpoint + // needs to be disabled in all roles. + Map> disabledEndpointsForUser = getDisabledEndpointsForCurrentUser(user.getName(), userRoles); + + if (isDebugEnabled) { + logger.debug("Disabled endpoints for user {} : {} ", user, disabledEndpointsForUser); + } + + // check if we have any disabled methods for this endpoint + List disabledMethodsForEndpoint = disabledEndpointsForUser.get(endpoint); + + // no settings, all methods for this endpoint allowed + if (disabledMethodsForEndpoint == null || disabledMethodsForEndpoint.isEmpty()) { + if (isDebugEnabled) { + logger.debug("No disabled methods for user {} and endpoint {}, access allowed ", user, endpoint); + } + return null; + } + + // some methods disabled, check requested method + if (!disabledMethodsForEndpoint.contains(request.method())) { + if (isDebugEnabled) { + logger.debug( + "Request method {} for user {} and endpoint {} not restricted, access allowed ", + request.method(), + user, + endpoint + ); + } + return null; + } + + logger.info( + "User {} with Security roles {} does not have access to endpoint {} and method {}, checking admin TLS certificate now.", + user, + userRoles, + endpoint.name(), + request.method() + ); + return "User " + + user.getName() + + " with Security roles " + + userRoles + + " does not have any access to endpoint " + + endpoint.name() + + " and method " + + request.method().name(); + } else { + // no, but maybe the request contains a client certificate. + // Remember error reason for better response message later on. + logger.info("User {} with Security roles {} does not have any role privileged for admin access.", user, userRoles); + return "User " + + user.getName() + + " with Security roles " + + userRoles + + " does not have any role privileged for admin access"; + } + } + return "Role based access not enabled."; + } + + private String checkAdminCertBasedAccessPermissions(RestRequest request) throws IOException { + if (logger.isTraceEnabled()) { + logger.trace("Checking certificate based admin access for path {} and method {}", request.path(), request.method().name()); + } + + // Certificate based access, Check if we have an admin TLS certificate + final SecurityRequest securityRequest = SecurityRequestFactory.from(request); + SSLRequestHelper.SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, securityRequest, principalExtractor); + + if (sslInfo == null) { + // here we log on error level, since authentication finally failed + logger.warn("No ssl info found in request."); + return "No ssl info found in request."; + } + + X509Certificate[] certs = sslInfo.getX509Certs(); + + if (certs == null || certs.length == 0) { + logger.warn("No client TLS certificate found in request"); + return "No client TLS certificate found in request"; + } + + if (!adminDNs.isAdminDN(sslInfo.getPrincipal())) { + logger.warn("Security admin permissions required but {} is not an admin", sslInfo.getPrincipal()); + return "Security admin permissions required but " + sslInfo.getPrincipal() + " is not an admin"; + } + return null; + } + + private String constructAccessErrorMessage(String roleBasedAccessFailure, String certBasedAccessFailure) { + return roleBasedAccessFailure + ". " + certBasedAccessFailure; + } } diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java new file mode 100644 index 0000000000..e1123ef7ee --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequest.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.net.ssl.SSLEngine; + +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; + +import io.netty.handler.ssl.SslHandler; + +/** + * Wraps the functionality of RestRequest for use in the security plugin + */ +public class OpenSearchRequest implements SecurityRequest { + + protected final RestRequest underlyingRequest; + + OpenSearchRequest(final RestRequest request) { + underlyingRequest = request; + } + + @Override + public Map> getHeaders() { + return underlyingRequest.getHeaders(); + } + + @Override + public SSLEngine getSSLEngine() { + if (underlyingRequest == null + || underlyingRequest.getHttpChannel() == null + || !(underlyingRequest.getHttpChannel() instanceof Netty4HttpChannel)) { + return null; + } + + // We look for Ssl_handler called `ssl_http` in the outbound pipeline of Netty channel first, and if its not + // present we look for it in inbound channel. If its present in neither we return null, else we return the sslHandler. + final Netty4HttpChannel httpChannel = (Netty4HttpChannel) underlyingRequest.getHttpChannel(); + SslHandler sslhandler = (SslHandler) httpChannel.getNettyChannel().pipeline().get("ssl_http"); + return sslhandler != null ? sslhandler.engine() : null; + } + + @Override + public String path() { + return underlyingRequest.path(); + } + + @Override + public Method method() { + return underlyingRequest.method(); + } + + @Override + public Optional getRemoteAddress() { + return Optional.ofNullable(this.underlyingRequest.getHttpChannel().getRemoteAddress()); + } + + @Override + public String uri() { + return underlyingRequest.uri(); + } + + @Override + public Map params() { + return underlyingRequest.params(); + } + + /** Gets access to the underlying request object */ + public RestRequest breakEncapsulationForRequest() { + return underlyingRequest; + } +} diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java new file mode 100644 index 0000000000..293b2af31e --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.rest.RestStatus; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; + +public class OpenSearchRequestChannel extends OpenSearchRequest implements SecurityRequestChannel { + + private final Logger log = LogManager.getLogger(OpenSearchRequest.class); + + private final AtomicReference responseRef = new AtomicReference(null); + private final AtomicBoolean hasCompleted = new AtomicBoolean(false); + private final RestChannel underlyingChannel; + + OpenSearchRequestChannel(final RestRequest request, final RestChannel channel) { + super(request); + underlyingChannel = channel; + } + + /** Gets access to the underlying channel object */ + public RestChannel breakEncapsulationForChannel() { + return underlyingChannel; + } + + @Override + public void queueForSending(final SecurityResponse response) { + if (underlyingChannel == null) { + throw new UnsupportedOperationException("Channel was not defined"); + } + + if (hasCompleted.get()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + if (getQueuedResponse().isPresent()) { + throw new UnsupportedOperationException("Another response was already queued"); + } + + responseRef.set(response); + } + + @Override + public Optional getQueuedResponse() { + return Optional.ofNullable(responseRef.get()); + } + + @Override + public boolean sendResponse() { + if (underlyingChannel == null) { + throw new UnsupportedOperationException("Channel was not defined"); + } + + if (hasCompleted.get()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + if (getQueuedResponse().isEmpty()) { + throw new UnsupportedOperationException("No response has been associated with this channel"); + } + + final SecurityResponse response = responseRef.get(); + + try { + final BytesRestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(response.getStatus()), response.getBody()); + if (response.getHeaders() != null) { + response.getHeaders().forEach(restResponse::addHeader); + } + underlyingChannel.sendResponse(restResponse); + + return true; + } catch (final Exception e) { + log.error("Error when attempting to send response", e); + throw new RuntimeException(e); + } finally { + hasCompleted.set(true); + } + + } +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index ca8bebbeee..4139150718 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -370,6 +370,7 @@ public void onFailure(Exception e) { String.format("no permissions for %s and %s", pres.getMissingPrivileges(), user); } log.debug(err); + listener.onFailure(new OpenSearchSecurityException(err, RestStatus.FORBIDDEN)); } } catch (OpenSearchException e) { diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequest.java b/src/main/java/org/opensearch/security/filter/SecurityRequest.java new file mode 100644 index 0000000000..7e6e94e0a6 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequest.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +import javax.net.ssl.SSLEngine; + +import org.opensearch.rest.RestRequest.Method; + +/** How the security plugin interacts with requests */ +public interface SecurityRequest { + + /** Collection of headers associated with the request */ + public Map> getHeaders(); + + /** The SSLEngine associated with the request */ + public SSLEngine getSSLEngine(); + + /** The path of the request */ + public String path(); + + /** The method type of this request */ + public Method method(); + + /** The remote address of the request, possible null */ + public Optional getRemoteAddress(); + + /** The full uri of the request */ + public String uri(); + + /** Finds the first value of the matching header or null */ + default public String header(final String headerName) { + final Optional>> headersMap = Optional.ofNullable(getHeaders()); + return headersMap.map(headers -> headers.get(headerName)).map(List::stream).flatMap(Stream::findFirst).orElse(null); + } + + /** The parameters associated with this request */ + public Map params(); +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java new file mode 100644 index 0000000000..1eec754c08 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Optional; + +/** + * When a request is recieved by the security plugin this governs getting information about the request and complete with with a response + */ +public interface SecurityRequestChannel extends SecurityRequest { + + /** Associate a response with this channel */ + public void queueForSending(final SecurityResponse response); + + /** Acess the queued response */ + public Optional getQueuedResponse(); + + /** Send the response through the channel */ + public boolean sendResponse(); +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java new file mode 100644 index 0000000000..bcacc2cf7a --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +/** Thrown when a security rest channel is not supported */ +public class SecurityRequestChannelUnsupported extends RuntimeException { + +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java new file mode 100644 index 0000000000..de74df01ff --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; + +/** + * Generates wrapped versions of requests for use in the security plugin + */ +public class SecurityRequestFactory { + + /** Creates a security request from a RestRequest */ + public static SecurityRequest from(final RestRequest request) { + return new OpenSearchRequest(request); + } + + /** Creates a security request channel from a RestRequest & RestChannel */ + public static SecurityRequestChannel from(final RestRequest request, final RestChannel channel) { + return new OpenSearchRequestChannel(request, channel); + } +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityResponse.java b/src/main/java/org/opensearch/security/filter/SecurityResponse.java new file mode 100644 index 0000000000..8618be3aab --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityResponse.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Map; + +import org.apache.http.HttpHeaders; + +public class SecurityResponse { + + public static final Map CONTENT_TYPE_APP_JSON = Map.of(HttpHeaders.CONTENT_TYPE, "application/json"); + + private final int status; + private final Map headers; + private final String body; + + public SecurityResponse(final int status, final Map headers, final String body) { + this.status = status; + this.headers = headers; + this.body = body; + } + + public int getStatus() { + return status; + } + + public Map getHeaders() { + return headers; + } + + public String getBody() { + return body; + } + +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java index c07c7d918a..4f479d2228 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java @@ -34,24 +34,19 @@ import javax.net.ssl.SSLPeerUnverifiedException; -import org.opensearch.security.configuration.AdminDNs; -import org.opensearch.security.dlic.rest.api.WhitelistApiAction; -import org.opensearch.security.securityconf.impl.WhitelistingSettings; -import org.apache.logging.log4j.Logger; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; -import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; -import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; -import org.opensearch.rest.RestStatus; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.auditlog.AuditLog.Origin; +import org.opensearch.security.configuration.AdminDNs; import org.opensearch.security.configuration.CompatConfig; +import org.opensearch.security.securityconf.impl.WhitelistingSettings; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.ExceptionUtils; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -59,10 +54,12 @@ import org.opensearch.security.support.HTTPHelper; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.security.ssl.util.SSLRequestHelper.SSLInfo;; +import org.opensearch.security.ssl.util.SSLRequestHelper.SSLInfo; import org.opensearch.security.auth.BackendRegistry; import org.opensearch.security.user.User; import org.greenrobot.eventbus.Subscribe; + +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -115,18 +112,35 @@ public SecurityRestFilter(final BackendRegistry registry, final AuditLog auditLo * SuperAdmin is identified by credentials, which can be passed in the curl request. */ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { - return new RestHandler() { - - @Override - public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { - org.apache.logging.log4j.ThreadContext.clearAll(); - if (!checkAndAuthenticateRequest(request, channel, client)) { - User user = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); - if (userIsSuperAdmin(user, adminDNs) || whitelistingSettings.checkRequestIsAllowed(request, channel, client)) { - original.handleRequest(request, channel, client); - } - } + return (request, channel, client) -> { + org.apache.logging.log4j.ThreadContext.clearAll(); + final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel); + + // Authenticate request + checkAndAuthenticateRequest(requestChannel); + if (requestChannel.getQueuedResponse().isPresent()) { + requestChannel.sendResponse(); + return; + } + + // Authorize Request + final User user = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + if (userIsSuperAdmin(user, adminDNs)) { + // Super admins are always authorized + original.handleRequest(request, channel, client); + return; + } + + final Optional deniedResponse = whitelistingSettings.checkRequestIsAllowed(requestChannel); + + if (deniedResponse.isPresent()) { + requestChannel.queueForSending(deniedResponse.orElseThrow()); + requestChannel.sendResponse(); + return; } + + // Caller was authorized, forward the request to the handler + original.handleRequest(request, channel, client); }; } @@ -137,31 +151,31 @@ private boolean userIsSuperAdmin(User user, AdminDNs adminDNs) { return user != null && adminDNs.isAdmin(user); } - private boolean checkAndAuthenticateRequest(RestRequest request, RestChannel channel, - NodeClient client) throws Exception { - + public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) throws Exception { threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN, Origin.REST.toString()); - - if(HTTPHelper.containsBadHeader(request)) { + + if (HTTPHelper.containsBadHeader(requestChannel)) { final OpenSearchException exception = ExceptionUtils.createBadHeaderException(); log.error(exception.toString()); - auditLog.logBadHeaders(request); - channel.sendResponse(new BytesRestResponse(channel, RestStatus.FORBIDDEN, exception)); - return true; + auditLog.logBadHeaders(requestChannel); + + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, exception.toString())); + return; } if(SSLRequestHelper.containsBadHeader(threadContext, ConfigConstants.OPENDISTRO_SECURITY_CONFIG_PREFIX)) { final OpenSearchException exception = ExceptionUtils.createBadHeaderException(); log.error(exception.toString()); - auditLog.logBadHeaders(request); - channel.sendResponse(new BytesRestResponse(channel, RestStatus.FORBIDDEN, exception)); - return true; + auditLog.logBadHeaders(requestChannel); + + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, exception.toString())); + return; } final SSLInfo sslInfo; try { - if((sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor)) != null) { - if(sslInfo.getPrincipal() != null) { + if ((sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, requestChannel, principalExtractor)) != null) { + if (sslInfo.getPrincipal() != null) { threadContext.putTransient("_opendistro_security_ssl_principal", sslInfo.getPrincipal()); } @@ -173,29 +187,28 @@ private boolean checkAndAuthenticateRequest(RestRequest request, RestChannel cha } } catch (SSLPeerUnverifiedException e) { log.error("No ssl info", e); - auditLog.logSSLException(request, e); - channel.sendResponse(new BytesRestResponse(channel, RestStatus.FORBIDDEN, e)); - return true; + auditLog.logSSLException(requestChannel, e); + requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, null)); + return; } - - if(!compatConfig.restAuthEnabled()) { - return false; + + if (!compatConfig.restAuthEnabled()) { + // Authentication is disabled + return; } - Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(requestChannel.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if(request.method() != Method.OPTIONS && !(HEALTH_SUFFIX.equals(suffix))) { - if (!registry.authenticate(request, channel, threadContext)) { + if (requestChannel.method() != Method.OPTIONS && !(HEALTH_SUFFIX.equals(suffix))) { + if (!registry.authenticate(requestChannel)) { // another roundtrip org.apache.logging.log4j.ThreadContext.remove("user"); - return true; + return; } else { // make it possible to filter logs by username org.apache.logging.log4j.ThreadContext.put("user", ((User)threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER)).getName()); } } - - return false; } @Subscribe diff --git a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java index f52f7744bd..d47d9f6859 100644 --- a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java @@ -31,17 +31,17 @@ package org.opensearch.security.http; import java.nio.file.Path; +import java.util.Map; +import java.util.Optional; -import org.apache.logging.log4j.Logger; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; - import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.support.HTTPHelper; import org.opensearch.security.user.AuthCredentials; @@ -55,11 +55,11 @@ public HTTPBasicAuthenticator(final Settings settings, final Path configPath) { } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext threadContext) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) { - final boolean forceLogin = request.paramAsBoolean("force_login", false); - - if(forceLogin) { + final boolean forceLogin = Boolean.getBoolean(request.params().get("force_login")); + + if (forceLogin) { return null; } @@ -69,11 +69,10 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "Unauthorized"); - wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""), "") + ); } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java index f3ee65f052..9bb72e9a62 100644 --- a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java @@ -34,6 +34,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import javax.naming.InvalidNameException; import javax.naming.ldap.LdapName; @@ -44,10 +45,9 @@ import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; - import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.AuthCredentials; @@ -61,7 +61,7 @@ public HTTPClientCertAuthenticator(final Settings settings, final Path configPat } @Override - public AuthCredentials extractCredentials(final RestRequest request, final ThreadContext threadContext) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext threadContext) { final String principal = threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_PRINCIPAL); @@ -102,8 +102,8 @@ public AuthCredentials extractCredentials(final RestRequest request, final Threa } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public Optional reRequestAuthentication(final SecurityRequest response, AuthCredentials creds) { + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java index 4320bd6009..7a1caf310a 100644 --- a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java @@ -31,6 +31,7 @@ package org.opensearch.security.http; import java.nio.file.Path; +import java.util.Optional; import java.util.regex.Pattern; import org.apache.logging.log4j.Logger; @@ -39,10 +40,9 @@ import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; - import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.AuthCredentials; import com.google.common.base.Predicates; @@ -60,9 +60,9 @@ public HTTPProxyAuthenticator(Settings settings, final Path configPath) { } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext context) { - - if(context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE) != Boolean.TRUE) { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) { + + if (context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE) != Boolean.TRUE) { throw new OpenSearchSecurityException("xff not done"); } @@ -94,8 +94,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public Optional reRequestAuthentication(final SecurityRequest response, AuthCredentials creds) { + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/security/http/RemoteIpDetector.java b/src/main/java/org/opensearch/security/http/RemoteIpDetector.java index 0edb3552ea..296501045d 100644 --- a/src/main/java/org/opensearch/security/http/RemoteIpDetector.java +++ b/src/main/java/org/opensearch/security/http/RemoteIpDetector.java @@ -47,6 +47,7 @@ package org.opensearch.security.http; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.LinkedList; import java.util.List; @@ -55,8 +56,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.LogManager; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestRequest; - +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.support.ConfigConstants; final class RemoteIpDetector { @@ -117,8 +117,12 @@ public String getRemoteIpHeader() { return remoteIpHeader; } - String detect(RestRequest request, ThreadContext threadContext){ - final String originalRemoteAddr = ((InetSocketAddress)request.getHttpChannel().getRemoteAddress()).getAddress().getHostAddress(); + String detect(SecurityRequest request, ThreadContext threadContext) { + + final String originalRemoteAddr = request.getRemoteAddress() + .map(InetSocketAddress::getAddress) + .map(InetAddress::getHostAddress) + .orElseThrow(); final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { @@ -176,8 +180,17 @@ String detect(RestRequest request, ThreadContext threadContext){ if (remoteIp != null) { if (isTraceEnabled) { - final String originalRemoteHost = ((InetSocketAddress)request.getHttpChannel().getRemoteAddress()).getAddress().getHostName(); - log.trace("Incoming request {} with originalRemoteAddr '{}', originalRemoteHost='{}', will be seen as newRemoteAddr='{}'", request.uri(), originalRemoteAddr, originalRemoteHost, remoteIp); + final String originalRemoteHost = request.getRemoteAddress() + .map(InetSocketAddress::getAddress) + .map(InetAddress::getHostName) + .orElseThrow(); + log.trace( + "Incoming request {} with originalRemoteAddr '{}', originalRemoteHost='{}', will be seen as newRemoteAddr='{}'", + request.uri(), + originalRemoteAddr, + originalRemoteHost, + remoteIp + ); } threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE, Boolean.TRUE); @@ -189,7 +202,11 @@ String detect(RestRequest request, ThreadContext threadContext){ } else { if (isTraceEnabled) { - log.trace("Skip RemoteIpDetector for request {} with originalRemoteAddr '{}' cause no internal proxy matches", request.uri(), request.getHttpChannel().getRemoteAddress()); + log.trace( + "Skip RemoteIpDetector for request {} with originalRemoteAddr '{}' cause no internal proxy matches", + request.uri(), + request.getRemoteAddress().orElse(null) + ); } } diff --git a/src/main/java/org/opensearch/security/http/XFFResolver.java b/src/main/java/org/opensearch/security/http/XFFResolver.java index e85e943a41..6b2cbbc7ee 100644 --- a/src/main/java/org/opensearch/security/http/XFFResolver.java +++ b/src/main/java/org/opensearch/security/http/XFFResolver.java @@ -36,9 +36,11 @@ import org.apache.logging.log4j.LogManager; import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.transport.TransportAddress; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.rest.RestRequest; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.securityconf.DynamicConfigModel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.threadpool.ThreadPool; @@ -56,38 +58,49 @@ public XFFResolver(final ThreadPool threadPool) { this.threadContext = threadPool.getThreadContext(); } - public TransportAddress resolve(final RestRequest request) throws OpenSearchSecurityException { + public TransportAddress resolve(final SecurityRequest request) throws OpenSearchSecurityException { final boolean isTraceEnabled = log.isTraceEnabled(); if (isTraceEnabled) { - log.trace("resolve {}", request.getHttpChannel().getRemoteAddress()); + log.trace("resolve {}", request.getRemoteAddress().orElse(null)); } - if(enabled && request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress && request.getHttpChannel() instanceof Netty4HttpChannel) { + boolean requestFromNetty = false; + if (request instanceof OpenSearchRequest) { + final OpenSearchRequest securityRequestChannel = (OpenSearchRequest) request; + final RestRequest restRequest = securityRequestChannel.breakEncapsulationForRequest(); - final InetSocketAddress isa = new InetSocketAddress(detector.detect(request, threadContext), ((InetSocketAddress)request.getHttpChannel().getRemoteAddress()).getPort()); - - if(isa.isUnresolved()) { - throw new OpenSearchSecurityException("Cannot resolve address "+isa.getHostString()); + requestFromNetty = restRequest.getHttpChannel() instanceof Netty4HttpChannel; + } + + if (enabled && request.getRemoteAddress().isPresent() && requestFromNetty) { + final InetSocketAddress remoteAddress = request.getRemoteAddress().get(); + final InetSocketAddress isa = new InetSocketAddress(detector.detect(request, threadContext), remoteAddress.getPort()); + + if (isa.isUnresolved()) { + throw new OpenSearchSecurityException("Cannot resolve address " + isa.getHostString()); } if (isTraceEnabled) { - if(threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE) == Boolean.TRUE) { - log.trace("xff resolved {} to {}", request.getHttpChannel().getRemoteAddress(), isa); + if (threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_XFF_DONE) == Boolean.TRUE) { + log.trace("xff resolved {} to {}", remoteAddress, isa); } else { log.trace("no xff done for {}",request.getClass()); } } return new TransportAddress(isa); - } else if(request.getHttpChannel().getRemoteAddress() instanceof InetSocketAddress){ - + } else if (request.getRemoteAddress().isPresent()) { if (isTraceEnabled) { - log.trace("no xff done (enabled or no netty request) {},{},{},{}",enabled, request.getClass()); - + log.trace("no xff done (enabled or no netty request) {},{},{},{}", enabled, request.getClass()); } - return new TransportAddress((InetSocketAddress)request.getHttpChannel().getRemoteAddress()); + return new TransportAddress((InetSocketAddress) request.getRemoteAddress().get()); } else { - throw new OpenSearchSecurityException("Cannot handle this request. Remote address is "+request.getHttpChannel().getRemoteAddress()+" with request class "+request.getClass()); + throw new OpenSearchSecurityException( + "Cannot handle this request. Remote address is " + + request.getRemoteAddress().orElse(null) + + " with request class " + + request.getClass() + ); } } diff --git a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java index a73cf0a233..c1db89e5cf 100644 --- a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java @@ -32,6 +32,7 @@ import java.nio.file.Path; import java.util.List; +import java.util.Optional; import java.util.Map.Entry; import org.apache.logging.log4j.Logger; @@ -39,9 +40,8 @@ import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; - +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.http.HTTPProxyAuthenticator; import org.opensearch.security.user.AuthCredentials; import com.google.common.base.Joiner; @@ -59,12 +59,12 @@ public HTTPExtendedProxyAuthenticator(Settings settings, final Path configPath) } @Override - public AuthCredentials extractCredentials(final RestRequest request, ThreadContext context) { - AuthCredentials credentials = super.extractCredentials(request, context); - if(credentials == null) { - return null; - } - + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) { + AuthCredentials credentials = super.extractCredentials(request, context); + if (credentials == null) { + return null; + } + String attrHeaderPrefix = settings.get("attr_header_prefix"); if(Strings.isNullOrEmpty(attrHeaderPrefix)) { log.debug("attr_header_prefix is null. Skipping additional attribute extraction"); @@ -89,8 +89,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public Optional reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) { + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java index 37ed930a40..14ae972685 100644 --- a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java +++ b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java @@ -15,15 +15,17 @@ package org.opensearch.security.securityconf.impl; -import org.opensearch.client.node.NodeClient; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; + +import org.apache.http.HttpStatus; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.rest.RestStatus; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; public class WhitelistingSettings { private boolean enabled; @@ -79,7 +81,7 @@ public String toString() { * GET /_cluster/settings - OK * GET /_cluster/settings/ - OK */ - private boolean requestIsWhitelisted(RestRequest request){ + private boolean requestIsWhitelisted(final SecurityRequest request) { //ALSO ALLOWS REQUEST TO HAVE TRAILING '/' //pathWithoutTrailingSlash stores the endpoint path without extra '/'. eg: /_cat/nodes @@ -111,17 +113,30 @@ private boolean requestIsWhitelisted(RestRequest request){ * then all PUT /_opendistro/_security/api/rolesmapping/{resource_name} work. * Currently, each resource_name has to be whitelisted separately */ - public boolean checkRequestIsAllowed(RestRequest request, RestChannel channel, - NodeClient client) throws IOException { + public Optional checkRequestIsAllowed(final SecurityRequest request) { // if whitelisting is enabled but the request is not whitelisted, then return false, otherwise true. - if (this.enabled && !requestIsWhitelisted(request)){ - channel.sendResponse(new BytesRestResponse(RestStatus.FORBIDDEN, channel.newErrorBuilder().startObject() - .field("error", request.method() + " " + request.path() + " API not whitelisted") - .field("status", RestStatus.FORBIDDEN) - .endObject() - )); - return false; + if (this.enabled && !requestIsWhitelisted(request)) { + return Optional.of( + new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request)) + ); + } + return Optional.empty(); + } + + protected String getVerb() { + return "whitelisted"; + } + + protected String generateFailureMessage(final SecurityRequest request) { + try { + return XContentFactory.jsonBuilder() + .startObject() + .field("error", request.method() + " " + request.path() + " API not " + getVerb()) + .field("status", RestStatus.FORBIDDEN) + .endObject() + .toString(); + } catch (final IOException ioe) { + throw new RuntimeException(ioe); } - return true; } } diff --git a/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java b/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java index 531711dc54..adcd1588af 100644 --- a/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java +++ b/src/main/java/org/opensearch/security/ssl/SslExceptionHandler.java @@ -17,14 +17,14 @@ package org.opensearch.security.ssl; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequest; public interface SslExceptionHandler { - - default void logError(Throwable t, RestRequest request, int type) { - //no-op + + default void logError(Throwable t, SecurityRequestChannel request, int type) { + // no-op } default void logError(Throwable t, boolean isRest) { diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java b/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java index c4129c08cf..5eb80dfcab 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/ValidatingDispatcher.java @@ -32,7 +32,8 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; - +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.util.ExceptionUtils; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -59,19 +60,19 @@ public ValidatingDispatcher(final ThreadContext threadContext, final Dispatcher @Override public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { - checkRequest(request, channel); + checkRequest(SecurityRequestFactory.from(request, channel)); originalDispatcher.dispatchRequest(request, channel, threadContext); } @Override public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, Throwable cause) { - checkRequest(channel.request(), channel); + checkRequest(SecurityRequestFactory.from(channel.request(), channel)); originalDispatcher.dispatchBadRequest(channel, threadContext, cause); } - - protected void checkRequest(final RestRequest request, final RestChannel channel) { - - if(SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { + + protected void checkRequest(final SecurityRequestChannel request) { + + if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { final OpenSearchException exception = ExceptionUtils.createBadHeaderException(); errorHandler.logError(exception, request, 1); throw exception; diff --git a/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java b/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java index b20e4084f7..38134fc36a 100644 --- a/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java +++ b/src/main/java/org/opensearch/security/ssl/rest/SecuritySSLInfoAction.java @@ -39,7 +39,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestStatus; - +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.ssl.SecurityKeyStore; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.util.SSLRequestHelper; @@ -82,10 +82,14 @@ public void accept(RestChannel channel) throws Exception { BytesRestResponse response = null; try { - - SSLInfo sslInfo = SSLRequestHelper.getSSLInfo(settings, configPath, request, principalExtractor); - X509Certificate[] certs = sslInfo == null?null:sslInfo.getX509Certs(); - X509Certificate[] localCerts = sslInfo == null?null:sslInfo.getLocalCertificates(); + SSLInfo sslInfo = SSLRequestHelper.getSSLInfo( + settings, + configPath, + SecurityRequestFactory.from(request), + principalExtractor + ); + X509Certificate[] certs = sslInfo == null ? null : sslInfo.getX509Certs(); + X509Certificate[] localCerts = sslInfo == null ? null : sslInfo.getLocalCertificates(); builder.startObject(); diff --git a/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java b/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java index 26d35e3dfb..f33c98d552 100644 --- a/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java +++ b/src/main/java/org/opensearch/security/ssl/util/SSLRequestHelper.java @@ -38,6 +38,7 @@ import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; +import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.LogManager; import org.opensearch.OpenSearchException; @@ -45,9 +46,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.env.Environment; -import org.opensearch.http.netty4.Netty4HttpChannel; -import org.opensearch.rest.RestRequest; - +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.transport.PrincipalExtractor.Type; @@ -103,19 +102,18 @@ public String toString() { } - public static SSLInfo getSSLInfo(final Settings settings, final Path configPath, final RestRequest request, PrincipalExtractor principalExtractor) throws SSLPeerUnverifiedException { - - if(request == null || request.getHttpChannel() == null || !(request.getHttpChannel() instanceof Netty4HttpChannel)) { + @SuppressWarnings("removal") + public static SSLInfo getSSLInfo( + final Settings settings, + final Path configPath, + final SecurityRequest request, + PrincipalExtractor principalExtractor + ) throws SSLPeerUnverifiedException { + final SSLEngine engine = request.getSSLEngine(); + if (engine == null) { return null; } - final SslHandler sslhandler = (SslHandler) ((Netty4HttpChannel)request.getHttpChannel()).getNettyChannel().pipeline().get("ssl_http"); - - if(sslhandler == null) { - return null; - } - - final SSLEngine engine = sslhandler.engine(); final SSLSession session = engine.getSession(); X509Certificate[] x509Certs = null; diff --git a/src/main/java/org/opensearch/security/support/HTTPHelper.java b/src/main/java/org/opensearch/security/support/HTTPHelper.java index 0e08d86764..809c74774c 100644 --- a/src/main/java/org/opensearch/security/support/HTTPHelper.java +++ b/src/main/java/org/opensearch/security/support/HTTPHelper.java @@ -36,8 +36,7 @@ import java.util.Map; import org.apache.logging.log4j.Logger; -import org.opensearch.rest.RestRequest; - +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.user.AuthCredentials; public class HTTPHelper { @@ -87,9 +86,9 @@ public static AuthCredentials extractCredentials(String authorizationHeader, Log return null; } } - - public static boolean containsBadHeader(final RestRequest request) { - + + public static boolean containsBadHeader(final SecurityRequest request) { + final Map> headers; if (request != null && ( headers = request.getHeaders()) != null) { diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java index 24ccd41aac..6d96536a34 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java @@ -62,7 +62,7 @@ public void testNoKey() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -79,7 +79,7 @@ public void testEmptyKey() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -96,7 +96,7 @@ public void testBadKey() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -110,8 +110,12 @@ public void testTokenMissing() throws Exception { HTTPJwtAuthenticator jwtAuth = new HTTPJwtAuthenticator(settings, null); Map headers = new HashMap(); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); - Assert.assertNull(creds); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + + Assert.assertNull(credentials); } @Test @@ -127,8 +131,11 @@ public void testInvalid() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); - Assert.assertNull(creds); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + Assert.assertNull(credentials); } @Test @@ -144,11 +151,15 @@ public void testBearer() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); - Assert.assertNotNull(creds); - Assert.assertEquals("Leonard McCoy", creds.getUsername()); - Assert.assertEquals(0, creds.getBackendRoles().size()); - Assert.assertEquals(2, creds.getAttributes().size()); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(credentials); + Assert.assertEquals("Leonard McCoy", credentials.getUsername()); + Assert.assertEquals(0, credentials.getBackendRoles().size()); + Assert.assertEquals(2, credentials.getAttributes().size()); } @Test @@ -164,9 +175,10 @@ public void testBearerWrongPosition() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken + "Bearer " + " 123"); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); - Assert.assertNull(creds); - } + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); @Test public void testNonBearer() throws Exception { @@ -181,7 +193,7 @@ public void testNonBearer() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -200,9 +212,11 @@ public void testBasicAuthHeader() throws Exception { String basicAuth = BaseEncoding.base64().encode("user:password".getBytes(StandardCharsets.UTF_8)); Map headers = Collections.singletonMap(HttpHeaders.AUTHORIZATION, "Basic " + basicAuth); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, Collections.emptyMap()), null); - Assert.assertNull(creds); - Mockito.verifyNoInteractions(jwtParser); + AuthCredentials credentials = jwtAuth.extractCredentials( + new FakeRestRequest(headers, Collections.emptyMap()).asSecurityRequest(), + null + ); + Assert.assertNull(credentials); } @Test @@ -224,7 +238,7 @@ public void testRoles() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(2, creds.getBackendRoles().size()); @@ -249,7 +263,7 @@ public void testNullClaim() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -274,7 +288,7 @@ public void testNonStringClaim() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(1, creds.getBackendRoles().size()); @@ -299,7 +313,7 @@ public void testRolesMissing() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -324,7 +338,7 @@ public void testWrongSubjectKey() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -348,7 +362,7 @@ public void testAlternativeSubject() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("Dr. Who", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -374,7 +388,7 @@ public void testNonStringAlternativeSubject() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("false", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -399,10 +413,11 @@ public void testUrlParam() throws Exception { FakeRestRequest req = new FakeRestRequest(headers, new HashMap()); req.params().put("abc", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(req, null); - Assert.assertNotNull(creds); - Assert.assertEquals("Leonard McCoy", creds.getUsername()); - Assert.assertEquals(0, creds.getBackendRoles().size()); + AuthCredentials credentials = jwtAuth.extractCredentials(req.asSecurityRequest(), null); + + Assert.assertNotNull(credentials); + Assert.assertEquals("Leonard McCoy", credentials.getUsername()); + Assert.assertEquals(0, credentials.getBackendRoles().size()); } @Test @@ -423,7 +438,7 @@ public void testExp() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -445,7 +460,7 @@ public void testNbf() throws Exception { Map headers = new HashMap(); headers.put("Authorization", jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -465,7 +480,11 @@ public void testRS256() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -487,7 +506,11 @@ public void testES512() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), + null + ); + Assert.assertNotNull(creds); Assert.assertEquals("Leonard McCoy", creds.getUsername()); Assert.assertEquals(0, creds.getBackendRoles().size()); @@ -514,7 +537,7 @@ public void rolesArray() throws Exception { Map headers = new HashMap(); headers.put("Authorization", "Bearer "+jwsToken); - AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()), null); + AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(headers, new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals("John Doe", creds.getUsername()); Assert.assertEquals(3, creds.getBackendRoles().size()); diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index bdc882f968..3317ddd7d7 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -53,7 +53,7 @@ public void basicTest() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest( - ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()), null); + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -69,7 +69,7 @@ public void testEscapeKid() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest( - ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_INVALID_KID), new HashMap()), null); + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_INVALID_KID).asSecurityRequest(), new HashMap()), null); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -86,7 +86,7 @@ public void bearerTest() { AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest(ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1), - new HashMap()), + new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); @@ -104,7 +104,7 @@ public void testRoles() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest( - ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()), null); + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -119,7 +119,7 @@ public void testExp() throws Exception { AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_EXPIRED_SIGNED_OCT_1), - new HashMap()), + new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); @@ -142,7 +142,7 @@ public void testExpInSkew() throws Exception { ImmutableMap.of( "Authorization", "bearer "+TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), - new HashMap()), + new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); @@ -165,7 +165,7 @@ public void testNbf() throws Exception { ImmutableMap.of( "Authorization", "bearer "+TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), - new HashMap()), + new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); @@ -186,7 +186,7 @@ public void testNbfInSkew() throws Exception { AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest( ImmutableMap.of("Authorization", "bearer "+TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), - new HashMap()), + new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); @@ -201,7 +201,7 @@ public void testRS256() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest( - ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()), null); + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -218,7 +218,7 @@ public void testBadSignature() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest( - ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap()), null); + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap()).asSecurityRequest(), null); Assert.assertNull(creds); } @@ -230,7 +230,7 @@ public void testPeculiarJsonEscaping() { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.PeculiarEscaping.MC_COY_SIGNED_RSA_1), new HashMap()), + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.PeculiarEscaping.MC_COY_SIGNED_RSA_1), new HashMap()).asSecurityRequest(), null); Assert.assertNotNull(creds); diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index 4b69d93c74..1878e591a1 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -35,10 +35,11 @@ public void basicTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); - AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), - new HashMap()), - null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -63,10 +64,11 @@ public void wrongSigTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); - AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), - new HashMap()), - null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNull(creds); @@ -87,10 +89,11 @@ public void noAlgTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); - AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), - new HashMap()), - null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -114,10 +117,11 @@ public void mismatchedAlgTest() throws Exception { HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); - AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), - new HashMap()), - null); + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNull(creds); @@ -138,12 +142,12 @@ public void keyExchangeTest() throws Exception { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); - - try { - AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), - new HashMap()), - null); + try { + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); @@ -152,24 +156,24 @@ public void keyExchangeTest() throws Exception { Assert.assertEquals(3, creds.getAttributes().size()); creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), - new HashMap()), - null); - + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNull(creds); creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), - new HashMap()), - null); - + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNull(creds); creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), - new HashMap()), - null); - + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud")); @@ -187,12 +191,12 @@ public void keyExchangeTest() throws Exception { mockIdpServer = new MockIpdServer(TestJwk.Jwks.RSA_2); settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); //port changed jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); - - try { - AuthCredentials creds = jwtAuth.extractCredentials( - new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), - new HashMap()), - null); + try { + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_2), new HashMap()) + .asSecurityRequest(), + null + ); Assert.assertNotNull(creds); Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index e31121b358..d8a40f7341 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -29,6 +29,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -37,25 +38,25 @@ import org.opensearch.security.DefaultObjectMapper; import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer; import org.apache.cxf.rs.security.jose.jwt.JwtToken; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.opensaml.saml.saml2.core.NameIDType; + import org.opensearch.common.bytes.BytesArray; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestResponse; import org.opensearch.rest.RestStatus; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; -import org.opensaml.saml.saml2.core.NameIDType; - +import org.opensearch.security.filter.SecurityRequestFactory; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.test.helper.file.FileHelper; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.util.FakeRestRequest; @@ -64,6 +65,7 @@ import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_CONTENT; import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_URL; +import static org.hamcrest.MatcherAssert.assertThat; public class HTTPSamlAuthenticatorTest { protected MockSamlIdpServer mockSamlIdpServer; @@ -137,14 +139,13 @@ public void basicTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -155,6 +156,17 @@ public void basicTest() throws Exception { Assert.assertEquals("horst", jwt.getClaim("sub")); } + private Optional sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) { + final SecurityRequest tokenRestChannel = SecurityRequestFactory.from(request); + + return samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + } + + private String getResponse(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) throws Exception { + SecurityResponse response = sendToAuthenticator(samlAuthenticator, request).orElseThrow(); + return response.getBody(); + } + @Test public void decryptAssertionsTest() throws Exception { mockSamlIdpServer.setAuthenticateUser("horst"); @@ -175,14 +187,12 @@ public void decryptAssertionsTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -214,14 +224,12 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -256,14 +264,12 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -298,14 +304,12 @@ public void shouldNotEscapeSamlEntities() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -342,12 +346,9 @@ public void testMetadataBody() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, new TypeReference>() { }); String authorization = (String) response.get("authorization"); @@ -392,16 +393,17 @@ public void unsolicitedSsoTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.createUnsolicitedSamlResponse(); - RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, null, - "/opendistrosecurity/saml/acs/idpinitiated"); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + RestRequest tokenRestRequest = buildTokenExchangeRestRequest( + encodedSamlResponse, + null, + "/opendistrosecurity/saml/acs/idpinitiated" + ); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -431,13 +433,14 @@ public void badUnsolicitedSsoTest() throws Exception { AuthenticateHeaders authenticateHeaders = new AuthenticateHeaders("http://wherever/opendistrosecurity/saml/acs/", "wrong_request_id"); - RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders, - "/opendistrosecurity/saml/acs/idpinitiated"); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + RestRequest tokenRestRequest = buildTokenExchangeRestRequest( + encodedSamlResponse, + authenticateHeaders, + "/opendistrosecurity/saml/acs/idpinitiated" + ); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); + Assert.assertEquals(RestStatus.UNAUTHORIZED.getStatus(), response.getStatus()); } @Test @@ -460,11 +463,9 @@ public void wrongCertTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); + Assert.assertEquals(401, response.getStatus()); } @Test @@ -484,11 +485,9 @@ public void noSignatureTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); + Assert.assertEquals(401, response.getStatus()); } @SuppressWarnings("unchecked") @@ -511,14 +510,12 @@ public void rolesTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -549,14 +546,12 @@ public void idpEndpointWithQueryStringTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -598,14 +593,12 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -684,10 +677,9 @@ public void initialConnectionFailureTest() throws Exception { HTTPSamlAuthenticator samlAuthenticator = new HTTPSamlAuthenticator(settings, null); RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); - TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + Optional maybeResponse = sendToAuthenticator(samlAuthenticator, restRequest); - Assert.assertNull(restChannel.response); + assertThat(maybeResponse.isPresent(), Matchers.equalTo(false)); mockSamlIdpServer.start(); @@ -703,14 +695,12 @@ public void initialConnectionFailureTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - - String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); - HashMap response = DefaultObjectMapper.objectMapper.readValue(responseJson, - new TypeReference>() { - }); + String responseJson = getResponse(samlAuthenticator, tokenRestRequest); + HashMap response = DefaultObjectMapper.objectMapper.readValue( + responseJson, + new TypeReference>() { + } + ); String authorization = (String) response.get("authorization"); Assert.assertNotNull("Expected authorization attribute in JSON: " + responseJson, authorization); @@ -724,17 +714,11 @@ public void initialConnectionFailureTest() throws Exception { private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuthenticator) { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); - TestRestChannel restChannel = new TestRestChannel(restRequest); - - samlAuthenticator.reRequestAuthentication(restChannel, null); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow(); - List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); + String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate"); - Assert.assertNotNull(wwwAuthenticateHeaders); - Assert.assertEquals("More than one WWW-Authenticate header: " + wwwAuthenticateHeaders, 1, - wwwAuthenticateHeaders.size()); - - String wwwAuthenticateHeader = wwwAuthenticateHeaders.get(0); + Assert.assertNotNull(wwwAuthenticateHeader); Matcher wwwAuthenticateHeaderMatcher = WWW_AUTHENTICATE_PATTERN.matcher(wwwAuthenticateHeader); @@ -796,58 +780,6 @@ public static void initSpSigningKeys() { } } - static class TestRestChannel implements RestChannel { - - final RestRequest restRequest; - RestResponse response; - - TestRestChannel(RestRequest restRequest) { - this.restRequest = restRequest; - } - - @Override - public XContentBuilder newBuilder() throws IOException { - return null; - } - - @Override - public XContentBuilder newErrorBuilder() throws IOException { - return null; - } - - @Override - public XContentBuilder newBuilder(XContentType xContentType, boolean useFiltering) throws IOException { - return null; - } - - @Override - public BytesStreamOutput bytesOutput() { - return null; - } - - @Override - public RestRequest request() { - return restRequest; - } - - @Override - public boolean detailedErrorsEnabled() { - return false; - } - - @Override - public void sendResponse(RestResponse response) { - this.response = response; - - } - - @Override - public XContentBuilder newBuilder(XContentType xContentType, XContentType responseContentType, boolean useFiltering) throws IOException { - return null; - } - - } - static class AuthenticateHeaders { final String location; final String requestId; diff --git a/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java b/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java index d84885f66f..e87f9a77dd 100644 --- a/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java +++ b/src/test/java/org/opensearch/security/auditlog/helper/MockRestRequest.java @@ -20,6 +20,8 @@ import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; public class MockRestRequest extends RestRequest { @@ -48,4 +50,8 @@ public boolean hasContent() { public BytesReference content() { return null; } -} \ No newline at end of file + + public SecurityRequestChannel asSecurityRequest() { + return SecurityRequestFactory.from(this, null); + } +} diff --git a/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java b/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java index c272646613..1885beefcc 100644 --- a/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java +++ b/src/test/java/org/opensearch/security/auditlog/impl/AuditlogTest.java @@ -25,9 +25,11 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.rest.RestRequest; +import org.opensearch.security.auditlog.AuditTestUtils; import org.opensearch.security.auditlog.helper.RetrySink; import org.opensearch.security.auditlog.integration.TestAuditlogImpl; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.transport.TransportRequest; import org.junit.Assert; @@ -90,7 +92,7 @@ public void testSslException() { .build(); AbstractAuditLog al = AuditTestUtils.createAuditLog(settings, null, null, AbstractSecurityUnitTest.MOCK_POOL, null, cs); TestAuditlogImpl.clear(); - al.logSSLException(null, new Exception("test rest")); + al.logSSLException((SecurityRequest)null, new Exception("test rest")); al.logSSLException(null, new Exception("test rest"), null, null); System.out.println(TestAuditlogImpl.sb.toString()); Assert.assertEquals(2, TestAuditlogImpl.messages.size()); @@ -110,7 +112,7 @@ public void testRetry() { .put(ConfigConstants.SECURITY_AUDIT_RETRY_DELAY_MS, 500) .build(); AbstractAuditLog al = AuditTestUtils.createAuditLog(settings, null, null, AbstractSecurityUnitTest.MOCK_POOL, null, cs); - al.logSSLException(null, new Exception("test retry")); + al.logSSLException((SecurityRequest)null, new Exception("test retry")); Assert.assertNotNull(RetrySink.getMsg()); Assert.assertTrue(RetrySink.getMsg().toJson().contains("test retry")); } @@ -129,18 +131,16 @@ public void testNoRetry() { .put(ConfigConstants.SECURITY_AUDIT_RETRY_DELAY_MS, 500) .build(); AbstractAuditLog al = AuditTestUtils.createAuditLog(settings, null, null, AbstractSecurityUnitTest.MOCK_POOL, null, cs); - al.logSSLException(null, new Exception("test retry")); + al.logSSLException((SecurityRequest)null, new Exception("test retry")); Assert.assertNull(RetrySink.getMsg()); } @Test public void testRestFilterEnabledCheck() { - final Settings settings = Settings.builder() - .put(ConfigConstants.OPENDISTRO_SECURITY_AUDIT_ENABLE_REST, false) - .build(); - final AbstractAuditLog al = AuditTestUtils.createAuditLog(settings, null, null, AbstractSecurityUnitTest.MOCK_POOL, null, cs); - for (AuditCategory category: AuditCategory.values()) { - Assert.assertFalse(al.checkRestFilter(category, "user", mock(RestRequest.class))); + final Settings settings = Settings.builder().put(ConfigConstants.OPENDISTRO_SECURITY_AUDIT_ENABLE_REST, false).build(); + final AbstractAuditLog al = AuditTestUtils.createAuditLog(settings, null, null, AbstractSecurityUnitTest.MOCK_POOL, null, cs); + for (AuditCategory category : AuditCategory.values()) { + Assert.assertFalse(al.checkRestFilter(category, "user", mock(SecurityRequestChannel.class))); } } diff --git a/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java b/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java index 454a6a43c2..45aa459cc5 100644 --- a/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java +++ b/src/test/java/org/opensearch/security/auditlog/impl/DisabledCategoriesTest.java @@ -221,12 +221,11 @@ protected void logTransportSucceededLogin(AuditLog auditLog) { auditLog.logSucceededLogin("testuser.transport.succeededlogin", false, "testuser.transport.succeededlogin", new TransportRequest.Empty(), "test/action", new Task(0, "x", "ac", "", null, null)); } - protected void logRestFailedLogin(AuditLog auditLog) { - auditLog.logFailedLogin("testuser.rest.failedlogin", false, "testuser.rest.failedlogin", new MockRestRequest()); - } + auditLog.logFailedLogin("testuser.rest.failedlogin", false, "testuser.rest.failedlogin", new MockRestRequest().asSecurityRequest()); + } - protected void logTransportFailedLogin(AuditLog auditLog) { + protected void logTransportFailedLogin(AuditLog auditLog) { auditLog.logFailedLogin("testuser.transport.failedlogin", false, "testuser.transport.failedlogin", new TransportRequest.Empty(), null); } @@ -239,7 +238,7 @@ protected void logTransportBadHeaders(AuditLog auditLog) { } protected void logRestBadHeaders(AuditLog auditLog) { - auditLog.logBadHeaders(new MockRestRequest()); + auditLog.logBadHeaders(new MockRestRequest().asSecurityRequest()); } protected void logSecurityIndexAttempt(AuditLog auditLog) { @@ -247,7 +246,7 @@ protected void logSecurityIndexAttempt(AuditLog auditLog) { } protected void logRestSSLException(AuditLog auditLog) { - auditLog.logSSLException(new MockRestRequest(), new Exception()); + auditLog.logSSLException(new MockRestRequest().asSecurityRequest(), new Exception()); } protected void logTransportSSLException(AuditLog auditLog) { diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java index f6b79428c8..77b4917443 100644 --- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java +++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java @@ -16,14 +16,14 @@ package org.opensearch.security.cache; import java.nio.file.Path; +import java.util.Optional; import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; - import org.opensearch.security.auth.HTTPAuthenticator; +import org.opensearch.security.filter.SecurityRequest; +import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.user.AuthCredentials; public class DummyHTTPAuthenticator implements HTTPAuthenticator { @@ -39,14 +39,15 @@ public String getType() { } @Override - public AuthCredentials extractCredentials(RestRequest request, ThreadContext context) throws OpenSearchSecurityException { + public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context) + throws OpenSearchSecurityException { count++; return new AuthCredentials("dummy").markComplete(); } @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials credentials) { - return false; + public Optional reRequestAuthentication(SecurityRequest channel, AuthCredentials credentials) { + return Optional.empty(); } public static long getCount() { diff --git a/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java b/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java index 49a5795548..e1b50b3f9c 100644 --- a/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java +++ b/src/test/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticatorTest.java @@ -40,6 +40,8 @@ import java.util.List; import java.util.Map; +import org.junit.Before; +import org.junit.Test; import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.ActionListener; import org.opensearch.common.bytes.BytesReference; @@ -52,9 +54,8 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestStatus; -import org.junit.Before; -import org.junit.Test; - +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.user.AuthCredentials; import com.google.common.collect.ImmutableSet; @@ -83,19 +84,19 @@ public void testGetType() { @Test(expected = OpenSearchSecurityException.class) public void testThrowsExceptionWhenMissingXFFDone() { authenticator = new HTTPExtendedProxyAuthenticator(Settings.EMPTY, null); - authenticator.extractCredentials(new TestRestRequest(), new ThreadContext(Settings.EMPTY)); + authenticator.extractCredentials(new TestRestRequest().asSecurityRequest(), new ThreadContext(Settings.EMPTY)); } @Test public void testReturnsNullWhenUserHeaderIsUnconfigured() { authenticator = new HTTPExtendedProxyAuthenticator(Settings.EMPTY, null); - assertNull(authenticator.extractCredentials(new TestRestRequest(), context)); + assertNull(authenticator.extractCredentials(new TestRestRequest().asSecurityRequest(), context)); } @Test public void testReturnsNullWhenUserHeaderIsMissing() { - - assertNull(authenticator.extractCredentials(new TestRestRequest(), context)); + + assertNull(authenticator.extractCredentials(new TestRestRequest().asSecurityRequest(), context)); } @Test @@ -107,10 +108,10 @@ public void testReturnsCredentials() { headers.get("proxy_uid").add("123"); headers.get("proxy_uid").add("456"); headers.get("proxy_other").add("someothervalue"); - - settings = Settings.builder().put(settings).put("attr_header_prefix","proxy_").build(); - authenticator = new HTTPExtendedProxyAuthenticator(settings,null); - AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers), context); + + settings = Settings.builder().put(settings).put("attr_header_prefix", "proxy_").build(); + authenticator = new HTTPExtendedProxyAuthenticator(settings, null); + AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers).asSecurityRequest(), context); assertNotNull(creds); assertEquals("aValidUser", creds.getUsername()); assertEquals("123,456", creds.getAttributes().get("attr.proxy.uid")); @@ -124,13 +125,10 @@ public void testTrimOnRoles() { headers.put("roles", new ArrayList<>()); headers.get("user").add("aValidUser"); headers.get("roles").add("role1, role2,\t"); - - settings = Settings.builder().put(settings) - .put("roles_header","roles") - .put("roles_separator", ",") - .build(); - authenticator = new HTTPExtendedProxyAuthenticator(settings,null); - AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers), context); + + settings = Settings.builder().put(settings).put("roles_header", "roles").put("roles_separator", ",").build(); + authenticator = new HTTPExtendedProxyAuthenticator(settings, null); + AuthCredentials creds = authenticator.extractCredentials(new TestRestRequest(headers).asSecurityRequest(), context); assertNotNull(creds); assertEquals("aValidUser", creds.getUsername()); assertEquals(ImmutableSet.of("role1", "role2"), creds.getBackendRoles()); @@ -165,6 +163,9 @@ public boolean hasContent() { return false; } + public SecurityRequestChannel asSecurityRequest() { + return SecurityRequestFactory.from(this, null); + } } static class HttpRequestImpl implements HttpRequest { diff --git a/src/test/java/org/opensearch/security/util/FakeRestRequest.java b/src/test/java/org/opensearch/security/util/FakeRestRequest.java index 4c60f7d67c..d0cc452fc0 100644 --- a/src/test/java/org/opensearch/security/util/FakeRestRequest.java +++ b/src/test/java/org/opensearch/security/util/FakeRestRequest.java @@ -22,6 +22,8 @@ import org.opensearch.common.bytes.BytesReference; import org.opensearch.rest.RestRequest; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; public class FakeRestRequest extends RestRequest { @@ -118,4 +120,7 @@ private static Map> convert(Map headers) { return ret; } -} \ No newline at end of file + public SecurityRequestChannel asSecurityRequest() { + return SecurityRequestFactory.from(this, null); + } +}