Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(API): only return proof.predictions in detail endpoint #605

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion open_prices/api/prices/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_queryset(self):
elif self.request.method in ["PATCH", "DELETE"]:
# only return prices owned by the current user
if self.request.user.is_authenticated:
return Price.objects.filter(owner=self.request.user.user_id)
return self.queryset.filter(owner=self.request.user.user_id)
return self.queryset

def get_serializer_class(self):
Expand Down
8 changes: 8 additions & 0 deletions open_prices/api/proofs/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class Meta:
exclude = ["location", "source"]


class ProofHalfFullSerializer(ProofSerializer):
location = LocationSerializer()

class Meta:
model = Proof
exclude = ["source"] # ProofSerializer.Meta.exclude


class ProofFullSerializer(ProofSerializer):
location = LocationSerializer()
predictions = ProofPredictionSerializer(many=True, read_only=True)
Expand Down
30 changes: 16 additions & 14 deletions open_prices/api/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,18 @@ def setUpTestData(cls):

def test_proof_list(self):
# anonymous
# thanks to select_related and prefetch_related, we only have 3
# queries:
# thanks to select_related, we only have 2 queries:
# - 1 to count the number of proofs of the user
# - 1 to get the proofs and their associated locations (select_related)
# - 1 to get the associated proof predictions (prefetch_related)
with self.assertNumQueries(3):
with self.assertNumQueries(2):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
data = response.data
self.assertEqual(data["total"], 3)
self.assertEqual(len(data["items"]), 3)
item = data["items"][0]
self.assertEqual(item["id"], self.proof.id) # default order
self.assertIn("predictions", item)
self.assertEqual(len(item["predictions"]), 1)
prediction = item["predictions"][0]
self.assertEqual(prediction["type"], self.proof_prediction.type)
self.assertEqual(prediction["model_name"], self.proof_prediction.model_name)
self.assertEqual(
prediction["model_version"], self.proof_prediction.model_version
)
self.assertNotIn("predictions", item) # not returned in "list"


class ProofListOrderApiTest(TestCase):
Expand All @@ -100,7 +91,7 @@ def setUpTestData(cls):
cls.proof = ProofFactory(
**PROOF, price_count=15, owner=cls.user_session.user.user_id
)
ProofFactory(price_count=0)
ProofFactory(type=proof_constants.TYPE_PRICE_TAG, price_count=0)
ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
price_count=50,
Expand All @@ -122,7 +113,7 @@ def setUpTestData(cls):
cls.proof = ProofFactory(
**PROOF, price_count=15, owner=cls.user_session.user.user_id
)
ProofFactory(price_count=0)
ProofFactory(type=proof_constants.TYPE_PRICE_TAG, price_count=0)
ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
price_count=50,
Expand Down Expand Up @@ -150,6 +141,9 @@ def setUpTestData(cls):
cls.proof = ProofFactory(
**PROOF, price_count=15, owner=cls.user_session_1.user.user_id
)
cls.proof_prediction = ProofPredictionFactory(
proof=cls.proof, type="CLASSIFICATION"
)
cls.url = reverse("api:proofs-detail", args=[cls.proof.id])

def test_proof_detail(self):
Expand All @@ -162,6 +156,14 @@ def test_proof_detail(self):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["id"], self.proof.id)
self.assertIn("predictions", response.data) # returned in "detail"
self.assertEqual(len(response.data["predictions"]), 1)
prediction = response.data["predictions"][0]
self.assertEqual(prediction["type"], self.proof_prediction.type)
self.assertEqual(prediction["model_name"], self.proof_prediction.model_name)
self.assertEqual(
prediction["model_version"], self.proof_prediction.model_version
)


class ProofCreateApiTest(TestCase):
Expand Down
18 changes: 9 additions & 9 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from open_prices.api.proofs.serializers import (
ProofCreateSerializer,
ProofFullSerializer,
ProofHalfFullSerializer,
ProofProcessWithGeminiSerializer,
ProofUpdateSerializer,
ProofUploadSerializer,
Expand Down Expand Up @@ -41,23 +42,22 @@ class ProofViewSet(
ordering = ["created"]

def get_queryset(self):
queryset = self.queryset
if self.request.method in ["GET"]:
# Select all proofs along with their locations using a select
# related query (1 single query)
# Then prefetch all the predictions related to the proof using
# a prefetch related query (only 1 query for all proofs)
return self.queryset.select_related("location").prefetch_related(
"predictions"
)
queryset = queryset.select_related("location")
if self.action == "retrieve":
queryset = queryset.prefetch_related("predictions")
elif self.request.method in ["PATCH", "DELETE"]:
# only return proofs owned by the current user
if self.request.user.is_authenticated:
return self.queryset.filter(owner=self.request.user.user_id)
return self.queryset
queryset = queryset.filter(owner=self.request.user.user_id)
return queryset

def get_serializer_class(self):
if self.request.method == "PATCH":
return ProofUpdateSerializer
elif self.action == "list":
return ProofHalfFullSerializer
return self.serializer_class

def destroy(self, request: Request, *args, **kwargs) -> Response:
Expand Down
Loading