From 2f8ff309bcd87f5ba15846195f7a8c7de79a489b Mon Sep 17 00:00:00 2001 From: David Leifker Date: Wed, 23 Oct 2024 11:34:48 -0500 Subject: [PATCH] refactor(datahub-frontend): upgrade frontend pac4j * Upgrade play components * scala 2.13 * pac4j 6.0.6 * oauth2-oidc-sdk/jose-jwt * java 11 -> java 17 --- build.gradle | 30 ++- datahub-frontend/app/auth/AuthModule.java | 91 +++---- datahub-frontend/app/auth/AuthUtils.java | 1 + datahub-frontend/app/auth/CookieConfigs.java | 18 +- datahub-frontend/app/auth/JAASConfigs.java | 6 +- .../app/auth/NativeAuthenticationConfigs.java | 12 +- datahub-frontend/app/auth/sso/SsoConfigs.java | 53 ++--- datahub-frontend/app/auth/sso/SsoManager.java | 35 ++- .../app/auth/sso/SsoProvider.java | 9 +- .../sso/oidc/OidcAuthorizationGenerator.java | 21 +- .../app/auth/sso/oidc/OidcCallbackLogic.java | 224 ++++++++++++------ .../app/auth/sso/oidc/OidcConfigs.java | 13 +- .../app/auth/sso/oidc/OidcProvider.java | 71 +++--- .../sso/oidc/OidcResponseErrorHandler.java | 19 +- .../oidc/custom/CustomOidcAuthenticator.java | 94 +++++--- .../sso/oidc/custom/CustomOidcClient.java | 37 ++- .../CustomOidcRedirectionActionBuilder.java | 21 +- .../app/client/KafkaTrackingProducer.java | 27 ++- .../app/config/ConfigurationProvider.java | 3 - .../app/controllers/Application.java | 120 +++++----- .../controllers/AuthenticationController.java | 154 ++++++------ .../controllers/CentralLogoutController.java | 6 +- .../app/controllers/RedirectController.java | 5 +- .../controllers/SsoCallbackController.java | 88 +++---- .../app/controllers/TrackingController.java | 22 +- datahub-frontend/build.gradle | 6 + datahub-frontend/conf/application.conf | 1 + datahub-frontend/play.gradle | 40 ++-- .../test/app/ApplicationTest.java | 196 ++++++++++++--- .../test/oidc/OidcCallbackLogicTest.java | 91 +++---- .../test/security/OidcConfigurationTest.java | 19 +- 31 files changed, 938 insertions(+), 595 deletions(-) diff --git a/build.gradle b/build.gradle index 67968ce3ee290..77f8395ac898e 100644 --- a/build.gradle +++ b/build.gradle @@ -45,7 +45,9 @@ buildscript { ext.elasticsearchVersion = '2.11.1' // ES 7.10, Opensearch 1.x, 2.x ext.jacksonVersion = '2.15.3' ext.jettyVersion = '11.0.21' + // see also datahub-frontend/play.gradle ext.playVersion = '2.8.22' + ext.playScalaVersion = '2.13' ext.log4jVersion = '2.23.1' ext.slf4jVersion = '1.7.36' ext.logbackClassic = '1.4.14' @@ -103,7 +105,7 @@ project.ext.spec = [ ] project.ext.externalDependency = [ - 'akkaHttp': 'com.typesafe.akka:akka-http-core_2.12:10.2.10', + 'akkaHttp': "com.typesafe.akka:akka-http-core_$playScalaVersion:10.2.10", 'antlr4Runtime': 'org.antlr:antlr4-runtime:4.9.3', 'antlr4': 'org.antlr:antlr4:4.9.3', 'assertJ': 'org.assertj:assertj-core:3.11.1', @@ -212,18 +214,18 @@ project.ext.externalDependency = [ 'parquet': 'org.apache.parquet:parquet-avro:1.12.3', 'parquetHadoop': 'org.apache.parquet:parquet-hadoop:1.13.1', 'picocli': 'info.picocli:picocli:4.5.0', - 'playCache': "com.typesafe.play:play-cache_2.12:$playVersion", - 'playCaffeineCache': "com.typesafe.play:play-caffeine-cache_2.12:$playVersion", - 'playWs': 'com.typesafe.play:play-ahc-ws-standalone_2.12:2.1.10', - 'playDocs': "com.typesafe.play:play-docs_2.12:$playVersion", - 'playGuice': "com.typesafe.play:play-guice_2.12:$playVersion", - 'playJavaJdbc': "com.typesafe.play:play-java-jdbc_2.12:$playVersion", - 'playAkkaHttpServer': "com.typesafe.play:play-akka-http-server_2.12:$playVersion", - 'playServer': "com.typesafe.play:play-server_2.12:$playVersion", - 'playTest': "com.typesafe.play:play-test_2.12:$playVersion", - 'playFilters': "com.typesafe.play:filters-helpers_2.12:$playVersion", - 'pac4j': 'org.pac4j:pac4j-oidc:4.5.8', - 'playPac4j': 'org.pac4j:play-pac4j_2.12:9.0.2', + 'playCache': "com.typesafe.play:play-cache_$playScalaVersion:$playVersion", + 'playCaffeineCache': "com.typesafe.play:play-caffeine-cache_$playScalaVersion:$playVersion", + 'playWs': "com.typesafe.play:play-ahc-ws-standalone_$playScalaVersion:2.1.10", + 'playDocs': "com.typesafe.play:play-docs_$playScalaVersion:$playVersion", + 'playGuice': "com.typesafe.play:play-guice_$playScalaVersion:$playVersion", + 'playJavaJdbc': "com.typesafe.play:play-java-jdbc_$playScalaVersion:$playVersion", + 'playAkkaHttpServer': "com.typesafe.play:play-akka-http-server_$playScalaVersion:$playVersion", + 'playServer': "com.typesafe.play:play-server_$playScalaVersion:$playVersion", + 'playTest': "com.typesafe.play:play-test_$playScalaVersion:$playVersion", + 'playFilters': "com.typesafe.play:filters-helpers_$playScalaVersion:$playVersion", + 'pac4j': 'org.pac4j:pac4j-oidc:6.0.6', + 'playPac4j': "org.pac4j:play-pac4j_$playScalaVersion:12.0.0-PLAY2.8", 'postgresql': 'org.postgresql:postgresql:42.3.9', 'protobuf': 'com.google.protobuf:protobuf-java:3.25.5', 'grpcProtobuf': 'io.grpc:grpc-protobuf:1.53.0', @@ -407,6 +409,8 @@ subprojects { googleJavaFormat() target project.fileTree(project.projectDir) { include 'src/**/*.java' + include 'app/**/*.java' + include 'test/**/*.java' exclude 'src/**/resources/' exclude 'src/**/generated/' exclude 'src/**/mainGeneratedDataTemplate/' diff --git a/datahub-frontend/app/auth/AuthModule.java b/datahub-frontend/app/auth/AuthModule.java index 7bb2547890701..d51795330f5ce 100644 --- a/datahub-frontend/app/auth/AuthModule.java +++ b/datahub-frontend/app/auth/AuthModule.java @@ -21,27 +21,28 @@ import com.linkedin.util.Configuration; import config.ConfigurationProvider; import controllers.SsoCallbackController; -import io.datahubproject.metadata.context.ValidationContext; -import java.nio.charset.StandardCharsets; -import java.util.Collections; - import io.datahubproject.metadata.context.ActorContext; import io.datahubproject.metadata.context.AuthorizationContext; import io.datahubproject.metadata.context.EntityRegistryContext; import io.datahubproject.metadata.context.OperationContext; import io.datahubproject.metadata.context.OperationContextConfig; import io.datahubproject.metadata.context.SearchContext; +import io.datahubproject.metadata.context.ValidationContext; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.codec.digest.DigestUtils; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; import org.pac4j.core.config.Config; import org.pac4j.core.context.session.SessionStore; +import org.pac4j.core.profile.ProfileManager; +import org.pac4j.core.util.serializer.JavaSerializer; import org.pac4j.play.LogoutController; import org.pac4j.play.http.PlayHttpActionAdapter; import org.pac4j.play.store.PlayCacheSessionStore; import org.pac4j.play.store.PlayCookieSessionStore; -import org.pac4j.play.store.PlaySessionStore; import org.pac4j.play.store.ShiroAesDataEncrypter; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import play.Environment; @@ -63,14 +64,16 @@ public class AuthModule extends AbstractModule { private static final String PAC4J_SESSIONSTORE_PROVIDER_CONF = "pac4j.sessionStore.provider"; private static final String ENTITY_CLIENT_RETRY_INTERVAL = "entityClient.retryInterval"; private static final String ENTITY_CLIENT_NUM_RETRIES = "entityClient.numRetries"; - private static final String ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE = "entityClient.restli.get.batchSize"; - private static final String ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY = "entityClient.restli.get.batchConcurrency"; + private static final String ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE = + "entityClient.restli.get.batchSize"; + private static final String ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY = + "entityClient.restli.get.batchConcurrency"; private static final String GET_SSO_SETTINGS_ENDPOINT = "auth/getSsoSettings"; - private final com.typesafe.config.Config _configs; + private final com.typesafe.config.Config configs; public AuthModule(final Environment environment, final com.typesafe.config.Config configs) { - _configs = configs; + this.configs = configs; } @Override @@ -84,13 +87,13 @@ protected void configure() { * the response will be rejected by the browser. Default to PlayCacheCookieStore so that * datahub-frontend container remains as a stateless service */ - String sessionStoreProvider = _configs.getString(PAC4J_SESSIONSTORE_PROVIDER_CONF); + String sessionStoreProvider = configs.getString(PAC4J_SESSIONSTORE_PROVIDER_CONF); if (sessionStoreProvider.equals("PlayCacheSessionStore")) { final PlayCacheSessionStore playCacheSessionStore = new PlayCacheSessionStore(getProvider(SyncCacheApi.class)); bind(SessionStore.class).toInstance(playCacheSessionStore); - bind(PlaySessionStore.class).toInstance(playCacheSessionStore); + bind(PlayCacheSessionStore.class).toInstance(playCacheSessionStore); } else { PlayCookieSessionStore playCacheCookieStore; try { @@ -98,17 +101,18 @@ protected void configure() { // hash the input to generate a fixed-length string. Then, we convert // it to hex and slice the first 16 bytes, because AES key length must strictly // have a specific length. - final String aesKeyBase = _configs.getString(PAC4J_AES_KEY_BASE_CONF); + final String aesKeyBase = configs.getString(PAC4J_AES_KEY_BASE_CONF); final String aesKeyHash = DigestUtils.sha256Hex(aesKeyBase.getBytes(StandardCharsets.UTF_8)); final String aesEncryptionKey = aesKeyHash.substring(0, 16); playCacheCookieStore = new PlayCookieSessionStore(new ShiroAesDataEncrypter(aesEncryptionKey.getBytes())); + playCacheCookieStore.setSerializer(new JavaSerializer()); } catch (Exception e) { throw new RuntimeException("Failed to instantiate Pac4j cookie session store!", e); } bind(SessionStore.class).toInstance(playCacheCookieStore); - bind(PlaySessionStore.class).toInstance(playCacheCookieStore); + bind(PlayCookieSessionStore.class).toInstance(playCacheCookieStore); } try { @@ -133,9 +137,12 @@ protected void configure() { @Provides @Singleton - protected Config provideConfig() { + protected Config provideConfig(@Nonnull SessionStore sessionStore) { Config config = new Config(); + config.setSessionStoreFactory(parameters -> sessionStore); config.setHttpActionAdapter(new PlayHttpActionAdapter()); + config.setProfileManagerFactory(ProfileManager::new); + return config; } @@ -145,7 +152,7 @@ protected SsoManager provideSsoManager( Authentication systemAuthentication, CloseableHttpClient httpClient) { SsoManager manager = new SsoManager( - _configs, systemAuthentication, getSsoSettingsRequestUrl(_configs), httpClient); + configs, systemAuthentication, getSsoSettingsRequestUrl(configs), httpClient); manager.initializeSsoProvider(); return manager; } @@ -155,8 +162,8 @@ protected SsoManager provideSsoManager( protected Authentication provideSystemAuthentication() { // Returns an instance of Authentication used to authenticate system initiated calls to Metadata // Service. - String systemClientId = _configs.getString(SYSTEM_CLIENT_ID_CONFIG_PATH); - String systemSecret = _configs.getString(SYSTEM_CLIENT_SECRET_CONFIG_PATH); + String systemClientId = configs.getString(SYSTEM_CLIENT_ID_CONFIG_PATH); + String systemSecret = configs.getString(SYSTEM_CLIENT_SECRET_CONFIG_PATH); final Actor systemActor = new Actor(ActorType.USER, systemClientId); // TODO: Change to service actor once supported. return new Authentication( @@ -169,27 +176,25 @@ protected Authentication provideSystemAuthentication() { @Singleton @Named("systemOperationContext") protected OperationContext provideOperationContext( - final Authentication systemAuthentication, - final ConfigurationProvider configurationProvider) { + final Authentication systemAuthentication, + final ConfigurationProvider configurationProvider) { ActorContext systemActorContext = - ActorContext.builder() - .systemAuth(true) - .authentication(systemAuthentication) - .build(); - OperationContextConfig systemConfig = OperationContextConfig.builder() + ActorContext.builder().systemAuth(true).authentication(systemAuthentication).build(); + OperationContextConfig systemConfig = + OperationContextConfig.builder() .viewAuthorizationConfiguration(configurationProvider.getAuthorization().getView()) .allowSystemAuthentication(true) .build(); return OperationContext.builder() - .operationContextConfig(systemConfig) - .systemActorContext(systemActorContext) - // Authorizer.EMPTY is fine since it doesn't actually apply to system auth - .authorizationContext(AuthorizationContext.builder().authorizer(Authorizer.EMPTY).build()) - .searchContext(SearchContext.EMPTY) - .entityRegistryContext(EntityRegistryContext.builder().build(EmptyEntityRegistry.EMPTY)) - .validationContext(ValidationContext.builder().alternateValidation(false).build()) - .build(systemAuthentication); + .operationContextConfig(systemConfig) + .systemActorContext(systemActorContext) + // Authorizer.EMPTY is fine since it doesn't actually apply to system auth + .authorizationContext(AuthorizationContext.builder().authorizer(Authorizer.EMPTY).build()) + .searchContext(SearchContext.EMPTY) + .entityRegistryContext(EntityRegistryContext.builder().build(EmptyEntityRegistry.EMPTY)) + .validationContext(ValidationContext.builder().alternateValidation(false).build()) + .build(systemAuthentication); } @Provides @@ -208,11 +213,11 @@ protected SystemEntityClient provideEntityClient( return new SystemRestliEntityClient( buildRestliClient(), - new ExponentialBackoff(_configs.getInt(ENTITY_CLIENT_RETRY_INTERVAL)), - _configs.getInt(ENTITY_CLIENT_NUM_RETRIES), + new ExponentialBackoff(configs.getInt(ENTITY_CLIENT_RETRY_INTERVAL)), + configs.getInt(ENTITY_CLIENT_NUM_RETRIES), configurationProvider.getCache().getClient().getEntityClient(), - Math.max(1, _configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE)), - Math.max(1, _configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY))); + Math.max(1, configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_SIZE)), + Math.max(1, configs.getInt(ENTITY_CLIENT_RESTLI_GET_BATCH_CONCURRENCY))); } @Provides @@ -220,11 +225,11 @@ protected SystemEntityClient provideEntityClient( protected AuthServiceClient provideAuthClient( Authentication systemAuthentication, CloseableHttpClient httpClient) { // Init a GMS auth client - final String metadataServiceHost = getMetadataServiceHost(_configs); + final String metadataServiceHost = getMetadataServiceHost(configs); - final int metadataServicePort = getMetadataServicePort(_configs); + final int metadataServicePort = getMetadataServicePort(configs); - final boolean metadataServiceUseSsl = doesMetadataServiceUseSsl(_configs); + final boolean metadataServiceUseSsl = doesMetadataServiceUseSsl(configs); return new AuthServiceClient( metadataServiceHost, @@ -243,22 +248,22 @@ protected CloseableHttpClient provideHttpClient() { private com.linkedin.restli.client.Client buildRestliClient() { final String metadataServiceHost = utils.ConfigUtil.getString( - _configs, + configs, METADATA_SERVICE_HOST_CONFIG_PATH, utils.ConfigUtil.DEFAULT_METADATA_SERVICE_HOST); final int metadataServicePort = utils.ConfigUtil.getInt( - _configs, + configs, utils.ConfigUtil.METADATA_SERVICE_PORT_CONFIG_PATH, utils.ConfigUtil.DEFAULT_METADATA_SERVICE_PORT); final boolean metadataServiceUseSsl = utils.ConfigUtil.getBoolean( - _configs, + configs, utils.ConfigUtil.METADATA_SERVICE_USE_SSL_CONFIG_PATH, ConfigUtil.DEFAULT_METADATA_SERVICE_USE_SSL); final String metadataServiceSslProtocol = utils.ConfigUtil.getString( - _configs, + configs, utils.ConfigUtil.METADATA_SERVICE_SSL_PROTOCOL_CONFIG_PATH, ConfigUtil.DEFAULT_METADATA_SERVICE_SSL_PROTOCOL); return DefaultRestliClientFactory.getRestLiClient( diff --git a/datahub-frontend/app/auth/AuthUtils.java b/datahub-frontend/app/auth/AuthUtils.java index 51bb784c61b3b..490f52bece651 100644 --- a/datahub-frontend/app/auth/AuthUtils.java +++ b/datahub-frontend/app/auth/AuthUtils.java @@ -75,6 +75,7 @@ public class AuthUtils { public static final String RESPONSE_MODE = "responseMode"; public static final String USE_NONCE = "useNonce"; public static final String READ_TIMEOUT = "readTimeout"; + public static final String CONNECT_TIMEOUT = "connectTimeout"; public static final String EXTRACT_JWT_ACCESS_TOKEN_CLAIMS = "extractJwtAccessTokenClaims"; // Retained for backwards compatibility public static final String PREFERRED_JWS_ALGORITHM = "preferredJwsAlgorithm"; diff --git a/datahub-frontend/app/auth/CookieConfigs.java b/datahub-frontend/app/auth/CookieConfigs.java index 63b2ce61aaf9b..e77e200144835 100644 --- a/datahub-frontend/app/auth/CookieConfigs.java +++ b/datahub-frontend/app/auth/CookieConfigs.java @@ -10,34 +10,34 @@ public class CookieConfigs { public static final String AUTH_COOKIE_SECURE = "play.http.session.secure"; public static final boolean DEFAULT_AUTH_COOKIE_SECURE = false; - private final int _ttlInHours; - private final String _authCookieSameSite; - private final boolean _authCookieSecure; + private final int ttlInHours; + private final String authCookieSameSite; + private final boolean authCookieSecure; public CookieConfigs(final Config configs) { - _ttlInHours = + ttlInHours = configs.hasPath(SESSION_TTL_CONFIG_PATH) ? configs.getInt(SESSION_TTL_CONFIG_PATH) : DEFAULT_SESSION_TTL_HOURS; - _authCookieSameSite = + authCookieSameSite = configs.hasPath(AUTH_COOKIE_SAME_SITE) ? configs.getString(AUTH_COOKIE_SAME_SITE) : DEFAULT_AUTH_COOKIE_SAME_SITE; - _authCookieSecure = + authCookieSecure = configs.hasPath(AUTH_COOKIE_SECURE) ? configs.getBoolean(AUTH_COOKIE_SECURE) : DEFAULT_AUTH_COOKIE_SECURE; } public int getTtlInHours() { - return _ttlInHours; + return ttlInHours; } public String getAuthCookieSameSite() { - return _authCookieSameSite; + return authCookieSameSite; } public boolean getAuthCookieSecure() { - return _authCookieSecure; + return authCookieSecure; } } diff --git a/datahub-frontend/app/auth/JAASConfigs.java b/datahub-frontend/app/auth/JAASConfigs.java index 529bf98e1fdcf..dee4ded68808a 100644 --- a/datahub-frontend/app/auth/JAASConfigs.java +++ b/datahub-frontend/app/auth/JAASConfigs.java @@ -8,16 +8,16 @@ public class JAASConfigs { public static final String JAAS_ENABLED_CONFIG_PATH = "auth.jaas.enabled"; - private Boolean _isEnabled = true; + private Boolean isEnabled = true; public JAASConfigs(final com.typesafe.config.Config configs) { if (configs.hasPath(JAAS_ENABLED_CONFIG_PATH) && !configs.getBoolean(JAAS_ENABLED_CONFIG_PATH)) { - _isEnabled = false; + isEnabled = false; } } public boolean isJAASEnabled() { - return _isEnabled; + return isEnabled; } } diff --git a/datahub-frontend/app/auth/NativeAuthenticationConfigs.java b/datahub-frontend/app/auth/NativeAuthenticationConfigs.java index 772c2c8f92f28..a7b8a8bc80067 100644 --- a/datahub-frontend/app/auth/NativeAuthenticationConfigs.java +++ b/datahub-frontend/app/auth/NativeAuthenticationConfigs.java @@ -7,17 +7,17 @@ public class NativeAuthenticationConfigs { public static final String NATIVE_AUTHENTICATION_ENFORCE_VALID_EMAIL_ENABLED_CONFIG_PATH = "auth.native.signUp.enforceValidEmail"; - private Boolean _isEnabled = true; - private Boolean _isEnforceValidEmailEnabled = true; + private Boolean isEnabled = true; + private Boolean isEnforceValidEmailEnabled = true; public NativeAuthenticationConfigs(final com.typesafe.config.Config configs) { if (configs.hasPath(NATIVE_AUTHENTICATION_ENABLED_CONFIG_PATH)) { - _isEnabled = + isEnabled = Boolean.parseBoolean( configs.getValue(NATIVE_AUTHENTICATION_ENABLED_CONFIG_PATH).toString()); } if (configs.hasPath(NATIVE_AUTHENTICATION_ENFORCE_VALID_EMAIL_ENABLED_CONFIG_PATH)) { - _isEnforceValidEmailEnabled = + isEnforceValidEmailEnabled = Boolean.parseBoolean( configs .getValue(NATIVE_AUTHENTICATION_ENFORCE_VALID_EMAIL_ENABLED_CONFIG_PATH) @@ -26,10 +26,10 @@ public NativeAuthenticationConfigs(final com.typesafe.config.Config configs) { } public boolean isNativeAuthenticationEnabled() { - return _isEnabled; + return isEnabled; } public boolean isEnforceValidEmailEnabled() { - return _isEnforceValidEmailEnabled; + return isEnforceValidEmailEnabled; } } diff --git a/datahub-frontend/app/auth/sso/SsoConfigs.java b/datahub-frontend/app/auth/sso/SsoConfigs.java index 976d0826f2277..46a2b7bfd27e8 100644 --- a/datahub-frontend/app/auth/sso/SsoConfigs.java +++ b/datahub-frontend/app/auth/sso/SsoConfigs.java @@ -1,10 +1,9 @@ package auth.sso; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; - import static auth.AuthUtils.*; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; /** * Class responsible for extracting and validating top-level SSO related configurations. TODO: @@ -25,72 +24,72 @@ public class SsoConfigs { private static final String DEFAULT_SUCCESS_REDIRECT_PATH = "/"; - private final String _authBaseUrl; - private final String _authBaseCallbackPath; - private final String _authSuccessRedirectPath; - private final Boolean _oidcEnabled; + private final String authBaseUrl; + private final String authBaseCallbackPath; + private final String authSuccessRedirectPath; + private final Boolean oidcEnabled; public SsoConfigs(Builder builder) { - _authBaseUrl = builder._authBaseUrl; - _authBaseCallbackPath = builder._authBaseCallbackPath; - _authSuccessRedirectPath = builder._authSuccessRedirectPath; - _oidcEnabled = builder._oidcEnabled; + authBaseUrl = builder.authBaseUrl; + authBaseCallbackPath = builder.authBaseCallbackPath; + authSuccessRedirectPath = builder.authSuccessRedirectPath; + oidcEnabled = builder.oidcEnabled; } public String getAuthBaseUrl() { - return _authBaseUrl; + return authBaseUrl; } public String getAuthBaseCallbackPath() { - return _authBaseCallbackPath; + return authBaseCallbackPath; } public String getAuthSuccessRedirectPath() { - return _authSuccessRedirectPath; + return authSuccessRedirectPath; } public Boolean isOidcEnabled() { - return _oidcEnabled; + return oidcEnabled; } public static class Builder> { - protected String _authBaseUrl = null; - private String _authBaseCallbackPath = DEFAULT_BASE_CALLBACK_PATH; - private String _authSuccessRedirectPath = DEFAULT_SUCCESS_REDIRECT_PATH; - protected Boolean _oidcEnabled = false; - private final ObjectMapper _objectMapper = new ObjectMapper(); + protected String authBaseUrl = null; + private String authBaseCallbackPath = DEFAULT_BASE_CALLBACK_PATH; + private String authSuccessRedirectPath = DEFAULT_SUCCESS_REDIRECT_PATH; + protected Boolean oidcEnabled = false; + private final ObjectMapper objectMapper = new ObjectMapper(); protected JsonNode jsonNode = null; // No need to check if changes are made since this method is only called at start-up. public Builder from(final com.typesafe.config.Config configs) { if (configs.hasPath(AUTH_BASE_URL_CONFIG_PATH)) { - _authBaseUrl = configs.getString(AUTH_BASE_URL_CONFIG_PATH); + authBaseUrl = configs.getString(AUTH_BASE_URL_CONFIG_PATH); } if (configs.hasPath(AUTH_BASE_CALLBACK_PATH_CONFIG_PATH)) { - _authBaseCallbackPath = configs.getString(AUTH_BASE_CALLBACK_PATH_CONFIG_PATH); + authBaseCallbackPath = configs.getString(AUTH_BASE_CALLBACK_PATH_CONFIG_PATH); } if (configs.hasPath(OIDC_ENABLED_CONFIG_PATH)) { - _oidcEnabled = + oidcEnabled = Boolean.TRUE.equals(Boolean.parseBoolean(configs.getString(OIDC_ENABLED_CONFIG_PATH))); } if (configs.hasPath(AUTH_SUCCESS_REDIRECT_PATH_CONFIG_PATH)) { - _authSuccessRedirectPath = configs.getString(AUTH_SUCCESS_REDIRECT_PATH_CONFIG_PATH); + authSuccessRedirectPath = configs.getString(AUTH_SUCCESS_REDIRECT_PATH_CONFIG_PATH); } return this; } public Builder from(String ssoSettingsJsonStr) { try { - jsonNode = _objectMapper.readTree(ssoSettingsJsonStr); + jsonNode = objectMapper.readTree(ssoSettingsJsonStr); } catch (Exception e) { throw new RuntimeException( String.format("Failed to parse ssoSettingsJsonStr %s into JSON", ssoSettingsJsonStr)); } if (jsonNode.has(BASE_URL)) { - _authBaseUrl = jsonNode.get(BASE_URL).asText(); + authBaseUrl = jsonNode.get(BASE_URL).asText(); } if (jsonNode.has(OIDC_ENABLED)) { - _oidcEnabled = jsonNode.get(OIDC_ENABLED).asBoolean(); + oidcEnabled = jsonNode.get(OIDC_ENABLED).asBoolean(); } return this; diff --git a/datahub-frontend/app/auth/sso/SsoManager.java b/datahub-frontend/app/auth/sso/SsoManager.java index 8377eb40e237f..468a5e4723981 100644 --- a/datahub-frontend/app/auth/sso/SsoManager.java +++ b/datahub-frontend/app/auth/sso/SsoManager.java @@ -26,22 +26,21 @@ public class SsoManager { private SsoProvider _provider; // Only one active provider at a time. - private final Authentication - _authentication; // Authentication used to fetch SSO settings from GMS - private final String _ssoSettingsRequestUrl; // SSO settings request URL. - private final CloseableHttpClient _httpClient; // HTTP client for making requests to GMS. - private com.typesafe.config.Config _configs; + private final Authentication authentication; // Authentication used to fetch SSO settings from GMS + private final String ssoSettingsRequestUrl; // SSO settings request URL. + private final CloseableHttpClient httpClient; // HTTP client for making requests to GMS. + private com.typesafe.config.Config configs; public SsoManager( com.typesafe.config.Config configs, Authentication authentication, String ssoSettingsRequestUrl, CloseableHttpClient httpClient) { - _configs = configs; - _authentication = Objects.requireNonNull(authentication, "authentication cannot be null"); - _ssoSettingsRequestUrl = + this.configs = configs; + this.authentication = Objects.requireNonNull(authentication, "authentication cannot be null"); + this.ssoSettingsRequestUrl = Objects.requireNonNull(ssoSettingsRequestUrl, "ssoSettingsRequestUrl cannot be null"); - _httpClient = Objects.requireNonNull(httpClient, "httpClient cannot be null"); + this.httpClient = Objects.requireNonNull(httpClient, "httpClient cannot be null"); _provider = null; } @@ -66,7 +65,7 @@ public void setSsoProvider(final SsoProvider provider) { } public void setConfigs(final com.typesafe.config.Config configs) { - _configs = configs; + this.configs = configs; } public void clearSsoProvider() { @@ -87,19 +86,19 @@ public SsoProvider getSsoProvider() { public void initializeSsoProvider() { SsoConfigs ssoConfigs = null; try { - ssoConfigs = new SsoConfigs.Builder().from(_configs).build(); + ssoConfigs = new SsoConfigs.Builder().from(configs).build(); } catch (Exception e) { // Debug-level logging since this is expected to fail if SSO has not been configured. - log.debug(String.format("Missing SSO settings in static configs %s", _configs), e); + log.debug(String.format("Missing SSO settings in static configs %s", configs), e); } if (ssoConfigs != null && ssoConfigs.isOidcEnabled()) { try { - OidcConfigs oidcConfigs = new OidcConfigs.Builder().from(_configs).build(); + OidcConfigs oidcConfigs = new OidcConfigs.Builder().from(configs).build(); maybeUpdateOidcProvider(oidcConfigs); } catch (Exception e) { // Error-level logging since this is unexpected to fail if SSO has been configured. - log.error(String.format("Error building OidcConfigs from static configs %s", _configs), e); + log.error(String.format("Error building OidcConfigs from static configs %s", configs), e); } } else { // Clear the SSO Provider since no SSO is enabled. @@ -132,7 +131,7 @@ private void refreshSsoProvider() { if (ssoConfigs != null && ssoConfigs.isOidcEnabled()) { try { OidcConfigs oidcConfigs = - new OidcConfigs.Builder().from(_configs, ssoSettingsJsonStr).build(); + new OidcConfigs.Builder().from(configs, ssoSettingsJsonStr).build(); maybeUpdateOidcProvider(oidcConfigs); } catch (Exception e) { log.error( @@ -166,15 +165,15 @@ private void maybeUpdateOidcProvider(OidcConfigs oidcConfigs) { private Optional getDynamicSsoSettings() { CloseableHttpResponse response = null; try { - final HttpPost request = new HttpPost(_ssoSettingsRequestUrl); + final HttpPost request = new HttpPost(ssoSettingsRequestUrl); // Build JSON request to verify credentials for a native user. request.setEntity(new StringEntity("")); // Add authorization header with DataHub frontend system id and secret. - request.addHeader(Http.HeaderNames.AUTHORIZATION, _authentication.getCredentials()); + request.addHeader(Http.HeaderNames.AUTHORIZATION, authentication.getCredentials()); - response = _httpClient.execute(request); + response = httpClient.execute(request); final HttpEntity entity = response.getEntity(); if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK && entity != null) { // Successfully received the SSO settings diff --git a/datahub-frontend/app/auth/sso/SsoProvider.java b/datahub-frontend/app/auth/sso/SsoProvider.java index a0947b52b92ae..2a6fff728966e 100644 --- a/datahub-frontend/app/auth/sso/SsoProvider.java +++ b/datahub-frontend/app/auth/sso/SsoProvider.java @@ -1,7 +1,6 @@ package auth.sso; import org.pac4j.core.client.Client; -import org.pac4j.core.credentials.Credentials; /** A thin interface over a Pac4j {@link Client} object and its associated configurations. */ public interface SsoProvider { @@ -12,14 +11,14 @@ enum SsoProtocol { // SAML -- not yet supported. // Common name appears in the Callback URL itself. - private final String _commonName; + private final String commonName; public String getCommonName() { - return _commonName; + return commonName; } SsoProtocol(String commonName) { - _commonName = commonName; + this.commonName = commonName; } } @@ -30,5 +29,5 @@ public String getCommonName() { SsoProtocol protocol(); /** Retrieves an initialized Pac4j {@link Client}. */ - Client client(); + Client client(); } diff --git a/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java b/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java index fa676d2d16c90..3a4433b0ca81e 100644 --- a/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java +++ b/datahub-frontend/app/auth/sso/oidc/OidcAuthorizationGenerator.java @@ -5,7 +5,7 @@ import java.util.Map.Entry; import java.util.Optional; import org.pac4j.core.authorization.generator.AuthorizationGenerator; -import org.pac4j.core.context.WebContext; +import org.pac4j.core.context.CallContext; import org.pac4j.core.profile.AttributeLocation; import org.pac4j.core.profile.CommonProfile; import org.pac4j.core.profile.UserProfile; @@ -18,24 +18,31 @@ public class OidcAuthorizationGenerator implements AuthorizationGenerator { private static final Logger logger = LoggerFactory.getLogger(OidcAuthorizationGenerator.class); - private final ProfileDefinition profileDef; - + private final ProfileDefinition profileDef; private final OidcConfigs oidcConfigs; public OidcAuthorizationGenerator( - final ProfileDefinition profileDef, final OidcConfigs oidcConfigs) { + final ProfileDefinition profileDef, final OidcConfigs oidcConfigs) { this.profileDef = profileDef; this.oidcConfigs = oidcConfigs; } @Override - public Optional generate(WebContext context, UserProfile profile) { + public Optional generate(final CallContext context, final UserProfile profile) { + if (!(profile instanceof OidcProfile oidcProfile)) { + return Optional.of(profile); + } + if (oidcConfigs.getExtractJwtAccessTokenClaims().orElse(false)) { try { - final JWT jwt = JWTParser.parse(((OidcProfile) profile).getAccessToken().getValue()); + final JWT jwt = JWTParser.parse(oidcProfile.getAccessToken().getValue()); CommonProfile commonProfile = new CommonProfile(); + // Copy existing attributes + profile.getAttributes().forEach(commonProfile::addAttribute); + + // Add JWT claims for (final Entry entry : jwt.getJWTClaimsSet().getClaims().entrySet()) { final String claimName = entry.getKey(); @@ -51,6 +58,6 @@ public Optional generate(WebContext context, UserProfile profile) { } } - return Optional.ofNullable(profile); + return Optional.of(profile); } } diff --git a/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java b/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java index 510804ba17f1a..ef5833f607efd 100644 --- a/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java +++ b/datahub-frontend/app/auth/sso/oidc/OidcCallbackLogic.java @@ -39,6 +39,8 @@ import com.linkedin.metadata.utils.GenericRecordUtils; import com.linkedin.mxe.MetadataChangeProposal; import com.linkedin.r2.RemoteInvocationException; +import com.linkedin.util.Pair; +import io.datahubproject.metadata.context.OperationContext; import java.io.UnsupportedEncodingException; import java.net.MalformedURLException; import java.net.URI; @@ -51,27 +53,36 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; - -import io.datahubproject.metadata.context.OperationContext; +import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; +import org.pac4j.core.client.BaseClient; +import org.pac4j.core.client.Client; +import org.pac4j.core.client.Clients; import org.pac4j.core.config.Config; +import org.pac4j.core.context.CallContext; import org.pac4j.core.context.Cookie; +import org.pac4j.core.context.FrameworkParameters; +import org.pac4j.core.context.WebContext; +import org.pac4j.core.credentials.Credentials; import org.pac4j.core.engine.DefaultCallbackLogic; +import org.pac4j.core.exception.http.HttpAction; import org.pac4j.core.http.adapter.HttpActionAdapter; import org.pac4j.core.profile.CommonProfile; import org.pac4j.core.profile.ProfileManager; import org.pac4j.core.profile.UserProfile; +import org.pac4j.core.util.CommonHelper; import org.pac4j.core.util.Pac4jConstants; -import org.pac4j.play.PlayWebContext; +import org.pac4j.play.store.PlayCookieSessionStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import play.mvc.Result; -import javax.annotation.Nonnull; - /** * This class contains the logic that is executed when an OpenID Connect Identity Provider redirects * back to D DataHub after an authentication attempt. @@ -83,7 +94,8 @@ * if the user does not already exist. */ @Slf4j -public class OidcCallbackLogic extends DefaultCallbackLogic { +public class OidcCallbackLogic extends DefaultCallbackLogic { + private static final Logger LOGGER = LoggerFactory.getLogger(OidcCallbackLogic.class); private final SsoManager ssoManager; private final SystemEntityClient systemEntityClient; @@ -105,67 +117,129 @@ public OidcCallbackLogic( } @Override - public Result perform( - PlayWebContext context, + public Object perform( Config config, - HttpActionAdapter httpActionAdapter, - String defaultUrl, - Boolean saveInSession, - Boolean multiProfile, - Boolean renewSession, - String defaultClient) { - - setContextRedirectUrl(context); - - final Result result = - super.perform( - context, - config, - httpActionAdapter, - defaultUrl, - saveInSession, - multiProfile, - renewSession, - defaultClient); + String inputDefaultUrl, + Boolean inputRenewSession, + String defaultClient, + FrameworkParameters parameters) { + + final Pair ctxResult = + superPerform(config, inputDefaultUrl, inputRenewSession, defaultClient, parameters); + + CallContext ctx = ctxResult.getFirst(); + Result result = (Result) ctxResult.getSecond(); + + setContextRedirectUrl(ctx); // Handle OIDC authentication errors. - if (OidcResponseErrorHandler.isError(context)) { - return OidcResponseErrorHandler.handleError(context); + if (OidcResponseErrorHandler.isError(ctx)) { + return OidcResponseErrorHandler.handleError(ctx); } // By this point, we know that OIDC is the enabled provider. final OidcConfigs oidcConfigs = (OidcConfigs) ssoManager.getSsoProvider().configs(); - return handleOidcCallback(systemOperationContext, oidcConfigs, result, getProfileManager(context)); + return handleOidcCallback(systemOperationContext, ctx, oidcConfigs, result); + } + + /** Overriding this to be able to intercept the CallContext being created */ + private Pair superPerform( + Config config, + String inputDefaultUrl, + Boolean inputRenewSession, + String defaultClient, + FrameworkParameters parameters) { + LOGGER.debug("=== CALLBACK ==="); + CallContext ctx = this.buildContext(config, parameters); + WebContext webContext = ctx.webContext(); + HttpActionAdapter httpActionAdapter = config.getHttpActionAdapter(); + CommonHelper.assertNotNull("httpActionAdapter", httpActionAdapter); + + HttpAction action; + try { + CommonHelper.assertNotNull("clientFinder", getClientFinder()); + String defaultUrl = (String) Objects.requireNonNullElse(inputDefaultUrl, "/"); + boolean renewSession = inputRenewSession == null || inputRenewSession; + CommonHelper.assertNotBlank("defaultUrl", defaultUrl); + Clients clients = config.getClients(); + CommonHelper.assertNotNull("clients", clients); + List foundClients = getClientFinder().find(clients, webContext, defaultClient); + CommonHelper.assertTrue( + foundClients != null && foundClients.size() == 1, + "unable to find one indirect client for the callback: check the callback URL for a client name parameter or suffix path or ensure that your configuration defaults to one indirect client"); + Client foundClient = (Client) foundClients.get(0); + LOGGER.debug("foundClient: {}", foundClient); + CommonHelper.assertNotNull("foundClient", foundClient); + Credentials credentials = (Credentials) foundClient.getCredentials(ctx).orElse(null); + LOGGER.debug("extracted credentials: {}", credentials); + credentials = (Credentials) foundClient.validateCredentials(ctx, credentials).orElse(null); + LOGGER.debug("validated credentials: {}", credentials); + if (credentials != null && !credentials.isForAuthentication()) { + action = foundClient.processLogout(ctx, credentials); + } else { + if (credentials != null) { + Optional optProfile = foundClient.getUserProfile(ctx, credentials); + LOGGER.debug("optProfile: {}", optProfile); + if (optProfile.isPresent()) { + UserProfile profile = (UserProfile) optProfile.get(); + Boolean saveProfileInSession = + ((BaseClient) foundClient).getSaveProfileInSession(webContext, profile); + boolean multiProfile = ((BaseClient) foundClient).isMultiProfile(webContext, profile); + LOGGER.debug( + "saveProfileInSession: {} / multiProfile: {}", saveProfileInSession, multiProfile); + this.saveUserProfile( + ctx, config, profile, saveProfileInSession, multiProfile, renewSession); + } + } + + action = this.redirectToOriginallyRequestedUrl(ctx, defaultUrl); + } + } catch (RuntimeException var20) { + RuntimeException e = var20; + return Pair.of(ctx, this.handleException(e, httpActionAdapter, webContext)); + } + + return Pair.of(ctx, httpActionAdapter.adapt(action, webContext)); } - @SuppressWarnings("unchecked") - private void setContextRedirectUrl(PlayWebContext context) { + private void setContextRedirectUrl(CallContext ctx) { + WebContext context = ctx.webContext(); + PlayCookieSessionStore sessionStore = (PlayCookieSessionStore) ctx.sessionStore(); + Optional redirectUrl = context.getRequestCookies().stream() .filter(cookie -> REDIRECT_URL_COOKIE_NAME.equals(cookie.getName())) .findFirst(); redirectUrl.ifPresent( cookie -> - context - .getSessionStore() - .set( - context, - Pac4jConstants.REQUESTED_URL, - JAVA_SER_HELPER.deserializeFromBytes( + sessionStore.set( + context, + Pac4jConstants.REQUESTED_URL, + sessionStore + .getSerializer() + .deserializeFromBytes( uncompressBytes(Base64.getDecoder().decode(cookie.getValue()))))); } private Result handleOidcCallback( final OperationContext opContext, + final CallContext ctx, final OidcConfigs oidcConfigs, - final Result result, - final ProfileManager profileManager) { + final Result result) { log.debug("Beginning OIDC Callback Handling..."); + ProfileManager profileManager = + ctx.profileManagerFactory().apply(ctx.webContext(), ctx.sessionStore()); + if (profileManager.isAuthenticated()) { // If authenticated, the user should have a profile. - final CommonProfile profile = (CommonProfile) profileManager.get(true).get(); + final Optional optProfile = profileManager.getProfile(); + if (optProfile.isEmpty()) { + return internalServerError( + "Failed to authenticate current user. Cannot find valid identity provider profile in session."); + } + final CommonProfile profile = (CommonProfile) optProfile.get(); log.debug( String.format( "Found authenticated user with profile %s", profile.getAttributes().toString())); @@ -196,7 +270,8 @@ private Result handleOidcCallback( } // Update user status to active on login. // If we want to prevent certain users from logging in, here's where we'll want to do it. - setUserStatus(opContext, + setUserStatus( + opContext, corpUserUrn, new CorpUserStatus() .setStatus(Constants.CORP_USER_STATUS_ACTIVE) @@ -307,29 +382,32 @@ private CorpUserSnapshot extractUser(CorpuserUrn urn, CommonProfile profile) { return corpUserSnapshot; } - public static Collection getGroupNames(CommonProfile profile, Object groupAttribute, String groupsClaimName) { - Collection groupNames = Collections.emptyList(); - try { - if (groupAttribute instanceof Collection) { - // List of group names - groupNames = (Collection) profile.getAttribute(groupsClaimName, Collection.class); - } else if (groupAttribute instanceof String) { - String groupString = (String) groupAttribute; - ObjectMapper objectMapper = new ObjectMapper(); - try { - // Json list of group names - groupNames = objectMapper.readValue(groupString, new TypeReference>(){}); - } catch (Exception e) { - groupNames = Arrays.asList(groupString.split(",")); - } + public static Collection getGroupNames( + CommonProfile profile, Object groupAttribute, String groupsClaimName) { + Collection groupNames = Collections.emptyList(); + try { + if (groupAttribute instanceof Collection) { + // List of group names + groupNames = (Collection) profile.getAttribute(groupsClaimName, Collection.class); + } else if (groupAttribute instanceof String) { + String groupString = (String) groupAttribute; + ObjectMapper objectMapper = new ObjectMapper(); + try { + // Json list of group names + groupNames = objectMapper.readValue(groupString, new TypeReference>() {}); + } catch (Exception e) { + groupNames = Arrays.asList(groupString.split(",")); } - } catch (Exception e) { - log.error(String.format( - "Failed to parse group names: Expected to find a list of strings for attribute with name %s, found %s", - groupsClaimName, profile.getAttribute(groupsClaimName).getClass())); } - return groupNames; + } catch (Exception e) { + log.error( + String.format( + "Failed to parse group names: Expected to find a list of strings for attribute with name %s, found %s", + groupsClaimName, profile.getAttribute(groupsClaimName).getClass())); + } + return groupNames; } + private List extractGroups(CommonProfile profile) { log.debug( @@ -350,7 +428,8 @@ private List extractGroups(CommonProfile profile) { if (profile.containsAttribute(groupsClaimName)) { try { final List groupSnapshots = new ArrayList<>(); - Collection groupNames = getGroupNames(profile, profile.getAttribute(groupsClaimName), groupsClaimName); + Collection groupNames = + getGroupNames(profile, profile.getAttribute(groupsClaimName), groupsClaimName); for (String groupName : groupNames) { // Create a basic CorpGroupSnapshot from the information. @@ -405,7 +484,8 @@ private GroupMembership createGroupMembership(final List extr return groupMembershipAspect; } - private void tryProvisionUser(@Nonnull OperationContext opContext, CorpUserSnapshot corpUserSnapshot) { + private void tryProvisionUser( + @Nonnull OperationContext opContext, CorpUserSnapshot corpUserSnapshot) { log.debug(String.format("Attempting to provision user with urn %s", corpUserSnapshot.getUrn())); @@ -439,7 +519,8 @@ private void tryProvisionUser(@Nonnull OperationContext opContext, CorpUserSnaps } } - private void tryProvisionGroups(@Nonnull OperationContext opContext, List corpGroups) { + private void tryProvisionGroups( + @Nonnull OperationContext opContext, List corpGroups) { log.debug( String.format( @@ -450,8 +531,7 @@ private void tryProvisionGroups(@Nonnull OperationContext opContext, List urnsToFetch = corpGroups.stream().map(CorpGroupSnapshot::getUrn).collect(Collectors.toSet()); - final Map existingGroups = - systemEntityClient.batchGet(opContext, urnsToFetch); + final Map existingGroups = systemEntityClient.batchGet(opContext, urnsToFetch); log.debug(String.format("Fetched GMS groups with urns %s", existingGroups.keySet())); @@ -489,7 +569,8 @@ private void tryProvisionGroups(@Nonnull OperationContext opContext, List new Entity().setValue(Snapshot.create(groupSnapshot))) .collect(Collectors.toSet())); @@ -505,7 +586,8 @@ private void tryProvisionGroups(@Nonnull OperationContext opContext, List useNonce; private final Optional customParamResource; private final String readTimeout; + private final String connectTimeout; private final Optional extractJwtAccessTokenClaims; private final Optional preferredJwsAlgorithm; private final Optional grantType; @@ -100,6 +103,7 @@ public OidcConfigs(Builder builder) { this.useNonce = builder.useNonce; this.customParamResource = builder.customParamResource; this.readTimeout = builder.readTimeout; + this.connectTimeout = builder.connectTimeout; this.extractJwtAccessTokenClaims = builder.extractJwtAccessTokenClaims; this.preferredJwsAlgorithm = builder.preferredJwsAlgorithm; this.acrValues = builder.acrValues; @@ -127,6 +131,7 @@ public static class Builder extends SsoConfigs.Builder { private Optional useNonce = Optional.empty(); private Optional customParamResource = Optional.empty(); private String readTimeout = DEFAULT_OIDC_READ_TIMEOUT; + private String connectTimeout = DEFAULT_OIDC_CONNECT_TIMEOUT; private Optional extractJwtAccessTokenClaims = Optional.empty(); private Optional preferredJwsAlgorithm = Optional.empty(); private Optional grantType = Optional.empty(); @@ -173,6 +178,7 @@ public Builder from(final com.typesafe.config.Config configs) { useNonce = getOptional(configs, OIDC_USE_NONCE).map(Boolean::parseBoolean); customParamResource = getOptional(configs, OIDC_CUSTOM_PARAM_RESOURCE); readTimeout = getOptional(configs, OIDC_READ_TIMEOUT, DEFAULT_OIDC_READ_TIMEOUT); + connectTimeout = getOptional(configs, OIDC_CONNECT_TIMEOUT, DEFAULT_OIDC_CONNECT_TIMEOUT); extractJwtAccessTokenClaims = getOptional(configs, OIDC_EXTRACT_JWT_ACCESS_TOKEN_CLAIMS).map(Boolean::parseBoolean); preferredJwsAlgorithm = @@ -232,6 +238,9 @@ public Builder from(final com.typesafe.config.Config configs, final String ssoSe if (jsonNode.has(READ_TIMEOUT)) { readTimeout = jsonNode.get(READ_TIMEOUT).asText(); } + if (jsonNode.has(CONNECT_TIMEOUT)) { + connectTimeout = jsonNode.get(CONNECT_TIMEOUT).asText(); + } if (jsonNode.has(EXTRACT_JWT_ACCESS_TOKEN_CLAIMS)) { extractJwtAccessTokenClaims = Optional.of(jsonNode.get(EXTRACT_JWT_ACCESS_TOKEN_CLAIMS).asBoolean()); @@ -250,11 +259,11 @@ public Builder from(final com.typesafe.config.Config configs, final String ssoSe } public OidcConfigs build() { - Objects.requireNonNull(_oidcEnabled, "oidcEnabled is required"); + Objects.requireNonNull(oidcEnabled, "oidcEnabled is required"); Objects.requireNonNull(clientId, "clientId is required"); Objects.requireNonNull(clientSecret, "clientSecret is required"); Objects.requireNonNull(discoveryUri, "discoveryUri is required"); - Objects.requireNonNull(_authBaseUrl, "authBaseUrl is required"); + Objects.requireNonNull(authBaseUrl, "authBaseUrl is required"); return new OidcConfigs(this); } diff --git a/datahub-frontend/app/auth/sso/oidc/OidcProvider.java b/datahub-frontend/app/auth/sso/oidc/OidcProvider.java index a8a3205e8299c..7fcaa5a9683cb 100644 --- a/datahub-frontend/app/auth/sso/oidc/OidcProvider.java +++ b/datahub-frontend/app/auth/sso/oidc/OidcProvider.java @@ -2,15 +2,16 @@ import auth.sso.SsoProvider; import auth.sso.oidc.custom.CustomOidcClient; -import com.google.common.collect.ImmutableMap; +import com.nimbusds.jose.JWSAlgorithm; import java.util.HashMap; import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.pac4j.core.client.Client; import org.pac4j.core.http.callback.PathParameterCallbackUrlResolver; import org.pac4j.oidc.config.OidcConfiguration; -import org.pac4j.oidc.credentials.OidcCredentials; import org.pac4j.oidc.profile.OidcProfileDefinition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of {@link SsoProvider} supporting the OIDC protocol. @@ -24,25 +25,25 @@ */ @Slf4j public class OidcProvider implements SsoProvider { - + private static final Logger logger = LoggerFactory.getLogger(OidcProvider.class); private static final String OIDC_CLIENT_NAME = "oidc"; - private final OidcConfigs _oidcConfigs; - private final Client _oidcClient; // Used primarily for redirecting to IdP. + private final OidcConfigs oidcConfigs; + private final Client oidcClient; // Used primarily for redirecting to IdP. public OidcProvider(final OidcConfigs configs) { - _oidcConfigs = configs; - _oidcClient = createPac4jClient(); + oidcConfigs = configs; + oidcClient = createPac4jClient(); } @Override - public Client client() { - return _oidcClient; + public Client client() { + return oidcClient; } @Override public OidcConfigs configs() { - return _oidcConfigs; + return oidcConfigs; } @Override @@ -50,50 +51,52 @@ public SsoProtocol protocol() { return SsoProtocol.OIDC; } - private Client createPac4jClient() { + private Client createPac4jClient() { final OidcConfiguration oidcConfiguration = new OidcConfiguration(); - oidcConfiguration.setClientId(_oidcConfigs.getClientId()); - oidcConfiguration.setSecret(_oidcConfigs.getClientSecret()); - oidcConfiguration.setDiscoveryURI(_oidcConfigs.getDiscoveryUri()); + oidcConfiguration.setClientId(oidcConfigs.getClientId()); + oidcConfiguration.setSecret(oidcConfigs.getClientSecret()); + oidcConfiguration.setDiscoveryURI(oidcConfigs.getDiscoveryUri()); oidcConfiguration.setClientAuthenticationMethodAsString( - _oidcConfigs.getClientAuthenticationMethod()); - oidcConfiguration.setScope(_oidcConfigs.getScope()); + oidcConfigs.getClientAuthenticationMethod()); + oidcConfiguration.setScope(oidcConfigs.getScope()); try { - oidcConfiguration.setReadTimeout(Integer.parseInt(_oidcConfigs.getReadTimeout())); + oidcConfiguration.setConnectTimeout(Integer.parseInt(oidcConfigs.getConnectTimeout())); + } catch (NumberFormatException e) { + log.warn("Invalid connect timeout configuration, defaulting to 1000ms"); + } + try { + oidcConfiguration.setReadTimeout(Integer.parseInt(oidcConfigs.getReadTimeout())); } catch (NumberFormatException e) { log.warn("Invalid read timeout configuration, defaulting to 5000ms"); } - _oidcConfigs.getResponseType().ifPresent(oidcConfiguration::setResponseType); - _oidcConfigs.getResponseMode().ifPresent(oidcConfiguration::setResponseMode); - _oidcConfigs.getUseNonce().ifPresent(oidcConfiguration::setUseNonce); + oidcConfigs.getResponseType().ifPresent(oidcConfiguration::setResponseType); + oidcConfigs.getResponseMode().ifPresent(oidcConfiguration::setResponseMode); + oidcConfigs.getUseNonce().ifPresent(oidcConfiguration::setUseNonce); Map customParamsMap = new HashMap<>(); - _oidcConfigs - .getCustomParamResource() - .ifPresent(value -> customParamsMap.put("resource", value)); - _oidcConfigs - .getGrantType() - .ifPresent(value -> customParamsMap.put("grant_type", value)); - _oidcConfigs - .getAcrValues() - .ifPresent(value -> customParamsMap.put("acr_values", value)); + oidcConfigs.getCustomParamResource().ifPresent(value -> customParamsMap.put("resource", value)); + oidcConfigs.getGrantType().ifPresent(value -> customParamsMap.put("grant_type", value)); + oidcConfigs.getAcrValues().ifPresent(value -> customParamsMap.put("acr_values", value)); if (!customParamsMap.isEmpty()) { oidcConfiguration.setCustomParams(customParamsMap); } - _oidcConfigs + oidcConfigs .getPreferredJwsAlgorithm() .ifPresent( preferred -> { log.info("Setting preferredJwsAlgorithm: " + preferred); - oidcConfiguration.setPreferredJwsAlgorithm(preferred); + oidcConfiguration.setPreferredJwsAlgorithm(JWSAlgorithm.parse(preferred)); }); + // Enable state parameter validation + oidcConfiguration.setWithState(true); + final CustomOidcClient oidcClient = new CustomOidcClient(oidcConfiguration); oidcClient.setName(OIDC_CLIENT_NAME); - oidcClient.setCallbackUrl( - _oidcConfigs.getAuthBaseUrl() + _oidcConfigs.getAuthBaseCallbackPath()); + oidcClient.setCallbackUrl(oidcConfigs.getAuthBaseUrl() + oidcConfigs.getAuthBaseCallbackPath()); oidcClient.setCallbackUrlResolver(new PathParameterCallbackUrlResolver()); oidcClient.addAuthorizationGenerator( - new OidcAuthorizationGenerator(new OidcProfileDefinition(), _oidcConfigs)); + new OidcAuthorizationGenerator(new OidcProfileDefinition(), oidcConfigs)); + return oidcClient; } } diff --git a/datahub-frontend/app/auth/sso/oidc/OidcResponseErrorHandler.java b/datahub-frontend/app/auth/sso/oidc/OidcResponseErrorHandler.java index 9881b5e095b78..2843beee61610 100644 --- a/datahub-frontend/app/auth/sso/oidc/OidcResponseErrorHandler.java +++ b/datahub-frontend/app/auth/sso/oidc/OidcResponseErrorHandler.java @@ -4,7 +4,8 @@ import static play.mvc.Results.unauthorized; import java.util.Optional; -import org.pac4j.play.PlayWebContext; +import org.pac4j.core.context.CallContext; +import org.pac4j.core.context.WebContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import play.mvc.Result; @@ -13,14 +14,14 @@ public class OidcResponseErrorHandler { private OidcResponseErrorHandler() {} - private static final Logger _logger = LoggerFactory.getLogger("OidcResponseErrorHandler"); + private static final Logger logger = LoggerFactory.getLogger("OidcResponseErrorHandler"); private static final String ERROR_FIELD_NAME = "error"; private static final String ERROR_DESCRIPTION_FIELD_NAME = "error_description"; - public static Result handleError(final PlayWebContext context) { - - _logger.warn( + public static Result handleError(final CallContext ctx) { + WebContext context = ctx.webContext(); + logger.warn( "OIDC responded with an error: '{}'. Error description: '{}'", getError(context), getErrorDescription(context)); @@ -44,15 +45,15 @@ public static Result handleError(final PlayWebContext context) { getError(context).orElse(""), getErrorDescription(context).orElse(""))); } - public static boolean isError(final PlayWebContext context) { - return getError(context).isPresent() && !getError(context).get().isEmpty(); + public static boolean isError(final CallContext ctx) { + return getError(ctx.webContext()).isPresent() && !getError(ctx.webContext()).get().isEmpty(); } - public static Optional getError(final PlayWebContext context) { + public static Optional getError(final WebContext context) { return context.getRequestParameter(ERROR_FIELD_NAME); } - public static Optional getErrorDescription(final PlayWebContext context) { + public static Optional getErrorDescription(final WebContext context) { return context.getRequestParameter(ERROR_DESCRIPTION_FIELD_NAME); } } diff --git a/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcAuthenticator.java b/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcAuthenticator.java index 01f8f16171d13..2288547cf6ed1 100644 --- a/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcAuthenticator.java +++ b/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcAuthenticator.java @@ -2,7 +2,6 @@ import com.nimbusds.oauth2.sdk.AuthorizationCode; import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; -import com.nimbusds.oauth2.sdk.AuthorizationGrant; import com.nimbusds.oauth2.sdk.ParseException; import com.nimbusds.oauth2.sdk.TokenErrorResponse; import com.nimbusds.oauth2.sdk.TokenRequest; @@ -18,6 +17,7 @@ import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; import com.nimbusds.openid.connect.sdk.token.OIDCTokens; import java.io.IOException; import java.net.URI; @@ -25,9 +25,11 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Objects; import java.util.Optional; +import org.pac4j.core.context.CallContext; import org.pac4j.core.context.WebContext; -import org.pac4j.core.credentials.authenticator.Authenticator; +import org.pac4j.core.credentials.Credentials; import org.pac4j.core.exception.TechnicalException; import org.pac4j.core.util.CommonHelper; import org.pac4j.oidc.client.OidcClient; @@ -37,9 +39,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class CustomOidcAuthenticator implements Authenticator { +public class CustomOidcAuthenticator extends OidcAuthenticator { - private static final Logger logger = LoggerFactory.getLogger(OidcAuthenticator.class); + private static final Logger logger = LoggerFactory.getLogger(CustomOidcAuthenticator.class); private static final Collection SUPPORTED_METHODS = Arrays.asList( @@ -47,21 +49,24 @@ public class CustomOidcAuthenticator implements Authenticator { ClientAuthenticationMethod.CLIENT_SECRET_BASIC, ClientAuthenticationMethod.NONE); - protected OidcConfiguration configuration; - - protected OidcClient client; - private final ClientAuthentication clientAuthentication; - public CustomOidcAuthenticator(final OidcClient client) { - CommonHelper.assertNotNull("configuration", client.getConfiguration()); - CommonHelper.assertNotNull("client", client); - this.configuration = client.getConfiguration(); - this.client = client; + public CustomOidcAuthenticator(final OidcClient client) { + super(client.getConfiguration(), client); // check authentication methods - final List metadataMethods = - configuration.findProviderMetadata().getTokenEndpointAuthMethods(); + OIDCProviderMetadata providerMetadata; + try { + providerMetadata = loadWithRetry(); + } catch (TechnicalException e) { + logger.error( + "Could not resolve identity provider's remote configuration from DiscoveryURI: {}", + configuration.getDiscoveryURI()); + throw e; + } + + List metadataMethods = + providerMetadata.getTokenEndpointAuthMethods(); final ClientAuthenticationMethod preferredMethod = getPreferredAuthenticationMethod(configuration); @@ -146,8 +151,11 @@ private static ClientAuthenticationMethod firstSupportedMethod( } @Override - public void validate(final OidcCredentials credentials, final WebContext context) { - final AuthorizationCode code = credentials.getCode(); + public Optional validate(CallContext ctx, Credentials cred) { + OidcCredentials credentials = (OidcCredentials) cred; + WebContext context = ctx.webContext(); + + final AuthorizationCode code = credentials.toAuthorizationCode(); // if we have a code if (code != null) { try { @@ -156,7 +164,7 @@ public void validate(final OidcCredentials credentials, final WebContext context (CodeVerifier) configuration .getValueRetriever() - .retrieve(client.getCodeVerifierSessionAttributeName(), client, context) + .retrieve(ctx, client.getCodeVerifierSessionAttributeName(), client) .orElse(null); // Token request final TokenRequest request = @@ -182,27 +190,49 @@ public void validate(final OidcCredentials credentials, final WebContext context // save tokens in credentials final OIDCTokens oidcTokens = tokenSuccessResponse.getOIDCTokens(); - credentials.setAccessToken(oidcTokens.getAccessToken()); - credentials.setRefreshToken(oidcTokens.getRefreshToken()); - credentials.setIdToken(oidcTokens.getIDToken()); + credentials.setAccessTokenObject(oidcTokens.getAccessToken()); + + // Only set refresh token if it exists + if (oidcTokens.getRefreshToken() != null) { + credentials.setRefreshTokenObject(oidcTokens.getRefreshToken()); + } + + if (oidcTokens.getIDToken() != null) { + credentials.setIdToken(oidcTokens.getIDToken().getParsedString()); + } } catch (final URISyntaxException | IOException | ParseException e) { throw new TechnicalException(e); } } + + return Optional.ofNullable(cred); } - private TokenRequest createTokenRequest(final AuthorizationGrant grant) { - if (clientAuthentication != null) { - return new TokenRequest( - configuration.findProviderMetadata().getTokenEndpointURI(), - this.clientAuthentication, - grant); - } else { - return new TokenRequest( - configuration.findProviderMetadata().getTokenEndpointURI(), - new ClientID(configuration.getClientId()), - grant); + // Simple retry with exponential backoff + public OIDCProviderMetadata loadWithRetry() { + int maxAttempts = 3; + long initialDelay = 1000; // 1 second + + for (int attempt = 1; attempt <= maxAttempts; attempt++) { + try { + OIDCProviderMetadata providerMetadata = configuration.getOpMetadataResolver().load(); + return Objects.requireNonNull(providerMetadata); + } catch (RuntimeException e) { + if (attempt == maxAttempts) { + throw e; // Rethrow on final attempt + } + try { + // Exponential backoff + Thread.sleep(initialDelay * (long) Math.pow(2, attempt - 1)); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Retry interrupted", ie); + } + logger.warn("Retry attempt {} of {} failed", attempt, maxAttempts, e); + } } + throw new RuntimeException( + "Failed to load provider metadata after " + maxAttempts + " attempts"); } } diff --git a/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcClient.java b/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcClient.java index 3a0a247cb761e..28c927c2aa541 100644 --- a/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcClient.java +++ b/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcClient.java @@ -3,25 +3,38 @@ import org.pac4j.core.util.CommonHelper; import org.pac4j.oidc.client.OidcClient; import org.pac4j.oidc.config.OidcConfiguration; -import org.pac4j.oidc.credentials.extractor.OidcExtractor; +import org.pac4j.oidc.credentials.extractor.OidcCredentialsExtractor; import org.pac4j.oidc.logout.OidcLogoutActionBuilder; +import org.pac4j.oidc.logout.processor.OidcLogoutProcessor; import org.pac4j.oidc.profile.creator.OidcProfileCreator; -import org.pac4j.oidc.redirect.OidcRedirectionActionBuilder; -public class CustomOidcClient extends OidcClient { +public class CustomOidcClient extends OidcClient { - public CustomOidcClient(final OidcConfiguration configuration) { - setConfiguration(configuration); + public CustomOidcClient(OidcConfiguration configuration) { + super(configuration); } @Override - protected void clientInit() { + protected void internalInit(final boolean forceReinit) { + // Validate configuration CommonHelper.assertNotNull("configuration", getConfiguration()); - getConfiguration().init(); - defaultRedirectionActionBuilder(new CustomOidcRedirectionActionBuilder(getConfiguration(), this)); - defaultCredentialsExtractor(new OidcExtractor(getConfiguration(), this)); - defaultAuthenticator(new CustomOidcAuthenticator(this)); - defaultProfileCreator(new OidcProfileCreator<>(getConfiguration(), this)); - defaultLogoutActionBuilder(new OidcLogoutActionBuilder(getConfiguration())); + + // Initialize configuration + getConfiguration().init(forceReinit); + + // Initialize client components + setRedirectionActionBuilderIfUndefined( + new CustomOidcRedirectionActionBuilder(getConfiguration(), this)); + setCredentialsExtractorIfUndefined(new OidcCredentialsExtractor(getConfiguration(), this)); + + // Initialize default authenticator if not set + if (getAuthenticator() == null || forceReinit) { + setAuthenticatorIfUndefined(new CustomOidcAuthenticator(this)); + } + + setProfileCreatorIfUndefined(new OidcProfileCreator(getConfiguration(), this)); + setLogoutProcessorIfUndefined( + new OidcLogoutProcessor(getConfiguration(), findSessionLogoutHandler())); + setLogoutActionBuilderIfUndefined(new OidcLogoutActionBuilder(getConfiguration())); } } diff --git a/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcRedirectionActionBuilder.java b/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcRedirectionActionBuilder.java index bdeeacc895af3..ea5315972344d 100644 --- a/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcRedirectionActionBuilder.java +++ b/datahub-frontend/app/auth/sso/oidc/custom/CustomOidcRedirectionActionBuilder.java @@ -2,29 +2,35 @@ import java.util.Map; import java.util.Optional; +import org.pac4j.core.context.CallContext; import org.pac4j.core.context.WebContext; import org.pac4j.core.exception.http.RedirectionAction; -import org.pac4j.core.exception.http.RedirectionActionHelper; +import org.pac4j.core.util.HttpActionHelper; import org.pac4j.oidc.client.OidcClient; import org.pac4j.oidc.config.OidcConfiguration; import org.pac4j.oidc.redirect.OidcRedirectionActionBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - public class CustomOidcRedirectionActionBuilder extends OidcRedirectionActionBuilder { private static final Logger logger = LoggerFactory.getLogger(OidcRedirectionActionBuilder.class); + + private final OidcConfiguration configuration; + public CustomOidcRedirectionActionBuilder(OidcConfiguration configuration, OidcClient client) { - super(configuration, client); + super(client); + this.configuration = configuration; } @Override - public Optional getRedirectionAction(WebContext context) { - Map params = this.buildParams(); + public Optional getRedirectionAction(CallContext ctx) { + WebContext context = ctx.webContext(); + + Map params = this.buildParams(context); String computedCallbackUrl = this.client.computeFinalCallbackUrl(context); params.put("redirect_uri", computedCallbackUrl); - this.addStateAndNonceParameters(context, params); + this.addStateAndNonceParameters(ctx, params); if (this.configuration.getMaxAge() != null) { params.put("max_age", this.configuration.getMaxAge().toString()); } @@ -40,7 +46,6 @@ public Optional getRedirectionAction(WebContext context) { } logger.debug("Authentication request url: {}", location); - return Optional.of(RedirectionActionHelper.buildRedirectUrlAction(context, location)); + return Optional.of(HttpActionHelper.buildRedirectUrlAction(context, location)); } - } diff --git a/datahub-frontend/app/client/KafkaTrackingProducer.java b/datahub-frontend/app/client/KafkaTrackingProducer.java index 058e75100c24a..a29fe3bb7aef7 100644 --- a/datahub-frontend/app/client/KafkaTrackingProducer.java +++ b/datahub-frontend/app/client/KafkaTrackingProducer.java @@ -27,7 +27,8 @@ @Singleton public class KafkaTrackingProducer { - private final Logger _logger = LoggerFactory.getLogger(KafkaTrackingProducer.class.getName()); + private static final Logger logger = + LoggerFactory.getLogger(KafkaTrackingProducer.class.getName()); private static final List KAFKA_SSL_PROTOCOLS = Collections.unmodifiableList( Arrays.asList( @@ -35,38 +36,38 @@ public class KafkaTrackingProducer { SecurityProtocol.SASL_SSL.name(), SecurityProtocol.SASL_PLAINTEXT.name())); - private final Boolean _isEnabled; - private final KafkaProducer _producer; + private final Boolean isEnabled; + private final KafkaProducer producer; @Inject public KafkaTrackingProducer( @Nonnull Config config, ApplicationLifecycle lifecycle, final ConfigurationProvider configurationProvider) { - _isEnabled = !config.hasPath("analytics.enabled") || config.getBoolean("analytics.enabled"); + isEnabled = !config.hasPath("analytics.enabled") || config.getBoolean("analytics.enabled"); - if (_isEnabled) { - _logger.debug("Analytics tracking is enabled"); - _producer = createKafkaProducer(config, configurationProvider.getKafka()); + if (isEnabled) { + logger.debug("Analytics tracking is enabled"); + producer = createKafkaProducer(config, configurationProvider.getKafka()); lifecycle.addStopHook( () -> { - _producer.flush(); - _producer.close(); + producer.flush(); + producer.close(); return CompletableFuture.completedFuture(null); }); } else { - _logger.debug("Analytics tracking is disabled"); - _producer = null; + logger.debug("Analytics tracking is disabled"); + producer = null; } } public Boolean isEnabled() { - return _isEnabled; + return isEnabled; } public void send(ProducerRecord record) { - _producer.send(record); + producer.send(record); } private static KafkaProducer createKafkaProducer( diff --git a/datahub-frontend/app/config/ConfigurationProvider.java b/datahub-frontend/app/config/ConfigurationProvider.java index d447b28cdcc46..97e916769a6c4 100644 --- a/datahub-frontend/app/config/ConfigurationProvider.java +++ b/datahub-frontend/app/config/ConfigurationProvider.java @@ -6,12 +6,9 @@ import com.linkedin.metadata.config.kafka.KafkaConfiguration; import com.linkedin.metadata.spring.YamlPropertySourceFactory; import lombok.Data; -import org.springframework.boot.autoconfigure.kafka.KafkaProperties; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties; -import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.PropertySource; -import org.springframework.stereotype.Component; /** * Minimal sharing between metadata-service and frontend Does not use the factories module to avoid diff --git a/datahub-frontend/app/controllers/Application.java b/datahub-frontend/app/controllers/Application.java index 017847367de05..1c3786f4af526 100644 --- a/datahub-frontend/app/controllers/Application.java +++ b/datahub-frontend/app/controllers/Application.java @@ -36,8 +36,8 @@ import play.libs.ws.StandaloneWSClient; import play.libs.ws.ahc.StandaloneAhcWSClient; import play.mvc.Controller; -import play.mvc.Http.Cookie; import play.mvc.Http; +import play.mvc.Http.Cookie; import play.mvc.ResponseHeader; import play.mvc.Result; import play.mvc.Security; @@ -48,16 +48,16 @@ import utils.ConfigUtil; public class Application extends Controller { - private final Logger _logger = LoggerFactory.getLogger(Application.class.getName()); - private final Config _config; - private final StandaloneWSClient _ws; - private final Environment _environment; + private static final Logger logger = LoggerFactory.getLogger(Application.class.getName()); + private final Config config; + private final StandaloneWSClient ws; + private final Environment environment; @Inject public Application(Environment environment, @Nonnull Config config) { - _config = config; - _ws = createWsClient(); - _environment = environment; + this.config = config; + ws = createWsClient(); + this.environment = environment; } /** @@ -69,10 +69,10 @@ public Application(Environment environment, @Nonnull Config config) { @Nonnull private Result serveAsset(@Nullable String path) { try { - InputStream indexHtml = _environment.resourceAsStream("public/index.html"); + InputStream indexHtml = environment.resourceAsStream("public/index.html"); return ok(indexHtml).withHeader("Cache-Control", "no-cache").as("text/html"); } catch (Exception e) { - _logger.warn("Cannot load public/index.html resource. Static assets or assets jar missing?"); + logger.warn("Cannot load public/index.html resource. Static assets or assets jar missing?"); return notFound().withHeader("Cache-Control", "no-cache").as("text/html"); } } @@ -106,17 +106,17 @@ public CompletableFuture proxy(String path, Http.Request request) final String metadataServiceHost = ConfigUtil.getString( - _config, + config, ConfigUtil.METADATA_SERVICE_HOST_CONFIG_PATH, ConfigUtil.DEFAULT_METADATA_SERVICE_HOST); final int metadataServicePort = ConfigUtil.getInt( - _config, + config, ConfigUtil.METADATA_SERVICE_PORT_CONFIG_PATH, ConfigUtil.DEFAULT_METADATA_SERVICE_PORT); final boolean metadataServiceUseSsl = ConfigUtil.getBoolean( - _config, + config, ConfigUtil.METADATA_SERVICE_USE_SSL_CONFIG_PATH, ConfigUtil.DEFAULT_METADATA_SERVICE_USE_SSL); @@ -139,7 +139,7 @@ public CompletableFuture proxy(String path, Http.Request request) // Get the current time to measure the duration of the request Instant start = Instant.now(); - return _ws.url( + return ws.url( String.format( "%s://%s:%s%s", protocol, metadataServiceHost, metadataServicePort, resolvedUri)) .setMethod(request.method()) @@ -167,9 +167,10 @@ AuthenticationConstants.LEGACY_X_DATAHUB_ACTOR_HEADER, getDataHubActorHeader(req .execute() .thenApply( apiResponse -> { - // Log the query if it takes longer than the configured threshold and verbose logging is enabled - boolean verboseGraphQLLogging = _config.getBoolean("graphql.verbose.logging"); - int verboseGraphQLLongQueryMillis = _config.getInt("graphql.verbose.slowQueryMillis"); + // Log the query if it takes longer than the configured threshold and verbose logging + // is enabled + boolean verboseGraphQLLogging = config.getBoolean("graphql.verbose.logging"); + int verboseGraphQLLongQueryMillis = config.getInt("graphql.verbose.slowQueryMillis"); Instant finish = Instant.now(); long timeElapsed = Duration.between(start, finish).toMillis(); if (verboseGraphQLLogging && timeElapsed >= verboseGraphQLLongQueryMillis) { @@ -206,32 +207,32 @@ AuthenticationConstants.LEGACY_X_DATAHUB_ACTOR_HEADER, getDataHubActorHeader(req public Result appConfig() { final ObjectNode config = Json.newObject(); config.put("application", "datahub-frontend"); - config.put("appVersion", _config.getString("app.version")); - config.put("isInternal", _config.getBoolean("linkedin.internal")); - config.put("shouldShowDatasetLineage", _config.getBoolean("linkedin.show.dataset.lineage")); + config.put("appVersion", this.config.getString("app.version")); + config.put("isInternal", this.config.getBoolean("linkedin.internal")); + config.put("shouldShowDatasetLineage", this.config.getBoolean("linkedin.show.dataset.lineage")); config.put( "suggestionConfidenceThreshold", - Integer.valueOf(_config.getString("linkedin.suggestion.confidence.threshold"))); + Integer.valueOf(this.config.getString("linkedin.suggestion.confidence.threshold"))); config.set("wikiLinks", wikiLinks()); config.set("tracking", trackingInfo()); // In a staging environment, we can trigger this flag to be true so that the UI can handle based // on // such config and alert users that their changes will not affect production data - config.put("isStagingBanner", _config.getBoolean("ui.show.staging.banner")); - config.put("isLiveDataWarning", _config.getBoolean("ui.show.live.data.banner")); - config.put("showChangeManagement", _config.getBoolean("ui.show.CM.banner")); + config.put("isStagingBanner", this.config.getBoolean("ui.show.staging.banner")); + config.put("isLiveDataWarning", this.config.getBoolean("ui.show.live.data.banner")); + config.put("showChangeManagement", this.config.getBoolean("ui.show.CM.banner")); // Flag to enable people entity elements - config.put("showPeople", _config.getBoolean("ui.show.people")); - config.put("changeManagementLink", _config.getString("ui.show.CM.link")); + config.put("showPeople", this.config.getBoolean("ui.show.people")); + config.put("changeManagementLink", this.config.getString("ui.show.CM.link")); // Flag set in order to warn users that search is experiencing issues - config.put("isStaleSearch", _config.getBoolean("ui.show.stale.search")); - config.put("showAdvancedSearch", _config.getBoolean("ui.show.advanced.search")); + config.put("isStaleSearch", this.config.getBoolean("ui.show.stale.search")); + config.put("showAdvancedSearch", this.config.getBoolean("ui.show.advanced.search")); // Flag to use the new api for browsing datasets - config.put("useNewBrowseDataset", _config.getBoolean("ui.new.browse.dataset")); + config.put("useNewBrowseDataset", this.config.getBoolean("ui.new.browse.dataset")); // show lineage graph in relationships tabs - config.put("showLineageGraph", _config.getBoolean("ui.show.lineage.graph")); + config.put("showLineageGraph", this.config.getBoolean("ui.show.lineage.graph")); // show institutional memory for available entities - config.put("showInstitutionalMemory", _config.getBoolean("ui.show.institutional.memory")); + config.put("showInstitutionalMemory", this.config.getBoolean("ui.show.institutional.memory")); // Insert properties for user profile operations config.set("userEntityProps", userEntityProps()); @@ -250,8 +251,8 @@ public Result appConfig() { @Nonnull private ObjectNode userEntityProps() { final ObjectNode props = Json.newObject(); - props.put("aviUrlPrimary", _config.getString("linkedin.links.avi.urlPrimary")); - props.put("aviUrlFallback", _config.getString("linkedin.links.avi.urlFallback")); + props.put("aviUrlPrimary", config.getString("linkedin.links.avi.urlPrimary")); + props.put("aviUrlFallback", config.getString("linkedin.links.avi.urlFallback")); return props; } @@ -261,19 +262,19 @@ private ObjectNode userEntityProps() { @Nonnull private ObjectNode wikiLinks() { final ObjectNode wikiLinks = Json.newObject(); - wikiLinks.put("appHelp", _config.getString("links.wiki.appHelp")); - wikiLinks.put("gdprPii", _config.getString("links.wiki.gdprPii")); - wikiLinks.put("tmsSchema", _config.getString("links.wiki.tmsSchema")); - wikiLinks.put("gdprTaxonomy", _config.getString("links.wiki.gdprTaxonomy")); - wikiLinks.put("staleSearchIndex", _config.getString("links.wiki.staleSearchIndex")); - wikiLinks.put("dht", _config.getString("links.wiki.dht")); - wikiLinks.put("purgePolicies", _config.getString("links.wiki.purgePolicies")); - wikiLinks.put("jitAcl", _config.getString("links.wiki.jitAcl")); - wikiLinks.put("metadataCustomRegex", _config.getString("links.wiki.metadataCustomRegex")); - wikiLinks.put("exportPolicy", _config.getString("links.wiki.exportPolicy")); - wikiLinks.put("metadataHealth", _config.getString("links.wiki.metadataHealth")); - wikiLinks.put("purgeKey", _config.getString("links.wiki.purgeKey")); - wikiLinks.put("datasetDecommission", _config.getString("links.wiki.datasetDecommission")); + wikiLinks.put("appHelp", config.getString("links.wiki.appHelp")); + wikiLinks.put("gdprPii", config.getString("links.wiki.gdprPii")); + wikiLinks.put("tmsSchema", config.getString("links.wiki.tmsSchema")); + wikiLinks.put("gdprTaxonomy", config.getString("links.wiki.gdprTaxonomy")); + wikiLinks.put("staleSearchIndex", config.getString("links.wiki.staleSearchIndex")); + wikiLinks.put("dht", config.getString("links.wiki.dht")); + wikiLinks.put("purgePolicies", config.getString("links.wiki.purgePolicies")); + wikiLinks.put("jitAcl", config.getString("links.wiki.jitAcl")); + wikiLinks.put("metadataCustomRegex", config.getString("links.wiki.metadataCustomRegex")); + wikiLinks.put("exportPolicy", config.getString("links.wiki.exportPolicy")); + wikiLinks.put("metadataHealth", config.getString("links.wiki.metadataHealth")); + wikiLinks.put("purgeKey", config.getString("links.wiki.purgeKey")); + wikiLinks.put("datasetDecommission", config.getString("links.wiki.datasetDecommission")); return wikiLinks; } @@ -283,8 +284,8 @@ private ObjectNode wikiLinks() { @Nonnull private ObjectNode trackingInfo() { final ObjectNode piwik = Json.newObject(); - piwik.put("piwikSiteId", Integer.valueOf(_config.getString("tracking.piwik.siteid"))); - piwik.put("piwikUrl", _config.getString("tracking.piwik.url")); + piwik.put("piwikSiteId", Integer.valueOf(config.getString("tracking.piwik.siteid"))); + piwik.put("piwikUrl", config.getString("tracking.piwik.url")); final ObjectNode trackers = Json.newObject(); trackers.set("piwik", piwik); @@ -376,9 +377,10 @@ private String mapPath(@Nonnull final String path) { return path; } - /** - * Called if verbose logging is enabled and request takes longer that the slow query milliseconds defined in the config + * Called if verbose logging is enabled and request takes longer that the slow query milliseconds + * defined in the config + * * @param request GraphQL request that was made * @param resolvedUri URI that was requested * @param duration How long the query took to complete @@ -393,16 +395,16 @@ private void logSlowQuery(Http.Request request, String resolvedUri, float durati JsonNode jsonNode = request.body().asJson(); ((ObjectNode) jsonNode).remove("query"); jsonBody.append(mapper.writerWithDefaultPrettyPrinter().writeValueAsString(jsonNode)); - } - catch (Exception e) { - _logger.info("GraphQL Request Received: {}, Unable to parse JSON body", resolvedUri); + } catch (Exception e) { + logger.info("GraphQL Request Received: {}, Unable to parse JSON body", resolvedUri); } String jsonBodyStr = jsonBody.toString(); - _logger.info("Slow GraphQL Request Received: {}, Request query string: {}, Request actor: {}, Request JSON: {}, Request completed in {} ms", - resolvedUri, - request.queryString(), - actorValue, - jsonBodyStr, - duration); + logger.info( + "Slow GraphQL Request Received: {}, Request query string: {}, Request actor: {}, Request JSON: {}, Request completed in {} ms", + resolvedUri, + request.queryString(), + actorValue, + jsonBodyStr, + duration); } } diff --git a/datahub-frontend/app/controllers/AuthenticationController.java b/datahub-frontend/app/controllers/AuthenticationController.java index 87c4b5ba06793..3c8ace9aee29f 100644 --- a/datahub-frontend/app/controllers/AuthenticationController.java +++ b/datahub-frontend/app/controllers/AuthenticationController.java @@ -26,12 +26,14 @@ import org.apache.commons.httpclient.InvalidRedirectLocationException; import org.apache.commons.lang3.StringUtils; import org.pac4j.core.client.Client; +import org.pac4j.core.context.CallContext; import org.pac4j.core.context.Cookie; +import org.pac4j.core.context.WebContext; import org.pac4j.core.exception.http.FoundAction; import org.pac4j.core.exception.http.RedirectionAction; import org.pac4j.play.PlayWebContext; import org.pac4j.play.http.PlayHttpActionAdapter; -import org.pac4j.play.store.PlaySessionStore; +import org.pac4j.play.store.PlayCookieSessionStore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import play.data.validation.Constraints; @@ -51,26 +53,27 @@ public class AuthenticationController extends Controller { private static final String SSO_NO_REDIRECT_MESSAGE = "SSO is configured, however missing redirect from idp"; - private final Logger _logger = LoggerFactory.getLogger(AuthenticationController.class.getName()); - private final CookieConfigs _cookieConfigs; - private final JAASConfigs _jaasConfigs; - private final NativeAuthenticationConfigs _nativeAuthenticationConfigs; - private final boolean _verbose; + private static final Logger logger = + LoggerFactory.getLogger(AuthenticationController.class.getName()); + private final CookieConfigs cookieConfigs; + private final JAASConfigs jaasConfigs; + private final NativeAuthenticationConfigs nativeAuthenticationConfigs; + private final boolean verbose; - @Inject private org.pac4j.core.config.Config _ssoConfig; + @Inject private org.pac4j.core.config.Config ssoConfig; - @Inject private PlaySessionStore _playSessionStore; + @Inject private PlayCookieSessionStore playCookieSessionStore; - @Inject private SsoManager _ssoManager; + @Inject private SsoManager ssoManager; - @Inject AuthServiceClient _authClient; + @Inject AuthServiceClient authClient; @Inject public AuthenticationController(@Nonnull Config configs) { - _cookieConfigs = new CookieConfigs(configs); - _jaasConfigs = new JAASConfigs(configs); - _nativeAuthenticationConfigs = new NativeAuthenticationConfigs(configs); - _verbose = configs.hasPath(AUTH_VERBOSE_LOGGING) && configs.getBoolean(AUTH_VERBOSE_LOGGING); + cookieConfigs = new CookieConfigs(configs); + jaasConfigs = new JAASConfigs(configs); + nativeAuthenticationConfigs = new NativeAuthenticationConfigs(configs); + verbose = configs.hasPath(AUTH_VERBOSE_LOGGING) && configs.getBoolean(AUTH_VERBOSE_LOGGING); } /** @@ -92,11 +95,14 @@ public Result authenticate(Http.Request request) { try { URI redirectUri = new URI(redirectPath); if (redirectUri.getScheme() != null || redirectUri.getAuthority() != null) { - throw new InvalidRedirectLocationException("Redirect location must be relative to the base url, cannot " - + "redirect to other domains: " + redirectPath, redirectPath); + throw new InvalidRedirectLocationException( + "Redirect location must be relative to the base url, cannot " + + "redirect to other domains: " + + redirectPath, + redirectPath); } } catch (URISyntaxException | InvalidRedirectLocationException e) { - _logger.warn(e.getMessage()); + logger.warn(e.getMessage()); redirectPath = "/"; } @@ -105,7 +111,7 @@ public Result authenticate(Http.Request request) { } // 1. If SSO is enabled, redirect to IdP if not authenticated. - if (_ssoManager.isSsoEnabled()) { + if (ssoManager.isSsoEnabled()) { return redirectToIdentityProvider(request, redirectPath) .orElse( Results.redirect( @@ -114,8 +120,8 @@ public Result authenticate(Http.Request request) { } // 2. If either JAAS auth or Native auth is enabled, fallback to it - if (_jaasConfigs.isJAASEnabled() - || _nativeAuthenticationConfigs.isNativeAuthenticationEnabled()) { + if (jaasConfigs.isJAASEnabled() + || nativeAuthenticationConfigs.isNativeAuthenticationEnabled()) { return Results.redirect( LOGIN_ROUTE + String.format("?%s=%s", AUTH_REDIRECT_URI_PARAM, encodeRedirectUri(redirectPath))); @@ -123,21 +129,21 @@ public Result authenticate(Http.Request request) { // 3. If no auth enabled, fallback to using default user account & redirect. // Generate GMS session token, TODO: - final String accessToken = _authClient.generateSessionTokenForUser(DEFAULT_ACTOR_URN.getId()); + final String accessToken = authClient.generateSessionTokenForUser(DEFAULT_ACTOR_URN.getId()); return Results.redirect(redirectPath) .withSession(createSessionMap(DEFAULT_ACTOR_URN.toString(), accessToken)) .withCookies( createActorCookie( DEFAULT_ACTOR_URN.toString(), - _cookieConfigs.getTtlInHours(), - _cookieConfigs.getAuthCookieSameSite(), - _cookieConfigs.getAuthCookieSecure())); + cookieConfigs.getTtlInHours(), + cookieConfigs.getAuthCookieSameSite(), + cookieConfigs.getAuthCookieSecure())); } /** Redirect to the identity provider for authentication. */ @Nonnull public Result sso(Http.Request request) { - if (_ssoManager.isSsoEnabled()) { + if (ssoManager.isSsoEnabled()) { return redirectToIdentityProvider(request, "/") .orElse( Results.redirect( @@ -156,11 +162,11 @@ public Result sso(Http.Request request) { */ @Nonnull public Result logIn(Http.Request request) { - boolean jaasEnabled = _jaasConfigs.isJAASEnabled(); - _logger.debug(String.format("Jaas authentication enabled: %b", jaasEnabled)); + boolean jaasEnabled = jaasConfigs.isJAASEnabled(); + logger.debug(String.format("Jaas authentication enabled: %b", jaasEnabled)); boolean nativeAuthenticationEnabled = - _nativeAuthenticationConfigs.isNativeAuthenticationEnabled(); - _logger.debug(String.format("Native authentication enabled: %b", nativeAuthenticationEnabled)); + nativeAuthenticationConfigs.isNativeAuthenticationEnabled(); + logger.debug(String.format("Native authentication enabled: %b", nativeAuthenticationEnabled)); boolean noAuthEnabled = !jaasEnabled && !nativeAuthenticationEnabled; if (noAuthEnabled) { String message = "Neither JAAS nor native authentication is enabled on the server."; @@ -182,13 +188,13 @@ public Result logIn(Http.Request request) { boolean loginSucceeded = tryLogin(username, password); if (!loginSucceeded) { - _logger.info("Login failed for user: {}", username); + logger.info("Login failed for user: {}", username); return Results.badRequest(invalidCredsJson); } final Urn actorUrn = new CorpuserUrn(username); - _logger.info("Login successful for user: {}, urn: {}", username, actorUrn); - final String accessToken = _authClient.generateSessionTokenForUser(actorUrn.getId()); + logger.info("Login successful for user: {}, urn: {}", username, actorUrn); + final String accessToken = authClient.generateSessionTokenForUser(actorUrn.getId()); return createSession(actorUrn.toString(), accessToken); } @@ -199,8 +205,8 @@ public Result logIn(Http.Request request) { @Nonnull public Result signUp(Http.Request request) { boolean nativeAuthenticationEnabled = - _nativeAuthenticationConfigs.isNativeAuthenticationEnabled(); - _logger.debug(String.format("Native authentication enabled: %b", nativeAuthenticationEnabled)); + nativeAuthenticationConfigs.isNativeAuthenticationEnabled(); + logger.debug(String.format("Native authentication enabled: %b", nativeAuthenticationEnabled)); if (!nativeAuthenticationEnabled) { String message = "Native authentication is not enabled on the server."; final ObjectNode error = Json.newObject(); @@ -224,7 +230,7 @@ public Result signUp(Http.Request request) { JsonNode invalidCredsJson = Json.newObject().put("message", "Email must not be empty."); return Results.badRequest(invalidCredsJson); } - if (_nativeAuthenticationConfigs.isEnforceValidEmailEnabled()) { + if (nativeAuthenticationConfigs.isEnforceValidEmailEnabled()) { Constraints.EmailValidator emailValidator = new Constraints.EmailValidator(); if (!emailValidator.isValid(email)) { JsonNode invalidCredsJson = Json.newObject().put("message", "Email must not be empty."); @@ -250,9 +256,9 @@ public Result signUp(Http.Request request) { final Urn userUrn = new CorpuserUrn(email); final String userUrnString = userUrn.toString(); - _authClient.signUp(userUrnString, fullName, email, title, password, inviteToken); - _logger.info("Signed up user {} using invite tokens", userUrnString); - final String accessToken = _authClient.generateSessionTokenForUser(userUrn.getId()); + authClient.signUp(userUrnString, fullName, email, title, password, inviteToken); + logger.info("Signed up user {} using invite tokens", userUrnString); + final String accessToken = authClient.generateSessionTokenForUser(userUrn.getId()); return createSession(userUrnString, accessToken); } @@ -260,8 +266,8 @@ public Result signUp(Http.Request request) { @Nonnull public Result resetNativeUserCredentials(Http.Request request) { boolean nativeAuthenticationEnabled = - _nativeAuthenticationConfigs.isNativeAuthenticationEnabled(); - _logger.debug(String.format("Native authentication enabled: %b", nativeAuthenticationEnabled)); + nativeAuthenticationConfigs.isNativeAuthenticationEnabled(); + logger.debug(String.format("Native authentication enabled: %b", nativeAuthenticationEnabled)); if (!nativeAuthenticationEnabled) { String message = "Native authentication is not enabled on the server."; final ObjectNode error = Json.newObject(); @@ -291,26 +297,27 @@ public Result resetNativeUserCredentials(Http.Request request) { final Urn userUrn = new CorpuserUrn(email); final String userUrnString = userUrn.toString(); - _authClient.resetNativeUserCredentials(userUrnString, password, resetToken); - final String accessToken = _authClient.generateSessionTokenForUser(userUrn.getId()); + authClient.resetNativeUserCredentials(userUrnString, password, resetToken); + final String accessToken = authClient.generateSessionTokenForUser(userUrn.getId()); return createSession(userUrnString, accessToken); } private Optional redirectToIdentityProvider( Http.RequestHeader request, String redirectPath) { - final PlayWebContext playWebContext = new PlayWebContext(request, _playSessionStore); - final Client client = _ssoManager.getSsoProvider().client(); - configurePac4jSessionStore(playWebContext, client, redirectPath); + CallContext ctx = buildCallContext(request); + + final Client client = ssoManager.getSsoProvider().client(); + configurePac4jSessionStore(ctx, client, redirectPath); try { - final Optional action = client.getRedirectionAction(playWebContext); - return action.map(act -> new PlayHttpActionAdapter().adapt(act, playWebContext)); + final Optional action = client.getRedirectionAction(ctx); + return action.map(act -> new PlayHttpActionAdapter().adapt(act, ctx.webContext())); } catch (Exception e) { - if (_verbose) { - _logger.error( + if (verbose) { + logger.error( "Caught exception while attempting to redirect to SSO identity provider! It's likely that SSO integration is mis-configured", e); } else { - _logger.error( + logger.error( "Caught exception while attempting to redirect to SSO identity provider! It's likely that SSO integration is mis-configured"); } return Optional.of( @@ -324,22 +331,33 @@ private Optional redirectToIdentityProvider( } } - private void configurePac4jSessionStore( - PlayWebContext context, Client client, String redirectPath) { + private CallContext buildCallContext(Http.RequestHeader request) { + // First create PlayWebContext from the request + PlayWebContext webContext = new PlayWebContext(request); + + // Then create CallContext using the web context and session store + return new CallContext(webContext, playCookieSessionStore); + } + + private void configurePac4jSessionStore(CallContext ctx, Client client, String redirectPath) { + WebContext context = ctx.webContext(); + // Set the originally requested path for post-auth redirection. We split off into a separate // cookie from the session // to reduce size of the session cookie FoundAction foundAction = new FoundAction(redirectPath); - byte[] javaSerBytes = JAVA_SER_HELPER.serializeToBytes(foundAction); + byte[] javaSerBytes = + ((PlayCookieSessionStore) ctx.sessionStore()).getSerializer().serializeToBytes(foundAction); String serialized = Base64.getEncoder().encodeToString(compressBytes(javaSerBytes)); context.addResponseCookie(new Cookie(REDIRECT_URL_COOKIE_NAME, serialized)); // This is to prevent previous login attempts from being cached. // We replicate the logic here, which is buried in the Pac4j client. - if (_playSessionStore.get(context, client.getName() + ATTEMPTED_AUTHENTICATION_SUFFIX) - != null) { - _logger.debug( + Optional attempt = + playCookieSessionStore.get(context, client.getName() + ATTEMPTED_AUTHENTICATION_SUFFIX); + if (attempt.isPresent() && !"".equals(attempt.get())) { + logger.debug( "Found previous login attempt. Removing it manually to prevent unexpected errors."); - _playSessionStore.set(context, client.getName() + ATTEMPTED_AUTHENTICATION_SUFFIX, ""); + playCookieSessionStore.set(context, client.getName() + ATTEMPTED_AUTHENTICATION_SUFFIX, ""); } } @@ -351,27 +369,27 @@ private boolean tryLogin(String username, String password) { boolean loginSucceeded = false; // First try jaas login, if enabled - if (_jaasConfigs.isJAASEnabled()) { + if (jaasConfigs.isJAASEnabled()) { try { - _logger.debug("Attempting JAAS authentication for user: {}", username); + logger.debug("Attempting JAAS authentication for user: {}", username); AuthenticationManager.authenticateJaasUser(username, password); - _logger.debug("JAAS authentication successful. Login succeeded"); + logger.debug("JAAS authentication successful. Login succeeded"); loginSucceeded = true; } catch (Exception e) { - if (_verbose) { - _logger.debug("JAAS authentication error. Login failed", e); + if (verbose) { + logger.debug("JAAS authentication error. Login failed", e); } else { - _logger.debug("JAAS authentication error. Login failed"); + logger.debug("JAAS authentication error. Login failed"); } } } // If jaas login fails or is disabled, try native auth login - if (_nativeAuthenticationConfigs.isNativeAuthenticationEnabled() && !loginSucceeded) { + if (nativeAuthenticationConfigs.isNativeAuthenticationEnabled() && !loginSucceeded) { final Urn userUrn = new CorpuserUrn(username); final String userUrnString = userUrn.toString(); loginSucceeded = - loginSucceeded || _authClient.verifyNativeUserCredentials(userUrnString, password); + loginSucceeded || authClient.verifyNativeUserCredentials(userUrnString, password); } return loginSucceeded; @@ -383,8 +401,8 @@ private Result createSession(String userUrnString, String accessToken) { .withCookies( createActorCookie( userUrnString, - _cookieConfigs.getTtlInHours(), - _cookieConfigs.getAuthCookieSameSite(), - _cookieConfigs.getAuthCookieSecure())); + cookieConfigs.getTtlInHours(), + cookieConfigs.getAuthCookieSameSite(), + cookieConfigs.getAuthCookieSecure())); } } diff --git a/datahub-frontend/app/controllers/CentralLogoutController.java b/datahub-frontend/app/controllers/CentralLogoutController.java index eea1c662ebf89..d284720bab118 100644 --- a/datahub-frontend/app/controllers/CentralLogoutController.java +++ b/datahub-frontend/app/controllers/CentralLogoutController.java @@ -15,11 +15,11 @@ public class CentralLogoutController extends LogoutController { private static final String AUTH_URL_CONFIG_PATH = "/login"; private static final String DEFAULT_BASE_URL_PATH = "/"; - private static Boolean _isOidcEnabled = false; + private static Boolean isOidcEnabled = false; @Inject public CentralLogoutController(Config config) { - _isOidcEnabled = config.hasPath("auth.oidc.enabled") && config.getBoolean("auth.oidc.enabled"); + isOidcEnabled = config.hasPath("auth.oidc.enabled") && config.getBoolean("auth.oidc.enabled"); setDefaultUrl(DEFAULT_BASE_URL_PATH); setLogoutUrlPattern(DEFAULT_BASE_URL_PATH + ".*"); @@ -29,7 +29,7 @@ public CentralLogoutController(Config config) { /** logout() method should not be called if oidc is not enabled */ public Result executeLogout(Http.Request request) { - if (_isOidcEnabled) { + if (isOidcEnabled) { try { return logout(request).toCompletableFuture().get().withNewSession(); } catch (Exception e) { diff --git a/datahub-frontend/app/controllers/RedirectController.java b/datahub-frontend/app/controllers/RedirectController.java index 17f86b7fbffae..a86584e24ca29 100644 --- a/datahub-frontend/app/controllers/RedirectController.java +++ b/datahub-frontend/app/controllers/RedirectController.java @@ -16,7 +16,10 @@ public Result favicon(Http.Request request) { if (config.getVisualConfig().getAssets().getFaviconUrl().startsWith("http")) { return permanentRedirect(config.getVisualConfig().getAssets().getFaviconUrl()); } else { - final String prefix = config.getVisualConfig().getAssets().getFaviconUrl().startsWith("/") ? "/public" : "/public/"; + final String prefix = + config.getVisualConfig().getAssets().getFaviconUrl().startsWith("/") + ? "/public" + : "/public/"; return ok(Application.class.getResourceAsStream( prefix + config.getVisualConfig().getAssets().getFaviconUrl())) .as("image/x-icon"); diff --git a/datahub-frontend/app/controllers/SsoCallbackController.java b/datahub-frontend/app/controllers/SsoCallbackController.java index 750886570bf40..385b02a56ba23 100644 --- a/datahub-frontend/app/controllers/SsoCallbackController.java +++ b/datahub-frontend/app/controllers/SsoCallbackController.java @@ -5,8 +5,8 @@ import auth.sso.SsoProvider; import auth.sso.oidc.OidcCallbackLogic; import client.AuthServiceClient; -import com.datahub.authentication.Authentication; import com.linkedin.entity.client.SystemEntityClient; +import io.datahubproject.metadata.context.OperationContext; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -16,16 +16,15 @@ import javax.annotation.Nonnull; import javax.inject.Inject; import javax.inject.Named; - -import io.datahubproject.metadata.context.OperationContext; import lombok.extern.slf4j.Slf4j; +import org.pac4j.core.adapter.FrameworkAdapter; import org.pac4j.core.client.Client; import org.pac4j.core.client.Clients; import org.pac4j.core.config.Config; +import org.pac4j.core.context.FrameworkParameters; import org.pac4j.core.engine.CallbackLogic; -import org.pac4j.core.http.adapter.HttpActionAdapter; import org.pac4j.play.CallbackController; -import org.pac4j.play.PlayWebContext; +import org.pac4j.play.context.PlayFrameworkParameters; import play.mvc.Http; import play.mvc.Result; import play.mvc.Results; @@ -40,8 +39,9 @@ @Slf4j public class SsoCallbackController extends CallbackController { - private final SsoManager _ssoManager; - private final Config _config; + private final SsoManager ssoManager; + private final Config config; + private final CallbackLogic callbackLogic; @Inject public SsoCallbackController( @@ -51,23 +51,41 @@ public SsoCallbackController( @Nonnull AuthServiceClient authClient, @Nonnull Config config, @Nonnull com.typesafe.config.Config configs) { - _ssoManager = ssoManager; - _config = config; + this.ssoManager = ssoManager; + this.config = config; setDefaultUrl("/"); // By default, redirects to Home Page on log in. - setSaveInSession(false); - setCallbackLogic( + + callbackLogic = new SsoCallbackLogic( ssoManager, - systemOperationContext, + systemOperationContext, entityClient, authClient, - new CookieConfigs(configs))); + new CookieConfigs(configs)); + } + + @Override + public CompletionStage callback(Http.Request request) { + FrameworkAdapter.INSTANCE.applyDefaultSettingsIfUndefined(this.config); + + return CompletableFuture.supplyAsync( + () -> { + return (Result) + callbackLogic.perform( + this.config, + getDefaultUrl(), + getRenewSession(), + getDefaultClient(), + new PlayFrameworkParameters(request)); + }, + this.ec.current()); } public CompletionStage handleCallback(String protocol, Http.Request request) { if (shouldHandleCallback(protocol)) { - log.debug("Handling SSO callback. Protocol: {}", - _ssoManager.getSsoProvider().protocol().getCommonName()); + log.debug( + "Handling SSO callback. Protocol: {}", + ssoManager.getSsoProvider().protocol().getCommonName()); return callback(request) .handle( (res, e) -> { @@ -94,9 +112,9 @@ public CompletionStage handleCallback(String protocol, Http.Request requ } /** Logic responsible for delegating to protocol-specific callback logic. */ - public class SsoCallbackLogic implements CallbackLogic { + public class SsoCallbackLogic implements CallbackLogic { - private final OidcCallbackLogic _oidcCallbackLogic; + private final OidcCallbackLogic oidcCallbackLogic; SsoCallbackLogic( final SsoManager ssoManager, @@ -104,31 +122,21 @@ public class SsoCallbackLogic implements CallbackLogic { final SystemEntityClient entityClient, final AuthServiceClient authClient, final CookieConfigs cookieConfigs) { - _oidcCallbackLogic = + oidcCallbackLogic = new OidcCallbackLogic( ssoManager, systemOperationContext, entityClient, authClient, cookieConfigs); } @Override - public Result perform( - PlayWebContext context, + public Object perform( Config config, - HttpActionAdapter httpActionAdapter, - String defaultUrl, - Boolean saveInSession, - Boolean multiProfile, - Boolean renewSession, - String defaultClient) { - if (SsoProvider.SsoProtocol.OIDC.equals(_ssoManager.getSsoProvider().protocol())) { - return _oidcCallbackLogic.perform( - context, - config, - httpActionAdapter, - defaultUrl, - saveInSession, - multiProfile, - renewSession, - defaultClient); + String inputDefaultUrl, + Boolean inputRenewSession, + String defaultClient, + FrameworkParameters parameters) { + if (SsoProvider.SsoProtocol.OIDC.equals(ssoManager.getSsoProvider().protocol())) { + return oidcCallbackLogic.perform( + config, inputDefaultUrl, inputRenewSession, defaultClient, parameters); } // Should never occur. throw new UnsupportedOperationException( @@ -137,18 +145,18 @@ public Result perform( } private boolean shouldHandleCallback(final String protocol) { - if (!_ssoManager.isSsoEnabled()) { + if (!ssoManager.isSsoEnabled()) { return false; } updateConfig(); - return _ssoManager.getSsoProvider().protocol().getCommonName().equals(protocol); + return ssoManager.getSsoProvider().protocol().getCommonName().equals(protocol); } private void updateConfig() { final Clients clients = new Clients(); final List clientList = new ArrayList<>(); - clientList.add(_ssoManager.getSsoProvider().client()); + clientList.add(ssoManager.getSsoProvider().client()); clients.setClients(clientList); - _config.setClients(clients); + config.setClients(clients); } } diff --git a/datahub-frontend/app/controllers/TrackingController.java b/datahub-frontend/app/controllers/TrackingController.java index 254a8cc640d0c..5d12c96ed77cb 100644 --- a/datahub-frontend/app/controllers/TrackingController.java +++ b/datahub-frontend/app/controllers/TrackingController.java @@ -22,23 +22,23 @@ @Singleton public class TrackingController extends Controller { - private final Logger _logger = LoggerFactory.getLogger(TrackingController.class.getName()); + private static final Logger logger = LoggerFactory.getLogger(TrackingController.class.getName()); - private final String _topic; + private final String topic; - @Inject KafkaTrackingProducer _producer; + @Inject KafkaTrackingProducer producer; - @Inject AuthServiceClient _authClient; + @Inject AuthServiceClient authClient; @Inject public TrackingController(@Nonnull Config config) { - _topic = config.getString("analytics.tracking.topic"); + topic = config.getString("analytics.tracking.topic"); } @Security.Authenticated(Authenticator.class) @Nonnull public Result track(Http.Request request) throws Exception { - if (!_producer.isEnabled()) { + if (!producer.isEnabled()) { // If tracking is disabled, simply return a 200. return status(200); } @@ -51,15 +51,15 @@ public Result track(Http.Request request) throws Exception { } final String actor = request.session().data().get(ACTOR); try { - _logger.debug( + logger.debug( String.format("Emitting product analytics event. actor: %s, event: %s", actor, event)); final ProducerRecord record = - new ProducerRecord<>(_topic, actor, event.toString()); - _producer.send(record); - _authClient.track(event.toString()); + new ProducerRecord<>(topic, actor, event.toString()); + producer.send(record); + authClient.track(event.toString()); return ok(); } catch (Exception e) { - _logger.error( + logger.error( String.format( "Failed to emit product analytics event. actor: %s, event: %s", actor, event)); return internalServerError(e.getMessage()); diff --git a/datahub-frontend/build.gradle b/datahub-frontend/build.gradle index ab4ce405a5541..7750e169b11fb 100644 --- a/datahub-frontend/build.gradle +++ b/datahub-frontend/build.gradle @@ -12,6 +12,12 @@ ext { docker_dir = 'datahub-frontend' } +java { + toolchain { + languageVersion = JavaLanguageVersion.of(jdkVersion(project)) + } +} + model { // Must specify the dependency here as "stage" is added by rule based model. tasks.myTar { diff --git a/datahub-frontend/conf/application.conf b/datahub-frontend/conf/application.conf index be57a33b13564..db982b595e248 100644 --- a/datahub-frontend/conf/application.conf +++ b/datahub-frontend/conf/application.conf @@ -184,6 +184,7 @@ auth.oidc.responseMode = ${?AUTH_OIDC_RESPONSE_MODE} auth.oidc.useNonce = ${?AUTH_OIDC_USE_NONCE} auth.oidc.customParam.resource = ${?AUTH_OIDC_CUSTOM_PARAM_RESOURCE} auth.oidc.readTimeout = ${?AUTH_OIDC_READ_TIMEOUT} +auth.oidc.connectTimeout = ${?AUTH_OIDC_CONNECT_TIMEOUT} auth.oidc.extractJwtAccessTokenClaims = ${?AUTH_OIDC_EXTRACT_JWT_ACCESS_TOKEN_CLAIMS} # Whether to extract claims from JWT access token. Defaults to false. auth.oidc.preferredJwsAlgorithm = ${?AUTH_OIDC_PREFERRED_JWS_ALGORITHM} # Which jws algorithm to use auth.oidc.acrValues = ${?AUTH_OIDC_ACR_VALUES} diff --git a/datahub-frontend/play.gradle b/datahub-frontend/play.gradle index ff43e4a93a80f..266962721a80a 100644 --- a/datahub-frontend/play.gradle +++ b/datahub-frontend/play.gradle @@ -11,16 +11,29 @@ configurations { play } +ext { + nimbusJoseJwtVersion = "9.41.2" + oauth2OidcSdkVersion = "11.20.1" +} + dependencies { implementation project(':datahub-web-react') constraints { + play(externalDependency.pac4j) + play(externalDependency.playPac4j) + play("com.nimbusds:oauth2-oidc-sdk:$oauth2OidcSdkVersion") + play("com.nimbusds:nimbus-jose-jwt:$nimbusJoseJwtVersion") + implementation(externalDependency.pac4j) + implementation(externalDependency.playPac4j) + implementation("com.nimbusds:nimbus-jose-jwt:$nimbusJoseJwtVersion") + testImplementation("com.nimbusds:oauth2-oidc-sdk:$oauth2OidcSdkVersion") + play(externalDependency.jacksonDataBind) - play('com.nimbusds:oauth2-oidc-sdk:8.36.2') - play('com.nimbusds:nimbus-jose-jwt:8.18') - play('com.typesafe.akka:akka-actor_2.12:2.6.20') + play("com.typesafe.akka:akka-actor_$playScalaVersion:2.6.20") play(externalDependency.jsonSmart) play('io.netty:netty-all:4.1.114.Final') + implementation(externalDependency.commonsText) { because("previous versions are vulnerable to CVE-2022-42889") } @@ -46,14 +59,11 @@ dependencies { implementation externalDependency.jerseyCore implementation externalDependency.jerseyGuava - implementation(externalDependency.pac4j) { - exclude group: "net.minidev", module: "json-smart" - exclude group: "com.nimbusds", module: "nimbus-jose-jwt" - } - - implementation 'com.nimbusds:nimbus-jose-jwt:8.18' - implementation externalDependency.jsonSmart + implementation externalDependency.pac4j implementation externalDependency.playPac4j + implementation "com.nimbusds:nimbus-jose-jwt:$nimbusJoseJwtVersion" + implementation externalDependency.jsonSmart + implementation externalDependency.shiroCore implementation externalDependency.playCache @@ -69,7 +79,7 @@ dependencies { testImplementation externalDependency.mockito testImplementation externalDependency.playTest testImplementation 'org.awaitility:awaitility:4.2.0' - testImplementation 'no.nav.security:mock-oauth2-server:0.3.1' + testImplementation 'no.nav.security:mock-oauth2-server:2.1.9' testImplementation 'org.junit-pioneer:junit-pioneer:1.9.1' testImplementation externalDependency.junitJupiterApi testRuntimeOnly externalDependency.junitJupiterEngine @@ -78,7 +88,7 @@ dependencies { compileOnly externalDependency.lombok runtimeOnly externalDependency.guicePlay runtimeOnly (externalDependency.playDocs) { - exclude group: 'com.typesafe.akka', module: 'akka-http-core_2.12' + exclude group: 'com.typesafe.akka', module: "akka-http-core_$playScalaVersion" } runtimeOnly externalDependency.playGuice implementation externalDependency.log4j2Api @@ -90,9 +100,9 @@ dependencies { play { platform { - playVersion = '2.8.21' - scalaVersion = '2.12' - javaVersion = JavaVersion.VERSION_11 + playVersion = "2.8.22" // see also top level build.gradle + scalaVersion = "2.13" + javaVersion = JavaVersion.VERSION_17 } injectedRoutesGenerator = true diff --git a/datahub-frontend/test/app/ApplicationTest.java b/datahub-frontend/test/app/ApplicationTest.java index 534cffb5cc7fe..3ad9e22857168 100644 --- a/datahub-frontend/test/app/ApplicationTest.java +++ b/datahub-frontend/test/app/ApplicationTest.java @@ -12,17 +12,27 @@ import com.nimbusds.jwt.JWTParser; import controllers.routes; import java.io.IOException; +import java.net.HttpURLConnection; import java.net.InetAddress; +import java.net.URL; import java.text.ParseException; import java.util.Date; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import no.nav.security.mock.oauth2.MockOAuth2Server; +import no.nav.security.mock.oauth2.http.OAuth2HttpRequest; +import no.nav.security.mock.oauth2.http.OAuth2HttpResponse; +import no.nav.security.mock.oauth2.http.Route; import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback; +import okhttp3.Headers; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.awaitility.Awaitility; import org.awaitility.Durations; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -30,6 +40,8 @@ import org.junitpioneer.jupiter.SetEnvironmentVariable; import org.openqa.selenium.Cookie; import org.openqa.selenium.htmlunit.HtmlUnitDriver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import play.Application; import play.Environment; import play.Mode; @@ -48,7 +60,9 @@ @SetEnvironmentVariable(key = "AUTH_OIDC_JIT_PROVISIONING_ENABLED", value = "false") @SetEnvironmentVariable(key = "AUTH_OIDC_CLIENT_ID", value = "testclient") @SetEnvironmentVariable(key = "AUTH_OIDC_CLIENT_SECRET", value = "testsecret") +@SetEnvironmentVariable(key = "AUTH_VERBOSE_LOGGING", value = "true") public class ApplicationTest extends WithBrowser { + private static final Logger logger = LoggerFactory.getLogger(ApplicationTest.class); private static final String ISSUER_ID = "testIssuer"; @Override @@ -80,41 +94,34 @@ public int gmsServerPort() { return providePort() + 2; } - private MockOAuth2Server _oauthServer; - private MockWebServer _gmsServer; + private MockOAuth2Server oauthServer; + private Thread oauthServerThread; + private CompletableFuture oauthServerStarted; - private String _wellKnownUrl; + private MockWebServer gmsServer; + + private String wellKnownUrl; private static final String TEST_USER = "urn:li:corpuser:testUser@myCompany.com"; private static final String TEST_TOKEN = "faketoken_YCpYIrjQH4sD3_rAc3VPPFg4"; @BeforeAll public void init() throws IOException { - _gmsServer = new MockWebServer(); - _gmsServer.enqueue(new MockResponse().setResponseCode(404)); // dynamic settings - not tested - _gmsServer.enqueue(new MockResponse().setResponseCode(404)); // dynamic settings - not tested - _gmsServer.enqueue(new MockResponse().setResponseCode(404)); // dynamic settings - not tested - _gmsServer.enqueue(new MockResponse().setBody(String.format("{\"value\":\"%s\"}", TEST_USER))); - _gmsServer.enqueue( + // Start Mock GMS + gmsServer = new MockWebServer(); + gmsServer.enqueue(new MockResponse().setResponseCode(404)); // dynamic settings - not tested + gmsServer.enqueue(new MockResponse().setResponseCode(404)); // dynamic settings - not tested + gmsServer.enqueue(new MockResponse().setResponseCode(404)); // dynamic settings - not tested + gmsServer.enqueue(new MockResponse().setBody(String.format("{\"value\":\"%s\"}", TEST_USER))); + gmsServer.enqueue( new MockResponse().setBody(String.format("{\"accessToken\":\"%s\"}", TEST_TOKEN))); - _gmsServer.start(gmsServerPort()); - - _oauthServer = new MockOAuth2Server(); - _oauthServer.enqueueCallback( - new DefaultOAuth2TokenCallback( - ISSUER_ID, - "testUser", - List.of(), - Map.of( - "email", "testUser@myCompany.com", - "groups", "myGroup"), - 600)); - _oauthServer.start(InetAddress.getByName("localhost"), oauthServerPort()); - - // Discovery url to authorization server metadata - _wellKnownUrl = _oauthServer.wellKnownUrl(ISSUER_ID).toString(); + gmsServer.start(gmsServerPort()); + // Start Mock Identity Provider + startMockOauthServer(); + // Start Play Frontend startServer(); + // Start Browser createBrowser(); Awaitility.await().timeout(Durations.TEN_SECONDS).until(() -> app != null); @@ -122,13 +129,131 @@ public void init() throws IOException { @AfterAll public void shutdown() throws IOException { - if (_gmsServer != null) { - _gmsServer.shutdown(); - } - if (_oauthServer != null) { - _oauthServer.shutdown(); + if (gmsServer != null) { + logger.info("Shutdown Mock GMS"); + gmsServer.shutdown(); } + logger.info("Shutdown Play Frontend"); stopServer(); + if (oauthServer != null) { + logger.info("Shutdown MockOAuth2Server"); + oauthServer.shutdown(); + } + if (oauthServerThread != null && oauthServerThread.isAlive()) { + logger.info("Shutdown MockOAuth2Server thread"); + oauthServerThread.interrupt(); + try { + oauthServerThread.join(2000); // Wait up to 2 seconds for thread to finish + } catch (InterruptedException e) { + logger.warn("Shutdown MockOAuth2Server thread failed to join."); + } + } + } + + private void startMockOauthServer() { + // Configure HEAD responses + Route[] routes = + new Route[] { + new Route() { + @Override + public boolean match(@NotNull OAuth2HttpRequest oAuth2HttpRequest) { + return "HEAD".equals(oAuth2HttpRequest.getMethod()) + && (String.format("/%s/.well-known/openid-configuration", ISSUER_ID) + .equals(oAuth2HttpRequest.getUrl().url().getPath()) + || String.format("/%s/token", ISSUER_ID) + .equals(oAuth2HttpRequest.getUrl().url().getPath())); + } + + @Override + public OAuth2HttpResponse invoke(OAuth2HttpRequest oAuth2HttpRequest) { + return new OAuth2HttpResponse( + Headers.of( + Map.of( + "Content-Type", "application/json", + "Cache-Control", "no-store", + "Pragma", "no-cache", + "Content-Length", "-1")), + 200, + null, + null); + } + } + }; + oauthServer = new MockOAuth2Server(routes); + oauthServerStarted = new CompletableFuture<>(); + + // Create and start server in separate thread + oauthServerThread = + new Thread( + () -> { + try { + // Configure mock responses + oauthServer.enqueueCallback( + new DefaultOAuth2TokenCallback( + ISSUER_ID, + "testUser", + "JWT", + List.of(), + Map.of( + "email", "testUser@myCompany.com", + "groups", "myGroup"), + 600)); + + oauthServer.start(InetAddress.getByName("localhost"), oauthServerPort()); + + oauthServerStarted.complete(null); + + // Keep thread alive until server is stopped + while (!Thread.currentThread().isInterrupted() && testServer.isRunning()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + } catch (Exception e) { + oauthServerStarted.completeExceptionally(e); + } + }); + + oauthServerThread.setDaemon(true); // Ensure thread doesn't prevent JVM shutdown + oauthServerThread.start(); + + // Wait for server to start with timeout + oauthServerStarted + .orTimeout(10, TimeUnit.SECONDS) + .whenComplete( + (result, throwable) -> { + if (throwable != null) { + if (throwable instanceof TimeoutException) { + throw new RuntimeException( + "MockOAuth2Server failed to start within timeout", throwable); + } + throw new RuntimeException("MockOAuth2Server failed to start", throwable); + } + }); + + // Discovery url to authorization server metadata + wellKnownUrl = oauthServer.wellKnownUrl(ISSUER_ID).toString(); + + // Wait for server to return configuration + // Validate mock server returns data + try { + URL url = new URL(wellKnownUrl); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("GET"); + int responseCode = conn.getResponseCode(); + logger.info("Well-known endpoint response code: {}", responseCode); + + if (responseCode != 200) { + throw new RuntimeException( + "MockOAuth2Server not accessible. Response code: " + responseCode); + } + logger.info("Successfully started MockOAuth2Server."); + } catch (Exception e) { + throw new RuntimeException("Failed to connect to MockOAuth2Server", e); + } } @Test @@ -158,7 +283,7 @@ public void testIndexNotFound() { public void testOpenIdConfig() { assertEquals( "http://localhost:" + oauthServerPort() + "/testIssuer/.well-known/openid-configuration", - _wellKnownUrl); + wellKnownUrl); } @Test @@ -188,10 +313,10 @@ public void testHappyPathOidc() throws ParseException { @Test public void testAPI() throws ParseException { testHappyPathOidc(); - int requestCount = _gmsServer.getRequestCount(); + int requestCount = gmsServer.getRequestCount(); browser.goTo("/api/v2/graphql/"); - assertEquals(++requestCount, _gmsServer.getRequestCount()); + assertEquals(++requestCount, gmsServer.getRequestCount()); } @Test @@ -201,8 +326,9 @@ public void testOidcRedirectToRequestedUrl() { } /** - * The Redirect Uri parameter is used to store a previous relative location within the app to be able to - * take a user back to their expected page. Redirecting to other domains should be blocked. + * The Redirect Uri parameter is used to store a previous relative location within the app to be + * able to take a user back to their expected page. Redirecting to other domains should be + * blocked. */ @Test public void testInvalidRedirectUrl() { diff --git a/datahub-frontend/test/oidc/OidcCallbackLogicTest.java b/datahub-frontend/test/oidc/OidcCallbackLogicTest.java index f4784c29e91f2..9eb3833cbc897 100644 --- a/datahub-frontend/test/oidc/OidcCallbackLogicTest.java +++ b/datahub-frontend/test/oidc/OidcCallbackLogicTest.java @@ -1,64 +1,65 @@ package oidc; -import auth.sso.oidc.OidcConfigs; - -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - import static auth.sso.oidc.OidcCallbackLogic.getGroupNames; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.util.Arrays; +import java.util.Collection; import org.junit.jupiter.api.Test; -import org.mockito.Mockito; import org.pac4j.core.profile.CommonProfile; public class OidcCallbackLogicTest { - @Test - public void testGetGroupsClaimNamesJsonArray() { - CommonProfile profile = createMockProfileWithAttribute("[\"group1\", \"group2\"]", "groupsClaimName"); - Collection result = getGroupNames(profile, "[\"group1\", \"group2\"]", "groupsClaimName"); - assertEquals(Arrays.asList("group1", "group2"), result); - } - @Test - public void testGetGroupNamesWithSingleGroup() { - CommonProfile profile = createMockProfileWithAttribute("group1", "groupsClaimName"); - Collection result = getGroupNames(profile, "group1", "groupsClaimName"); - assertEquals(Arrays.asList("group1"), result); - } + @Test + public void testGetGroupsClaimNamesJsonArray() { + CommonProfile profile = + createMockProfileWithAttribute("[\"group1\", \"group2\"]", "groupsClaimName"); + Collection result = + getGroupNames(profile, "[\"group1\", \"group2\"]", "groupsClaimName"); + assertEquals(Arrays.asList("group1", "group2"), result); + } - @Test - public void testGetGroupNamesWithCommaSeparated() { - CommonProfile profile = createMockProfileWithAttribute("group1,group2", "groupsClaimName"); - Collection result = getGroupNames(profile, "group1,group2", "groupsClaimName"); - assertEquals(Arrays.asList("group1", "group2"), result); - } + @Test + public void testGetGroupNamesWithSingleGroup() { + CommonProfile profile = createMockProfileWithAttribute("group1", "groupsClaimName"); + Collection result = getGroupNames(profile, "group1", "groupsClaimName"); + assertEquals(Arrays.asList("group1"), result); + } - @Test - public void testGetGroupNamesWithCollection() { - CommonProfile profile = createMockProfileWithAttribute(Arrays.asList("group1", "group2"), "groupsClaimName"); - Collection result = getGroupNames(profile, Arrays.asList("group1", "group2"), "groupsClaimName"); - assertEquals(Arrays.asList("group1", "group2"), result); - } - // Helper method to create a mock CommonProfile with given attribute - private CommonProfile createMockProfileWithAttribute(Object attribute, String attributeName) { - CommonProfile profile = mock(CommonProfile.class); + @Test + public void testGetGroupNamesWithCommaSeparated() { + CommonProfile profile = createMockProfileWithAttribute("group1,group2", "groupsClaimName"); + Collection result = getGroupNames(profile, "group1,group2", "groupsClaimName"); + assertEquals(Arrays.asList("group1", "group2"), result); + } + + @Test + public void testGetGroupNamesWithCollection() { + CommonProfile profile = + createMockProfileWithAttribute(Arrays.asList("group1", "group2"), "groupsClaimName"); + Collection result = + getGroupNames(profile, Arrays.asList("group1", "group2"), "groupsClaimName"); + assertEquals(Arrays.asList("group1", "group2"), result); + } - // Mock for getAttribute(String) - when(profile.getAttribute(attributeName)).thenReturn(attribute); + // Helper method to create a mock CommonProfile with given attribute + private CommonProfile createMockProfileWithAttribute(Object attribute, String attributeName) { + CommonProfile profile = mock(CommonProfile.class); - // Mock for getAttribute(String, Class) - if (attribute instanceof Collection) { - when(profile.getAttribute(attributeName, Collection.class)).thenReturn((Collection) attribute); - } else if (attribute instanceof String) { - when(profile.getAttribute(attributeName, String.class)).thenReturn((String) attribute); - } - // Add more conditions here if needed for other types + // Mock for getAttribute(String) + when(profile.getAttribute(attributeName)).thenReturn(attribute); - return profile; + // Mock for getAttribute(String, Class) + if (attribute instanceof Collection) { + when(profile.getAttribute(attributeName, Collection.class)) + .thenReturn((Collection) attribute); + } else if (attribute instanceof String) { + when(profile.getAttribute(attributeName, String.class)).thenReturn((String) attribute); } + // Add more conditions here if needed for other types + + return profile; + } } diff --git a/datahub-frontend/test/security/OidcConfigurationTest.java b/datahub-frontend/test/security/OidcConfigurationTest.java index 1c52d45af5f9e..ec19979c56120 100644 --- a/datahub-frontend/test/security/OidcConfigurationTest.java +++ b/datahub-frontend/test/security/OidcConfigurationTest.java @@ -23,9 +23,9 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; +import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.pac4j.oidc.client.OidcClient; -import org.json.JSONObject; public class OidcConfigurationTest { @@ -328,17 +328,28 @@ public void readPreferredJwsAlgorithmPropagationFromConfig() { oidcConfigsBuilder.from(CONFIG, SSO_SETTINGS_JSON_STR); OidcConfigs oidcConfigs = new OidcConfigs(oidcConfigsBuilder); OidcProvider oidcProvider = new OidcProvider(oidcConfigs); - assertEquals("RS256", ((OidcClient) oidcProvider.client()).getConfiguration().getPreferredJwsAlgorithm().toString()); + assertEquals( + "RS256", + ((OidcClient) oidcProvider.client()) + .getConfiguration() + .getPreferredJwsAlgorithm() + .toString()); } @Test public void readPreferredJwsAlgorithmPropagationFromJSON() { - final String SSO_SETTINGS_JSON_STR = new JSONObject().put(PREFERRED_JWS_ALGORITHM, "HS256").toString(); + final String SSO_SETTINGS_JSON_STR = + new JSONObject().put(PREFERRED_JWS_ALGORITHM, "HS256").toString(); CONFIG.withValue(OIDC_PREFERRED_JWS_ALGORITHM, ConfigValueFactory.fromAnyRef("RS256")); OidcConfigs.Builder oidcConfigsBuilder = new OidcConfigs.Builder(); oidcConfigsBuilder.from(CONFIG, SSO_SETTINGS_JSON_STR); OidcConfigs oidcConfigs = new OidcConfigs(oidcConfigsBuilder); OidcProvider oidcProvider = new OidcProvider(oidcConfigs); - assertEquals("HS256", ((OidcClient) oidcProvider.client()).getConfiguration().getPreferredJwsAlgorithm().toString()); + assertEquals( + "HS256", + ((OidcClient) oidcProvider.client()) + .getConfiguration() + .getPreferredJwsAlgorithm() + .toString()); } }