Skip to content

Commit

Permalink
feat: multitenancy (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
mghilardelli authored Jul 16, 2024
1 parent 62933d6 commit b2d2f36
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package ch.sbb.playgroundbackend.auth;

import java.util.ArrayList;
import java.util.List;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;

/**
* Contains the tenant configurations from the application.yml
*/
@Configuration
@EnableConfigurationProperties
@ConfigurationProperties("auth")
public class TenantConfig {

private List<Tenant> tenants = new ArrayList<>();

public List<Tenant> getTenants() {
return tenants;
}

public void setTenants(List<Tenant> tenants) {
this.tenants = tenants;
}

public static class Tenant {

private String name;
private String jwkSetUri;
private String issuerUri;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getJwkSetUri() {
return jwkSetUri;
}

public void setJwkSetUri(String jwkSetUri) {
this.jwkSetUri = jwkSetUri;
}

public String getIssuerUri() {
return issuerUri;
}

public void setIssuerUri(String issuerUri) {
this.issuerUri = issuerUri;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package ch.sbb.playgroundbackend.auth;

import ch.sbb.playgroundbackend.auth.TenantConfig.Tenant;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.proc.JWSAlgorithmFamilyJWSKeySelector;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.JWTClaimsSetAwareJWSKeySelector;
import java.net.URL;
import java.security.Key;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.stereotype.Component;

/**
* This class (which is used by the bean jwtProcessor, see SecurityConfig.java) provides the functionality to choose which key selector to use based on the iss claim in the JWT. It uses a cache for
* JWKKeySelectors, keyed by tenant identifier. Looking up the tenant is more secure than simply calculating the JWK Set endpoint on the fly - the lookup acts as a list of allowed tenants.
* <p>
* For a more detailed description see <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/resource-server/multitenancy.html#_parsing_the_claim_only_once">Spring Security
* Documentation</a>.
*/
@Component
public class TenantJwsKeySelector implements JWTClaimsSetAwareJWSKeySelector<SecurityContext> {

private final TenantService tenantService;
private final Map<String, JWSKeySelector<SecurityContext>> selectors = new ConcurrentHashMap<>();

public TenantJwsKeySelector(TenantService tenantService) {
this.tenantService = tenantService;
}

@Override
public List<? extends Key> selectKeys(JWSHeader jwsHeader, JWTClaimsSet jwtClaimsSet, SecurityContext securityContext)
throws KeySourceException {
return this.selectors.computeIfAbsent(toTenant(jwtClaimsSet), this::fromTenant)
.selectJWSKeys(jwsHeader, securityContext);
}

private String toTenant(JWTClaimsSet claimSet) {
return (String) claimSet.getClaim(JwtClaimNames.ISS);
}

private JWSKeySelector<SecurityContext> fromTenant(String issuerUri) {
final Tenant tenant = tenantService.getByIssuerUri(issuerUri);
return fromUri(tenant.getJwkSetUri());
}

private JWSKeySelector<SecurityContext> fromUri(String uri) {
try {
return JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL(new URL(uri));
} catch (Exception ex) {
throw new IllegalArgumentException(ex);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package ch.sbb.playgroundbackend.auth;

import ch.sbb.playgroundbackend.auth.TenantConfig.Tenant;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.Service;

/**
* Service providing tenant information based on the iss claim of the JWT token.
*/
@Service
public class TenantService {

private static final Logger logger = LogManager.getLogger(TenantService.class);

private final TenantConfig tenantConfig;

public TenantService(TenantConfig tenantConfig) {
this.tenantConfig = tenantConfig;
}

public Tenant getByIssuerUri(String issuerUri) {
Tenant tenant = tenantConfig.getTenants().stream().filter(t ->
issuerUri.equals(t.getIssuerUri())
).findAny()
.orElseThrow(() -> new IllegalArgumentException("unknown tenant"));

logger.info(String.format("Got tenant '%s' with issuer URI '%s'", tenant.getName(), tenant.getIssuerUri()));

return tenant;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
package ch.sbb.playgroundbackend.config;

import ch.sbb.playgroundbackend.auth.TenantJwsKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.function.Predicate;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.security.oauth2.resource.OAuth2ResourceServerProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
Expand All @@ -18,38 +27,30 @@
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtIssuerValidator;
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.function.Predicate;

@Configuration
public class ApplicationConfiguration {

private final OAuth2ResourceServerProperties.Jwt properties;

// The audience is important because the JWT token is accepted only if the aud claim in the JWT token received by the server is the same as the client ID of the server.
@Value("${spring.security.oauth2.resourceserver.jwt.audience}")
String[] audiences;

public ApplicationConfiguration(OAuth2ResourceServerProperties properties) {
this.properties = properties.getJwt();
}

@Bean
JwtDecoder jwtDecoder() {
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder.withJwkSetUri(properties.getJwkSetUri()).build();
JwtDecoder jwtDecoder(TenantJwsKeySelector keySelector) {
NimbusJwtDecoder nimbusJwtDecoder = new NimbusJwtDecoder(jwtProcessor(keySelector));
nimbusJwtDecoder.setJwtValidator(jwtValidator());
return nimbusJwtDecoder;
}

@Bean
public JWTProcessor<SecurityContext> jwtProcessor(TenantJwsKeySelector keySelector) {
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWTClaimsSetAwareJWSKeySelector(keySelector);
return jwtProcessor;
}

@Bean
public OAuth2AuthorizedClientManager authorizedClientManager(
ClientRegistrationRepository clientRegistrationRepository,
Expand Down Expand Up @@ -77,10 +78,6 @@ private DefaultJwtBearerTokenResponseClient oAuth2AccessTokenResponseClient() {

private OAuth2TokenValidator<Jwt> jwtValidator() {
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
String issuerUri = properties.getIssuerUri();
if (StringUtils.hasText(issuerUri)) {
validators.add(new JwtIssuerValidator(issuerUri));
}
if (audiences != null && audiences.length > 0) {
validators.add(new JwtClaimValidator<>(JwtClaimNames.AUD, audiencePredicate()));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package ch.sbb.playgroundbackend.config;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimNames;

@Configuration
public class DynamicClientRegistrationRepository implements ClientRegistrationRepository {

@Value("${auth.exchange.client-id}")
private String clientId;

@Value("${auth.exchange.client-secret}")
private String clientSecret;

@Value("${auth.exchange.scope}")
private String scope;

@Override
public ClientRegistration findByRegistrationId(String registrationId) {
String issuerUri = (String) ((Jwt) (SecurityContextHolder.getContext().getAuthentication().getPrincipal())).getClaims().get(JwtClaimNames.ISS);
return ClientRegistrations.fromIssuerLocation(issuerUri)
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.clientId(clientId)
.clientSecret(clientSecret)
.scope(scope)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class SwaggerConfiguration {
private String authorizationUrl;

@Bean
public OpenAPI gleisspiegelOpenAPIConfiguration() {
public OpenAPI openAPIConfiguration() {
return new OpenAPI()
.components(new Components().addSecuritySchemes(OAUTH_2, addOAuthSecurityScheme()))
.security(Collections.singletonList(new SecurityRequirement().addList(OAUTH_2)))
Expand All @@ -43,7 +43,7 @@ private Info apiInfo() {
return new Info()
.title("Playground Backend" + versionInformation)
.contact(new Contact()
.name("Team Zug")
.name("DAS")
.url("https://github.com/SchweizerischeBundesbahnen/DAS"));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,57 @@
package ch.sbb.playgroundbackend.config;

import static org.springframework.security.config.Customizer.withDefaults;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter;
import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter;
import org.springframework.security.web.SecurityFilterChain;

import static org.springframework.security.config.Customizer.withDefaults;

@Configuration
@EnableWebSecurity
@Profile("!test")
public class WebSecurityConfiguration {

private static final String ROLES_KEY = "roles";
private static final String ROLE_PREFIX = "ROLE_";

@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
http
.cors(withDefaults())
.authorizeHttpRequests(authConfig -> {
authConfig.requestMatchers("/swagger-ui/**").permitAll();
authConfig.requestMatchers("/v3/api-docs/**").permitAll();
authConfig.requestMatchers("/actuator/health/*").permitAll();
authConfig.requestMatchers("/actuator/info").permitAll();
authConfig.requestMatchers("/**").authenticated();
}
)
// Disable csrf for now as it makes unauthenticated requests return 401/403
.csrf(AbstractHttpConfigurer::disable)
.oauth2ResourceServer(oauth2 ->
oauth2.jwt(withDefaults())
);
.cors(withDefaults())
.authorizeHttpRequests(authConfig -> {
authConfig.requestMatchers("/swagger-ui/**").permitAll();
authConfig.requestMatchers("/v3/api-docs/**").permitAll();
authConfig.requestMatchers("/actuator/health/*").permitAll();
authConfig.requestMatchers("/actuator/info").permitAll();
authConfig.requestMatchers("/admin/**").hasRole("admin");
authConfig.requestMatchers("/**").authenticated();
}
)
// Disable csrf for now as it makes unauthenticated requests return 401/403
.csrf(AbstractHttpConfigurer::disable)
.oauth2ResourceServer(oauth2 ->
oauth2.jwt(jwtConfigurer -> jwtConfigurer.jwtAuthenticationConverter(jwtAuthenticationConverter()))
);
return http.build();
}

@Bean
public JwtAuthenticationConverter jwtAuthenticationConverter() {
// We define a custom role converter to extract the roles from the Entra ID's JWT token and convert them to granted authorities.
// This allows us to do role-based access control on our endpoints.
JwtGrantedAuthoritiesConverter roleConverter = new JwtGrantedAuthoritiesConverter();
roleConverter.setAuthoritiesClaimName(ROLES_KEY);
roleConverter.setAuthorityPrefix(ROLE_PREFIX);

JwtAuthenticationConverter jwtAuthenticationConverter = new JwtAuthenticationConverter();
jwtAuthenticationConverter.setJwtGrantedAuthoritiesConverter(roleConverter);

return jwtAuthenticationConverter;
}
}
Loading

0 comments on commit b2d2f36

Please sign in to comment.