From 8da1a2ad53b44f8c275b13e5529dc7f4a65d3c6f Mon Sep 17 00:00:00 2001 From: Gresille&Siffle <39056254+GresilleSiffle@users.noreply.github.com> Date: Thu, 4 Jan 2024 15:24:49 +0100 Subject: [PATCH 1/4] [back] feat: add entity contexts to unconnected entities API --- backend/tournesol/views/mixins/poll.py | 11 ++++++ .../tournesol/views/unconnected_entities.py | 6 ++-- frontend/scripts/openapi.yaml | 36 ++++++++++++++++++- .../entity_selector/EntitySelectButton.tsx | 2 +- 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/backend/tournesol/views/mixins/poll.py b/backend/tournesol/views/mixins/poll.py index 87529ed99e..2007583b09 100644 --- a/backend/tournesol/views/mixins/poll.py +++ b/backend/tournesol/views/mixins/poll.py @@ -13,10 +13,17 @@ class PollScopedViewMixin: view or a subclass of it. """ + entity_contexts = None + enable_entity_contexts = False + poll_parameter = "poll_name" # used to avoid multiple similar database queries in a single HTTP request poll_from_url: Poll + @staticmethod + def get_entity_contexts(poll: Poll): + return poll.all_entity_contexts.prefetch_related("texts").all() + def poll_from_kwargs_or_404(self, request_kwargs): poll_name = request_kwargs[self.poll_parameter] try: @@ -33,7 +40,11 @@ def initial(self, request, *args, **kwargs): # make the requested poll available at any time in the view self.poll_from_url = self.poll_from_kwargs_or_404(kwargs) + if self.enable_entity_contexts: + self.entity_contexts = self.get_entity_contexts(self.poll_from_url) + def get_serializer_context(self): context = super().get_serializer_context() context["poll"] = self.poll_from_url + context["entity_contexts"] = self.entity_contexts return context diff --git a/backend/tournesol/views/unconnected_entities.py b/backend/tournesol/views/unconnected_entities.py index 3b374a7b82..9e5b89131b 100644 --- a/backend/tournesol/views/unconnected_entities.py +++ b/backend/tournesol/views/unconnected_entities.py @@ -12,7 +12,7 @@ from rest_framework.permissions import IsAuthenticated from tournesol.models import Comparison, Entity -from tournesol.serializers.entity import EntityNoExtraFieldSerializer +from tournesol.serializers.unconnected_entities import UnconnectedEntitySerializer from tournesol.views.mixins.poll import PollScopedViewMixin @@ -56,10 +56,12 @@ class UnconnectedEntitiesView(PollScopedViewMixin, generics.ListAPIView): List unconnected entities. """ - serializer_class = EntityNoExtraFieldSerializer + serializer_class = UnconnectedEntitySerializer permission_classes = [IsAuthenticated] pagination_class = SortedEntityIdLimitOffsetPagination + enable_entity_contexts = True + def get_queryset(self): # Get related entities from source entity source_node = get_object_or_404(Entity, uid=self.kwargs.get("uid")) diff --git a/frontend/scripts/openapi.yaml b/frontend/scripts/openapi.yaml index b587f479da..18866429e1 100644 --- a/frontend/scripts/openapi.yaml +++ b/frontend/scripts/openapi.yaml @@ -2231,7 +2231,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/PaginatedEntityNoExtraFieldList' + $ref: '#/components/schemas/PaginatedUnconnectedEntityList' description: '' /users/me/vouchers/: post: @@ -3680,6 +3680,26 @@ components: type: array items: $ref: '#/components/schemas/TalkEntry' + PaginatedUnconnectedEntityList: + type: object + properties: + count: + type: integer + example: 123 + next: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?offset=400&limit=100 + previous: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?offset=200&limit=100 + results: + type: array + items: + $ref: '#/components/schemas/UnconnectedEntity' PaginatedVideoSerializerWithCriteriaList: type: object properties: @@ -4338,6 +4358,20 @@ components: description: |- * `video` - Video * `candidate_fr_2022` - Candidate (FR 2022) + UnconnectedEntity: + type: object + properties: + entity: + $ref: '#/components/schemas/EntityNoExtraField' + entity_contexts: + type: array + items: + $ref: '#/components/schemas/EntityContext' + readOnly: true + default: [] + required: + - entity + - entity_contexts UnsafeStatus: type: object properties: diff --git a/frontend/src/features/entity_selector/EntitySelectButton.tsx b/frontend/src/features/entity_selector/EntitySelectButton.tsx index d0b7ccceaa..f2fef64400 100644 --- a/frontend/src/features/entity_selector/EntitySelectButton.tsx +++ b/frontend/src/features/entity_selector/EntitySelectButton.tsx @@ -156,7 +156,7 @@ const VideoInput = ({ limit: 20, strict: false, }); - return (response.results ?? []).map((entity) => ({ entity })); + return response.results ?? []; }, disabled: !isLoggedIn || !otherUid, }, From 32148cd8d7b1981cc7b4d70c9a454a02c941503d Mon Sep 17 00:00:00 2001 From: Gresille&Siffle <39056254+GresilleSiffle@users.noreply.github.com> Date: Thu, 4 Jan 2024 15:26:04 +0100 Subject: [PATCH 2/4] [back] fix: add missing file --- .../serializers/unconnected_entities.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 backend/tournesol/serializers/unconnected_entities.py diff --git a/backend/tournesol/serializers/unconnected_entities.py b/backend/tournesol/serializers/unconnected_entities.py new file mode 100644 index 0000000000..134a1ab3d3 --- /dev/null +++ b/backend/tournesol/serializers/unconnected_entities.py @@ -0,0 +1,22 @@ +from rest_framework import serializers + +from tournesol.serializers.entity import EntityNoExtraFieldSerializer +from tournesol.serializers.entity_context import EntityContextSerializer + + +class UnconnectedEntitySerializer(serializers.Serializer): + entity = EntityNoExtraFieldSerializer(source="*") + entity_contexts = EntityContextSerializer(read_only=True, many=True, default=[]) + + def to_representation(self, instance): + ret = super().to_representation(instance) + + poll = self.context["poll"] + ent_contexts = self.context.get("entity_contexts") + + if ent_contexts is not None: + ret["entity_contexts"] = EntityContextSerializer( + poll.get_entity_contexts(ret["entity"]["metadata"], ent_contexts), many=True + ).data + + return ret From 3e00a31eaded6c1a56bce949ea8514e9ebdf19cd Mon Sep 17 00:00:00 2001 From: Gresille&Siffle <39056254+GresilleSiffle@users.noreply.github.com> Date: Thu, 4 Jan 2024 15:38:04 +0100 Subject: [PATCH 3/4] [back] refactor: comparisons API now uses ent contexts from PollScopedViewMixin --- backend/tournesol/serializers/comparison.py | 4 ++-- backend/tournesol/views/comparison.py | 13 +------------ 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/backend/tournesol/serializers/comparison.py b/backend/tournesol/serializers/comparison.py index 18e9f1ec04..20da7474da 100644 --- a/backend/tournesol/serializers/comparison.py +++ b/backend/tournesol/serializers/comparison.py @@ -89,7 +89,7 @@ def to_representation(self, instance): poll = self.context.get("poll") ent_contexts = self.context.get("entity_contexts") - if poll is not None: + if poll is not None and ent_contexts is not None: ret["entity_a_contexts"] = self.format_entity_contexts( poll, ent_contexts, ret["entity_a"]["metadata"] ) @@ -164,7 +164,7 @@ def to_representation(self, instance): poll = self.context.get("poll") ent_contexts = self.context.get("entity_contexts") - if poll is not None: + if poll is not None and ent_contexts is not None: ret["entity_a_contexts"] = self.format_entity_contexts( poll, ent_contexts, ret["entity_a"]["metadata"] ) diff --git a/backend/tournesol/views/comparison.py b/backend/tournesol/views/comparison.py index ce89d95aa2..9cd601420b 100644 --- a/backend/tournesol/views/comparison.py +++ b/backend/tournesol/views/comparison.py @@ -23,18 +23,7 @@ class InactivePollError(exceptions.PermissionDenied): class ComparisonApiMixin: """A mixin used to factorize behaviours common to all API views.""" - entity_contexts = None - - def initial(self, request, *args, **kwargs): - super().initial(request, *args, **kwargs) - self.entity_contexts = self.poll_from_url.all_entity_contexts.prefetch_related( - "texts" - ).all() - - def get_serializer_context(self): - context = super().get_serializer_context() - context["entity_contexts"] = self.entity_contexts - return context + enable_entity_contexts = True def comparison_already_exists(self, poll_id, request): """Return True if the comparison already exist, False instead.""" From c9ccc2c9263fb5a46aad0e364dc02604b82eb9b5 Mon Sep 17 00:00:00 2001 From: Gresille&Siffle <39056254+GresilleSiffle@users.noreply.github.com> Date: Thu, 4 Jan 2024 17:19:22 +0100 Subject: [PATCH 4/4] [back] tests: fix the broken tests of unconnected entities --- .../tests/test_api_unconnected_entities.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/backend/tournesol/tests/test_api_unconnected_entities.py b/backend/tournesol/tests/test_api_unconnected_entities.py index 52d46a3f98..b8f02812c8 100644 --- a/backend/tournesol/tests/test_api_unconnected_entities.py +++ b/backend/tournesol/tests/test_api_unconnected_entities.py @@ -264,13 +264,15 @@ def test_non_strict_must_return_sorted_by_max_distance(self): format="json", ) + results = response.data["results"] + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["count"], 2) # video 2 and 4 are both at distance 2 and should appear first, video 3 and 5 should # not appear because it has already been compared with the source entity. video 4 should # be first because it has a single comparison by the user against video 2 which has two. - self.assertEqual(response.data["results"][0]["uid"], self.video_4.uid) - self.assertEqual(response.data["results"][1]["uid"], self.video_2.uid) + self.assertEqual(results[0]["entity"]["uid"], self.video_4.uid) + self.assertEqual(results[1]["entity"]["uid"], self.video_2.uid) class ConnectedGraphReducibleDistanceTestCase(TestCase): @@ -331,11 +333,12 @@ def test_must_return_non_connected_entities(self): f"{self.user_base_url}/{self.video_source.uid}/", format="json", ) + results = response.data["results"] self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["count"], 2) self.assertEqual( - {entity["uid"] for entity in response.data["results"]}, + {res["entity"]["uid"] for res in results}, {entity.uid for entity in self.unrelated_video}, ) @@ -346,15 +349,16 @@ def test_non_strict_must_return_non_connected_then_connected_sorted_by_distance( f"{self.user_base_url}/{self.video_source.uid}/?strict=false", format="json", ) + results = response.data["results"] self.assertEqual(response.status_code, status.HTTP_200_OK) - # The response must contain firstly the two unconnected entities, and - # and then the 4 other entities sorted by decreasing distance to the - # source entity + # The response must contain firstly the two unconnected entities, and + # the 4 other entities sorted by decreasing distance to the source + # entity. self.assertEqual(response.data["count"], 5) - entities = [entity["uid"] for entity in response.data["results"]] - self.assertEqual(set(entities[:2]), {entity.uid for entity in self.unrelated_video}) - self.assertEqual([self.video_5.uid,self.video_4.uid,self.video_2.uid], entities[2:]) + uids = [res["entity"]["uid"] for res in results] + self.assertEqual(set(uids[:2]), {entity.uid for entity in self.unrelated_video}) + self.assertEqual([self.video_5.uid,self.video_4.uid,self.video_2.uid], uids[2:]) def test_non_connected_entities_ordering(self): """ @@ -387,18 +391,18 @@ def test_non_connected_entities_ordering(self): f"{self.user_base_url}/{self.video_source.uid}/", format="json", ) - results = response.data["results"] + # The first entity must be `video_a`, as it has 1 comparison. - self.assertEqual(results[0]["uid"], video_a.uid) + self.assertEqual(results[0]["entity"]["uid"], video_a.uid) # The 2nd and 3rd entities must be in the list of entities having 2 # comparisons. - self.assertIn(results[1]["uid"], [video_b.uid, self.unrelated_video[1].uid]) - self.assertIn(results[2]["uid"], [video_b.uid, self.unrelated_video[1].uid]) + self.assertIn(results[1]["entity"]["uid"], [video_b.uid, self.unrelated_video[1].uid]) + self.assertIn(results[2]["entity"]["uid"], [video_b.uid, self.unrelated_video[1].uid]) # The fourth entity must be `self.unrelated_video[0]` with 3 comparisons. - self.assertEqual(results[3]["uid"], self.unrelated_video[0].uid) + self.assertEqual(results[3]["entity"]["uid"], self.unrelated_video[0].uid) class TwoIsolatedGraphComponentsTestCase(TestCase): @@ -455,12 +459,13 @@ def test_must_return_non_connected_entities(self): f"{self.user_base_url}/{self.video_source.uid}/", format="json", ) + results = response.data["results"] self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["count"], 2) self.assertEqual( - {x["uid"] for x in response.data["results"]}, - {x.uid for x in self.unrelated_video}, + {res["entity"]["uid"] for res in results}, + {entity.uid for entity in self.unrelated_video}, ) @@ -469,7 +474,7 @@ def test_must_return_non_connected_entities(self): 100, 1000, 2000, - # Disabling the large test cases because they take a few minutes to run, but are useful for + # Disabling the large test cases because they take a few minutes to run, but are useful for # checking the performance of the api call. # 5000, # 10000,