diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index 7d7f4e15920b..f121228f6686 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -2,6 +2,7 @@ from typing import Any +from requests import RequestException from structlog.stdlib import get_logger from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient @@ -17,14 +18,41 @@ class AzureADOAuthRedirect(OAuthRedirect): def get_additional_parameters(self, source): # pragma: no cover return { - "scope": ["openid", "https://graph.microsoft.com/User.Read"], + "scope": [ + "openid", + "https://graph.microsoft.com/User.Read", + "https://graph.microsoft.com/GroupMember.Read.All", + ], } +class AzureADClient(UserprofileHeaderAuthClient): + """Fetch AzureAD group information""" + + def get_profile_info(self, token): + profile_data = super().get_profile_info(token) + group_response = self.session.request( + "get", + "https://graph.microsoft.com/v1.0/me/memberOf", + headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, + ) + try: + group_response.raise_for_status() + except RequestException as exc: + LOGGER.warning( + "Unable to fetch user profile", + exc=exc, + response=exc.response.text if exc.response else str(exc), + ) + return None + profile_data["raw_groups"] = group_response.json() + return profile_data + + class AzureADOAuthCallback(OpenIDConnectOAuth2Callback): """AzureAD OAuth2 Callback""" - client_class = UserprofileHeaderAuthClient + client_class = AzureADClient def get_user_id(self, info: dict[str, str]) -> str: # Default try to get `id` for the Graph API endpoint @@ -53,8 +81,24 @@ class AzureADType(SourceType): def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: mail = info.get("mail", None) or info.get("otherMails", [None])[0] + # Format group info + groups = [] + group_id_dict = {} + for group in info.get("raw_groups").get("value", []): + if group["@odata.type"] != "#microsoft.graph.group": + continue + groups.append(group["id"]) + group_id_dict[group["id"]] = group + info["raw_groups"] = group_id_dict return { "username": info.get("userPrincipalName"), "email": mail, "name": info.get("displayName"), + "groups": groups, + } + + def get_base_group_properties(self, source, group_id, **kwargs): + raw_group = kwargs["info"]["raw_groups"][group_id] + return { + "name": raw_group["displayName"], }