diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 3fee4a9444..0ef063f5e1 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -19,6 +19,8 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -36,6 +38,10 @@ import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; + +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; + import org.apache.commons.lang3.StringUtils; import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter; import org.apache.cxf.rs.security.jose.jwk.JsonWebKey; @@ -46,6 +52,7 @@ import org.apache.cxf.rs.security.jose.jwt.JwtClaims; import org.apache.cxf.rs.security.jose.jwt.JwtToken; import org.apache.cxf.rs.security.jose.jwt.JwtUtils; +import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; @@ -63,6 +70,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.dlic.rest.api.AuthTokenProcessorAction; +import org.opensearch.security.filter.SecurityResponse; class AuthTokenProcessorHandler { private static final Logger log = LogManager.getLogger(AuthTokenProcessorHandler.class); @@ -122,7 +130,7 @@ class AuthTokenProcessorHandler { } @SuppressWarnings("removal") - boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception { + Optional handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -130,11 +138,10 @@ boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exceptio sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction>() { @Override - public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException, - SAXException, SettingsException { - return handleLowLevel(restRequest, restChannel); + public Optional run() throws SamlConfigException, IOException { + return handleLowLevel(restRequest); } }); } catch (PrivilegedActionException e) { @@ -147,13 +154,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc } private AuthTokenProcessorAction.Response handleImpl( - RestRequest restRequest, - RestChannel restChannel, String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings - ) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException { + ) { if (token_log.isDebugEnabled()) { try { token_log.debug( @@ -188,8 +193,7 @@ private AuthTokenProcessorAction.Response handleImpl( } } - private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException, - XPathExpressionException, ParserConfigurationException, SAXException, SettingsException { + private Optional handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getMediaType() != XContentType.JSON) { @@ -234,31 +238,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue()); } - AuthTokenProcessorAction.Response responseBody = this.handleImpl( - restRequest, - restChannel, - samlResponseBase64, - samlRequestId, - acsEndpoint, - saml2Settings - ); + AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return false; + return Optional.empty(); } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); - restChannel.sendResponse(authenticateResponse); - - return true; + return Optional.of(new SecurityResponse(HttpStatus.SC_OK, Map.of(HttpHeaderNames.CONTENT_TYPE.toString(), "application/json"), responseBodyString)); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); - restChannel.sendResponse(authenticateResponse); - return true; + return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed")); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index e511ba7bca..ecc2448c09 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -65,6 +65,7 @@ import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.filter.SecurityRequestChannelUnsupported; import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.filter.OpenSearchRequestChannel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.PemKeyReader; @@ -184,16 +185,14 @@ public Optional reRequestAuthentication(final SecurityRequest if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { // Verficiation of SAML ASC endpoint only works with RestRequests - if (!(request instanceof OpenSearchRequestChannel)) { + if (!(request instanceof OpenSearchRequest)) { throw new SecurityRequestChannelUnsupported(); } else { - final OpenSearchRequestChannel securityRequestChannel = (OpenSearchRequestChannel) request; - final RestRequest restRequest = securityRequestChannel.breakEncapsulationForRequest(); - final RestChannel channel = securityRequestChannel.breakEncapsulationForChannel(); - if (this.authTokenProcessorHandler.handle(restRequest, channel)) { - // The ACS response was accepted - securityRequestChannel.markCompleted(); - return Optional.empty(); + final OpenSearchRequest openSearchRequest = (OpenSearchRequest) request; + final RestRequest restRequest = openSearchRequest.breakEncapsulationForRequest(); + Optional restResponse = this.authTokenProcessorHandler.handle(restRequest); + if (restResponse.isPresent()) { + return restResponse; } } } diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index dd030da5ed..17a2148fa5 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -25,6 +25,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -34,6 +35,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer; import org.apache.cxf.rs.security.jose.jwt.JwtToken; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -42,7 +44,6 @@ import org.opensaml.saml.saml2.core.NameIDType; import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.MediaType; @@ -54,13 +55,15 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.filter.SecurityRequestFactory; -import org.opensearch.security.filter.OpenSearchRequestChannel; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.test.helper.file.FileHelper; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.util.FakeRestRequest; import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_CONTENT; import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_URL; +import static org.hamcrest.MatcherAssert.assertThat; public class HTTPSamlAuthenticatorTest { protected MockSamlIdpServer mockSamlIdpServer; @@ -158,17 +161,15 @@ public void basicTest() throws Exception { Assert.assertEquals("horst", jwt.getClaim("sub")); } - private TestRestChannel sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) { - TestRestChannel testChannel = new TestRestChannel(request); - OpenSearchRequestChannel tokenRestChannel = (OpenSearchRequestChannel) SecurityRequestFactory.from(request, testChannel); + private Optional sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) { + final SecurityRequest tokenRestChannel = SecurityRequestFactory.from(request); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); - return testChannel; + return samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); } private String getResponse(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) throws Exception { - TestRestChannel testChannel = sendToAuthenticator(samlAuthenticator, request); - return new String(BytesReference.toBytes(testChannel.response.content())); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, request).orElseThrow(); + return response.getBody(); } @Test @@ -534,9 +535,9 @@ public void badUnsolicitedSsoTest() throws Exception { authenticateHeaders, "/opendistrosecurity/saml/acs/idpinitiated" ); - TestRestChannel tokenRestChannel = sendToAuthenticator(samlAuthenticator, tokenRestRequest); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); + Assert.assertEquals(RestStatus.UNAUTHORIZED.getStatus(), response.getStatus()); } @Test @@ -564,9 +565,9 @@ public void wrongCertTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = sendToAuthenticator(samlAuthenticator, tokenRestRequest); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); + Assert.assertEquals(401, response.getStatus()); } @Test @@ -591,9 +592,9 @@ public void noSignatureTest() throws Exception { String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location); RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); - TestRestChannel tokenRestChannel = sendToAuthenticator(samlAuthenticator, tokenRestRequest); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow(); - Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); + Assert.assertEquals(401, response.getStatus()); } @SuppressWarnings("unchecked") @@ -815,9 +816,9 @@ public void initialConnectionFailureTest() throws Exception { HTTPSamlAuthenticator samlAuthenticator = new HTTPSamlAuthenticator(settings, null); RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); - TestRestChannel restChannel = sendToAuthenticator(samlAuthenticator, restRequest); + Optional maybeResponse = sendToAuthenticator(samlAuthenticator, restRequest); - Assert.assertNull(restChannel.response); + assertThat(maybeResponse.isPresent(), Matchers.equalTo(false)); mockSamlIdpServer.start(); @@ -852,14 +853,11 @@ public void initialConnectionFailureTest() throws Exception { private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuthenticator) { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); - TestRestChannel restChannel = sendToAuthenticator(samlAuthenticator, restRequest); + SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow(); - List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); + String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate"); - Assert.assertNotNull(wwwAuthenticateHeaders); - Assert.assertEquals("More than one WWW-Authenticate header: " + wwwAuthenticateHeaders, 1, wwwAuthenticateHeaders.size()); - - String wwwAuthenticateHeader = wwwAuthenticateHeaders.get(0); + Assert.assertNotNull(wwwAuthenticateHeader); Matcher wwwAuthenticateHeaderMatcher = WWW_AUTHENTICATE_PATTERN.matcher(wwwAuthenticateHeader);