Skip to content

Commit

Permalink
fix: keycloak to jwt token migration
Browse files Browse the repository at this point in the history
  • Loading branch information
igorvargaextvi committed Feb 7, 2025
1 parent 0ed21f8 commit 50af1d2
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 100 deletions.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@
<artifactId>spring-security-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public Collection<? extends GrantedAuthority> mapAuthorities(
return mapAuthorities(roleNames);
}

private Set<GrantedAuthority> mapAuthorities(Set<String> roleNames) {
Set<GrantedAuthority> mapAuthorities(Set<String> roleNames) {
return roleNames.parallelStream()
.map(UserRole::getRoleByValue)
.flatMap(Optional::stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import java.util.Optional;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.keycloak.KeycloakSecurityContext;
import org.keycloak.adapters.springsecurity.token.KeycloakAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.stereotype.Component;

@AllArgsConstructor
Expand All @@ -18,30 +19,36 @@ public class AccessTokenTenantResolver implements TenantResolver {

@Override
public Optional<Long> resolve(HttpServletRequest request) {
return resolveTenantIdFromTokenClaims(request);
return resolveTenantIdFromTokenClaims();
}

private Optional<Long> resolveTenantIdFromTokenClaims(HttpServletRequest request) {
Map<String, Object> claimMap = getClaimMap(request);
private Optional<Long> resolveTenantIdFromTokenClaims() {
Map<String, Object> claimMap = getClaimMap();
log.debug("Found tenantId in claim : " + claimMap.toString());
return getUserTenantIdAttribute(claimMap);
}

private Optional<Long> getUserTenantIdAttribute(Map<String, Object> claimMap) {
if (claimMap.containsKey(TENANT_ID)) {
Integer tenantId = (Integer) claimMap.get(TENANT_ID);
return Optional.of(Long.valueOf(tenantId));
} else {
return Optional.empty();
Object tenantIdObject = claimMap.get(TENANT_ID);
if (tenantIdObject instanceof Long tenantId) {
return Optional.of(tenantId);
}
if (tenantIdObject instanceof Integer tenantId) {
return Optional.of(Long.valueOf(tenantId));
}
}
return Optional.empty();
}

private Map<String, Object> getClaimMap(HttpServletRequest request) {
KeycloakSecurityContext keycloakSecContext =
((KeycloakAuthenticationToken) request.getUserPrincipal())
.getAccount()
.getKeycloakSecurityContext();
return keycloakSecContext.getToken().getOtherClaims();
private Map<String, Object> getClaimMap() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null) {
var jwt = (Jwt) authentication.getPrincipal();
return jwt.getClaims();
} else {
return Map.of();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package de.caritas.cob.uploadservice.api.tenant;

import com.google.common.collect.Lists;
import de.caritas.cob.uploadservice.api.authorization.UserRole;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.keycloak.adapters.springsecurity.token.KeycloakAuthenticationToken;
import org.keycloak.representations.AccessToken;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
import org.springframework.stereotype.Component;

@Component
Expand All @@ -14,13 +21,34 @@ public Optional<Long> resolve(HttpServletRequest request) {
return isTechnicalUserRole(request) ? Optional.of(0L) : Optional.empty();
}

public Collection<String> extractRealmRoles(Jwt jwt) {
Map<String, Object> realmAccess = (Map<String, Object>) jwt.getClaims().get("realm_access");
if (realmAccess != null) {
var roles = (List<String>) realmAccess.get("roles");
System.out.println("Extracted roles: " + roles); // Debug logging
if (roles != null) {
return roles;
}
}
return Lists.newArrayList();
}

private boolean containsAnyRole(HttpServletRequest request, String... expectedRoles) {
JwtAuthenticationToken token = ((JwtAuthenticationToken) request.getUserPrincipal());
var roles = extractRealmRoles(token.getToken());
if (!roles.isEmpty()) {
return containsAny(roles, expectedRoles);
} else {
return false;
}
}

private boolean containsAny(Collection<String> roles, String... expectedRoles) {
return Arrays.stream(expectedRoles).anyMatch(roles::contains);
}

private boolean isTechnicalUserRole(HttpServletRequest request) {
AccessToken token =
((KeycloakAuthenticationToken) request.getUserPrincipal())
.getAccount()
.getKeycloakSecurityContext()
.getToken();
return hasRoles(token) && token.getRealmAccess().getRoles().contains("technical");
return containsAnyRole(request, UserRole.TECHNICAL.getValue());
}

private boolean hasRoles(AccessToken accessToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import de.caritas.cob.uploadservice.filter.HttpTenantFilter;
import de.caritas.cob.uploadservice.filter.StatelessCsrfFilter;
import lombok.RequiredArgsConstructor;
import org.keycloak.adapters.springboot.KeycloakSpringBootConfigResolver;
import org.keycloak.adapters.springsecurity.KeycloakConfiguration;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
Expand Down Expand Up @@ -63,16 +61,6 @@ public class SecurityConfig implements WebMvcConfigurer {
"/actuator/health/**"
};

/**
* Tells Keycloak to use Spring Boot properties (application.yml/application.properties) rather
* than a keycloak.json.
*/
@Bean
@Primary
public KeycloakSpringBootConfigResolver keycloakSpringBootConfigResolver() {
return new KeycloakSpringBootConfigResolver();
}

@Autowired AuthorisationService authorisationService;

@Autowired JwtAuthConverterProperties jwtAuthConverterProperties;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,60 +1,39 @@
package de.caritas.cob.uploadservice.api.authorization;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
import static org.mockito.Mockito.mock;
import static org.assertj.core.api.Assertions.assertThat;

import java.security.Principal;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.keycloak.adapters.RefreshableKeycloakSecurityContext;
import org.keycloak.adapters.spi.KeycloakAccount;
import org.keycloak.adapters.springsecurity.account.SimpleKeycloakAccount;
import org.keycloak.adapters.springsecurity.authentication.KeycloakAuthenticationProvider;
import org.keycloak.adapters.springsecurity.token.KeycloakAuthenticationToken;
import org.mockito.internal.util.collections.Sets;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.SimpleGrantedAuthority;

@ExtendWith(MockitoExtension.class)
public class RoleAuthorizationAuthorityMapperTest {
class RoleAuthorizationAuthorityMapperTest {

private final KeycloakAuthenticationProvider provider = new KeycloakAuthenticationProvider();
private final Set<String> roles =
Sets.newSet(
UserRole.USER.getValue(),
UserRole.CONSULTANT.getValue(),
UserRole.PEER_CONSULTANT.getValue(),
UserRole.ANONYMOUS.getValue());
Stream.of(UserRole.values()).map(UserRole::getValue).collect(Collectors.toSet());

@Test
public void roleAuthorizationAuthorityMapper_Should_GrantCorrectAuthorities() {

Principal principal = mock(Principal.class);
RefreshableKeycloakSecurityContext securityContext =
mock(RefreshableKeycloakSecurityContext.class);
KeycloakAccount account = new SimpleKeycloakAccount(principal, roles, securityContext);

KeycloakAuthenticationToken token = new KeycloakAuthenticationToken(account, false);
void roleAuthorizationAuthorityMapper_Should_GrantCorrectAuthorities() {

RoleAuthorizationAuthorityMapper roleAuthorizationAuthorityMapper =
new RoleAuthorizationAuthorityMapper();
provider.setGrantedAuthoritiesMapper(roleAuthorizationAuthorityMapper);

Authentication result = provider.authenticate(token);
var result = roleAuthorizationAuthorityMapper.mapAuthorities(roles);

Set<SimpleGrantedAuthority> expectedGrantendAuthorities = new HashSet<>();
roles.forEach(
roleName ->
expectedGrantendAuthorities.addAll(
Authority.getAuthoritiesByUserRole(UserRole.getRoleByValue(roleName).get()).stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet())));

assertThat(expectedGrantendAuthorities, containsInAnyOrder(result.getAuthorities().toArray()));
roleName -> {
expectedGrantendAuthorities.addAll(
Authority.getAuthoritiesByUserRole(UserRole.getRoleByValue(roleName).get()).stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet()));
});

assertThat(expectedGrantendAuthorities).isEqualTo(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,48 @@

import com.google.common.collect.Maps;
import jakarta.servlet.http.HttpServletRequest;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.keycloak.adapters.springsecurity.token.KeycloakAuthenticationToken;
import org.mockito.Answers;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.jwt.Jwt;

@ExtendWith(MockitoExtension.class)
class AccessTokenTenantResolverTest {

@InjectMocks AccessTokenTenantResolver accessTokenTenantResolver;

@Mock SecurityContext mockSecurityContext;

@Mock Authentication mockAuthentication;

@Mock HttpServletRequest authenticatedRequest;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
KeycloakAuthenticationToken token;
@AfterEach
public void tearDown() {
SecurityContextHolder.clearContext();
}

@InjectMocks AccessTokenTenantResolver accessTokenTenantResolver;
private void givenUserIsAuthenticated() {
SecurityContextHolder.setContext(mockSecurityContext);
when(mockSecurityContext.getAuthentication()).thenReturn(mockAuthentication);
Jwt jwt = buildJwt();
when(mockAuthentication.getPrincipal()).thenReturn(jwt);
}

@Test
void resolve_Should_ResolveTenantId_When_TenantIdInAccessTokenClaim() {
// given
when(authenticatedRequest.getUserPrincipal()).thenReturn(token);

HashMap<String, Object> claimMap = givenClaimMapContainingTenantId(1);
when(token.getAccount().getKeycloakSecurityContext().getToken().getOtherClaims())
.thenReturn(claimMap);
givenUserIsAuthenticated();

// when
Optional<Long> resolvedTenantId = accessTokenTenantResolver.resolve(authenticatedRequest);
Expand All @@ -41,6 +55,14 @@ void resolve_Should_ResolveTenantId_When_TenantIdInAccessTokenClaim() {
assertThat(resolvedTenantId).isEqualTo(Optional.of(1L));
}

private Jwt buildJwt() {
Map<String, Object> headers = new HashMap<>();
headers.put("alg", "HS256"); // Signature algorithm
headers.put("typ", "JWT"); // Token type
return new Jwt(
"token", Instant.now(), Instant.now(), headers, givenClaimMapContainingTenantId(1));
}

private HashMap<String, Object> givenClaimMapContainingTenantId(Integer tenantId) {
HashMap<String, Object> claimMap = Maps.newHashMap();
claimMap.put("tenantId", tenantId);
Expand Down
Loading

0 comments on commit 50af1d2

Please sign in to comment.