Skip to content

Commit

Permalink
Fix issue with saml tests
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Nied <petern@amazon.com>
  • Loading branch information
peternied committed Oct 5, 2023
1 parent e323823 commit 177ad11
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand All @@ -36,6 +38,10 @@
import com.onelogin.saml2.exception.ValidationError;
import com.onelogin.saml2.settings.Saml2Settings;
import com.onelogin.saml2.util.Util;

import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;

import org.apache.commons.lang3.StringUtils;
import org.apache.cxf.jaxrs.json.basic.JsonMapObjectReaderWriter;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
Expand All @@ -46,6 +52,7 @@
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import org.apache.cxf.rs.security.jose.jwt.JwtUtils;
import org.apache.http.HttpStatus;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.joda.time.DateTime;
Expand All @@ -63,6 +70,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.DefaultObjectMapper;
import org.opensearch.security.dlic.rest.api.AuthTokenProcessorAction;
import org.opensearch.security.filter.SecurityResponse;

class AuthTokenProcessorHandler {
private static final Logger log = LogManager.getLogger(AuthTokenProcessorHandler.class);
Expand Down Expand Up @@ -122,19 +130,18 @@ class AuthTokenProcessorHandler {
}

@SuppressWarnings("removal")
boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception {
Optional<SecurityResponse> handle(RestRequest restRequest) throws Exception {
try {
final SecurityManager sm = System.getSecurityManager();

if (sm != null) {
sm.checkPermission(new SpecialPermission());
}

return AccessController.doPrivileged(new PrivilegedExceptionAction<Boolean>() {
return AccessController.doPrivileged(new PrivilegedExceptionAction<Optional<SecurityResponse>>() {
@Override
public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException,
SAXException, SettingsException {
return handleLowLevel(restRequest, restChannel);
public Optional<SecurityResponse> run() throws SamlConfigException, IOException {
return handleLowLevel(restRequest);
}
});
} catch (PrivilegedActionException e) {
Expand All @@ -147,13 +154,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc
}

private AuthTokenProcessorAction.Response handleImpl(
RestRequest restRequest,
RestChannel restChannel,
String samlResponseBase64,
String samlRequestId,
String acsEndpoint,
Saml2Settings saml2Settings
) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException {
) {
if (token_log.isDebugEnabled()) {
try {
token_log.debug(
Expand Down Expand Up @@ -188,8 +193,7 @@ private AuthTokenProcessorAction.Response handleImpl(
}
}

private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException,
XPathExpressionException, ParserConfigurationException, SAXException, SettingsException {
private Optional<SecurityResponse> handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException {
try {

if (restRequest.getMediaType() != XContentType.JSON) {
Expand Down Expand Up @@ -234,31 +238,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel)
acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue());
}

AuthTokenProcessorAction.Response responseBody = this.handleImpl(
restRequest,
restChannel,
samlResponseBase64,
samlRequestId,
acsEndpoint,
saml2Settings
);
AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings);

if (responseBody == null) {
return false;
return Optional.empty();
}

String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody);

BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString);
restChannel.sendResponse(authenticateResponse);

return true;
return Optional.of(new SecurityResponse(HttpStatus.SC_OK, Map.of(HttpHeaderNames.CONTENT_TYPE.toString(), "application/json"), responseBodyString));
} catch (JsonProcessingException e) {
log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e);

BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed");
restChannel.sendResponse(authenticateResponse);
return true;
return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.opensearch.security.filter.SecurityRequest;
import org.opensearch.security.filter.SecurityRequestChannelUnsupported;
import org.opensearch.security.filter.SecurityResponse;
import org.opensearch.security.filter.OpenSearchRequest;
import org.opensearch.security.filter.OpenSearchRequestChannel;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.PemKeyReader;
Expand Down Expand Up @@ -184,16 +185,14 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest

if (API_AUTHTOKEN_SUFFIX.equals(suffix)) {
// Verficiation of SAML ASC endpoint only works with RestRequests
if (!(request instanceof OpenSearchRequestChannel)) {
if (!(request instanceof OpenSearchRequest)) {
throw new SecurityRequestChannelUnsupported();
} else {
final OpenSearchRequestChannel securityRequestChannel = (OpenSearchRequestChannel) request;
final RestRequest restRequest = securityRequestChannel.breakEncapsulationForRequest();
final RestChannel channel = securityRequestChannel.breakEncapsulationForChannel();
if (this.authTokenProcessorHandler.handle(restRequest, channel)) {
// The ACS response was accepted
securityRequestChannel.markCompleted();
return Optional.empty();
final OpenSearchRequest openSearchRequest = (OpenSearchRequest) request;
final RestRequest restRequest = openSearchRequest.breakEncapsulationForRequest();
Optional<SecurityResponse> restResponse = this.authTokenProcessorHandler.handle(restRequest);
if (restResponse.isPresent()) {
return restResponse;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -34,6 +35,7 @@
import com.google.common.collect.ImmutableMap;
import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -42,7 +44,6 @@
import org.opensaml.saml.saml2.core.NameIDType;

import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.MediaType;
Expand All @@ -54,13 +55,15 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.DefaultObjectMapper;
import org.opensearch.security.filter.SecurityRequestFactory;
import org.opensearch.security.filter.OpenSearchRequestChannel;
import org.opensearch.security.filter.SecurityResponse;
import org.opensearch.security.filter.SecurityRequest;
import org.opensearch.security.test.helper.file.FileHelper;
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.FakeRestRequest;

import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_CONTENT;
import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.IDP_METADATA_URL;
import static org.hamcrest.MatcherAssert.assertThat;

public class HTTPSamlAuthenticatorTest {
protected MockSamlIdpServer mockSamlIdpServer;
Expand Down Expand Up @@ -158,17 +161,15 @@ public void basicTest() throws Exception {
Assert.assertEquals("horst", jwt.getClaim("sub"));
}

private TestRestChannel sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) {
TestRestChannel testChannel = new TestRestChannel(request);
OpenSearchRequestChannel tokenRestChannel = (OpenSearchRequestChannel) SecurityRequestFactory.from(request, testChannel);
private Optional<SecurityResponse> sendToAuthenticator(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) {
final SecurityRequest tokenRestChannel = SecurityRequestFactory.from(request);

samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
return testChannel;
return samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
}

private String getResponse(HTTPSamlAuthenticator samlAuthenticator, RestRequest request) throws Exception {
TestRestChannel testChannel = sendToAuthenticator(samlAuthenticator, request);
return new String(BytesReference.toBytes(testChannel.response.content()));
SecurityResponse response = sendToAuthenticator(samlAuthenticator, request).orElseThrow();
return response.getBody();
}

@Test
Expand Down Expand Up @@ -534,9 +535,9 @@ public void badUnsolicitedSsoTest() throws Exception {
authenticateHeaders,
"/opendistrosecurity/saml/acs/idpinitiated"
);
TestRestChannel tokenRestChannel = sendToAuthenticator(samlAuthenticator, tokenRestRequest);
SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow();

Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status());
Assert.assertEquals(RestStatus.UNAUTHORIZED.getStatus(), response.getStatus());
}

@Test
Expand Down Expand Up @@ -564,9 +565,9 @@ public void wrongCertTest() throws Exception {
String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location);

RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
TestRestChannel tokenRestChannel = sendToAuthenticator(samlAuthenticator, tokenRestRequest);
SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow();

Assert.assertEquals(401, tokenRestChannel.response.status().getStatus());
Assert.assertEquals(401, response.getStatus());
}

@Test
Expand All @@ -591,9 +592,9 @@ public void noSignatureTest() throws Exception {
String encodedSamlResponse = mockSamlIdpServer.handleSsoGetRequestURI(authenticateHeaders.location);

RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
TestRestChannel tokenRestChannel = sendToAuthenticator(samlAuthenticator, tokenRestRequest);
SecurityResponse response = sendToAuthenticator(samlAuthenticator, tokenRestRequest).orElseThrow();

Assert.assertEquals(401, tokenRestChannel.response.status().getStatus());
Assert.assertEquals(401, response.getStatus());
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -815,9 +816,9 @@ public void initialConnectionFailureTest() throws Exception {
HTTPSamlAuthenticator samlAuthenticator = new HTTPSamlAuthenticator(settings, null);

RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap<String, String>());
TestRestChannel restChannel = sendToAuthenticator(samlAuthenticator, restRequest);
Optional<SecurityResponse> maybeResponse = sendToAuthenticator(samlAuthenticator, restRequest);

Assert.assertNull(restChannel.response);
assertThat(maybeResponse.isPresent(), Matchers.equalTo(false));

mockSamlIdpServer.start();

Expand Down Expand Up @@ -852,14 +853,11 @@ public void initialConnectionFailureTest() throws Exception {

private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuthenticator) {
RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap<String, String>());
TestRestChannel restChannel = sendToAuthenticator(samlAuthenticator, restRequest);
SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow();

List<String> wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate");
String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate");

Assert.assertNotNull(wwwAuthenticateHeaders);
Assert.assertEquals("More than one WWW-Authenticate header: " + wwwAuthenticateHeaders, 1, wwwAuthenticateHeaders.size());

String wwwAuthenticateHeader = wwwAuthenticateHeaders.get(0);
Assert.assertNotNull(wwwAuthenticateHeader);

Matcher wwwAuthenticateHeaderMatcher = WWW_AUTHENTICATE_PATTERN.matcher(wwwAuthenticateHeader);

Expand Down

0 comments on commit 177ad11

Please sign in to comment.