Skip to content

Commit

Permalink
Remove mediaDownloadHeaders and imageDownloadheaders from the body if…
Browse files Browse the repository at this point in the history
… they are None
  • Loading branch information
wanliAlex authored Oct 31, 2024
1 parent 5932044 commit fbf4d7b
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 15 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"tox"
],
name="marqo",
version="3.9.0",
version="3.9.1",
author="marqo org",
author_email="org@marqo.io",
description="Tensor search for humans",
Expand Down
19 changes: 8 additions & 11 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,12 @@ def embed(self, content: Union[Union[str, Dict[str, float]], List[Union[str, Dic
body = {
"content": content,
"content_type": content_type,
"mediaDownloadHeaders": media_download_headers,
}

if media_download_headers is not None:
body["mediaDownloadHeaders"] = media_download_headers
if image_download_headers is not None:
body["image_download_headers"] = image_download_headers
body["imageDownloadHeaders"] = image_download_headers
if model_auth is not None:
body["modelAuth"] = model_auth

Expand Down Expand Up @@ -532,13 +533,6 @@ def add_documents(
Returns:
Response body outlining indexing result
"""

if image_download_headers is None:
image_download_headers = dict()

if media_download_headers is None:
media_download_headers = dict()

return self._add_docs_organiser(
documents=documents,
client_batch_size=client_batch_size, device=device, tensor_fields=tensor_fields,
Expand Down Expand Up @@ -579,11 +573,14 @@ def _add_docs_organiser(

base_body = {
"useExistingTensors": use_existing_tensors,
"imageDownloadHeaders": image_download_headers,
"mediaDownloadHeaders": media_download_headers,
"mappings": mappings,
"modelAuth": model_auth,
}
if image_download_headers is not None:
base_body["imageDownloadHeaders"] = image_download_headers

if media_download_headers is not None:
base_body["mediaDownloadHeaders"] = media_download_headers

if tensor_fields is not None:
base_body['tensorFields'] = tensor_fields
Expand Down
35 changes: 33 additions & 2 deletions tests/v2_tests/test_add_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def run():
assert "device" not in kwargs0["path"]

assert kwargs0["body"]["useExistingTensors"] == False
assert kwargs0["body"]["imageDownloadHeaders"] == {}
assert kwargs0["body"]["mappings"] is None
assert kwargs0["body"]["modelAuth"] is None

Expand Down Expand Up @@ -804,4 +803,36 @@ def test_add_multimodal_field_document(self):

self.assertIn('_tensor_facets', doc['results'][0])
self.assertIn('_embedding', doc['results'][0]['_tensor_facets'][0])
self.assertEqual(len(doc['results'][0]['_tensor_facets'][0]['_embedding']), 768)
self.assertEqual(len(doc['results'][0]['_tensor_facets'][0]['_embedding']), 768)

def test_media_download_headers_is_not_included(self):
"""Ensure newly added attributes mediaDownloadHeaders is not included in the request body."""
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
self.client.index(index_name=self.generic_test_index_name).add_documents(
documents=["something"], tensor_fields=None)
args, kwargs = mock__post.call_args
self.assertNotIn("mediaDownloadHeaders", kwargs["body"])
self.assertNotIn("imageDownloadHeaders", kwargs["body"])
return True
run()

def test_media_download_headers_is_included_if_explicitly_set(self):
"""Ensure newly added attributes mediaDownloadHeaders is included if explicitly set."""
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
self.client.index(index_name=self.generic_test_index_name).add_documents(
documents=["something"], tensor_fields=None,
media_download_headers={"key": "value-1"},
image_download_headers={"key": "value-2"}
)
args, kwargs = mock__post.call_args
self.assertIn("imageDownloadHeaders", kwargs["body"])
self.assertEqual({"key": "value-2"}, kwargs["body"]["imageDownloadHeaders"])
self.assertEqual({"key": "value-1"}, kwargs["body"]["mediaDownloadHeaders"])
return True
run()
36 changes: 35 additions & 1 deletion tests/v2_tests/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

import numpy as np
from pytest import mark

Expand Down Expand Up @@ -184,4 +186,36 @@ def test_embed_non_numeric_weight_fails(self):
with self.assertRaises(MarqoWebError) as e:
self.client.index(test_index_name).embed(content={"text to embed": "not a number"})

self.assertIn("not a valid float", str(e.exception))
self.assertIn("not a valid float", str(e.exception))

def test_media_download_headers_is_not_included(self):
"""Ensure newly added attributes mediaDownloadHeaders is not included in the request body."""
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
self.client.index(index_name=self.generic_test_index_name).embed(
content=["something"])
args, kwargs = mock__post.call_args
self.assertNotIn("mediaDownloadHeaders", kwargs["body"])
self.assertNotIn("imageDownloadHeaders", kwargs["body"])
return True
run()

def test_media_download_headers_is_included_if_explicitly_set(self):
"""Ensure newly added attributes mediaDownloadHeaders is included if explicitly set."""
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
self.client.index(index_name=self.generic_test_index_name).embed(
content=["something"], media_download_headers={"key": "value-1"},
image_download_headers={"key": "value-2"}
)
args, kwargs = mock__post.call_args
self.assertIn("mediaDownloadHeaders", kwargs["body"])
self.assertEqual({"key": "value-2"}, kwargs["body"]["imageDownloadHeaders"])
self.assertEqual({"key": "value-1"}, kwargs["body"]["mediaDownloadHeaders"])
return True

run()
32 changes: 32 additions & 0 deletions tests/v2_tests/test_tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,35 @@ def test_special_characters(self):
assert list(search1_res['hits'][0]['_highlights'][0].keys()) == [field_to_search, ]
assert set(k for k in search1_res['hits'][0].keys() if not k.startswith('_')) == {field_to_not_search}

def test_media_download_headers_is_not_included(self):
"""Ensure newly added attributes mediaDownloadHeaders is not included in the request body."""
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
self.client.index(index_name=self.generic_test_index_name).search(
q="test")
args, kwargs = mock__post.call_args
self.assertNotIn("mediaDownloadHeaders", kwargs["body"])
self.assertNotIn("imageDownloadHeaders", kwargs["body"])
return True

run()

def test_media_download_headers_is_included_if_explicitly_set(self):
"""Ensure newly added attributes mediaDownloadHeaders is included if explicitly set."""
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
self.client.index(index_name=self.generic_test_index_name).search(
q="test", media_download_headers={"key": "value-1"},
image_download_headers={"key": "value-2"}
)
args, kwargs = mock__post.call_args
self.assertIn("mediaDownloadHeaders", kwargs["body"])
self.assertEqual({"key": "value-2"}, kwargs["body"]["imageDownloadHeaders"])
self.assertEqual({"key": "value-1"}, kwargs["body"]["mediaDownloadHeaders"])
return True

run()

0 comments on commit fbf4d7b

Please sign in to comment.