diff --git a/helusers/models.py b/helusers/models.py index a7121d4..eb9d693 100644 --- a/helusers/models.py +++ b/helusers/models.py @@ -1,5 +1,7 @@ import logging import uuid +from collections import defaultdict +from itertools import chain from django.contrib.auth.models import AbstractUser as DjangoAbstractUser from django.contrib.auth.models import Group @@ -87,16 +89,18 @@ def sync_groups_from_ad(self): """Determine which Django groups to add or remove based on AD groups.""" ad_list = ADGroupMapping.objects.values_list("ad_group", "group") - mappings = {ad_group: group for ad_group, group in ad_list} + mappings = defaultdict(list) + for ad_group, group in ad_list: + mappings[ad_group].append(group) user_ad_groups = set( self.ad_groups.filter(groups__isnull=False).values_list(flat=True) ) - all_mapped_groups = set(mappings.values()) + all_mapped_groups = set(chain(*mappings.values())) old_groups = set( self.groups.filter(id__in=all_mapped_groups).values_list(flat=True) ) - new_groups = set([mappings[x] for x in user_ad_groups]) + new_groups = set(chain(*[mappings[x] for x in user_ad_groups])) groups_to_delete = old_groups - new_groups if groups_to_delete: diff --git a/helusers/tests/test_models.py b/helusers/tests/test_models.py index 9c8829e..26c8954 100644 --- a/helusers/tests/test_models.py +++ b/helusers/tests/test_models.py @@ -160,6 +160,32 @@ class TestUserAdGroups: [], id="all_removed", ), + # many ad-group for 1 user group + pytest.param( + ( + ("ad_group_1", "group_1"), + ("ad_group_2", "group_1"), + ("ad_group_3", "group_1"), + ), + [], + [], + ["ad_group_1", "ad_group_2", "ad_group_3"], + ["group_1"], + id="one-ad-for-many-group", + ), + # 1 ad-group for many user groups + pytest.param( + ( + ("ad_group_1", "group_1"), + ("ad_group_1", "group_2"), + ("ad_group_1", "group_3"), + ), + [], + [], + ["ad_group_1"], + ["group_1", "group_2", "group_3"], + id="many-ad-for-one-group", + ), ], ) def test_update_ad_groups(