diff --git a/praw/models/reddit/subreddit.py b/praw/models/reddit/subreddit.py index 643e809d6..326b3e783 100644 --- a/praw/models/reddit/subreddit.py +++ b/praw/models/reddit/subreddit.py @@ -13,10 +13,9 @@ from warnings import warn from xml.etree.ElementTree import XML -import websocket +from niquests.exceptions import HTTPError, ReadTimeout, Timeout from prawcore import Redirect from prawcore.exceptions import ServerError -from requests.exceptions import HTTPError from ...const import API_PATH, JPEG_HEADER from ...exceptions import ( @@ -3067,30 +3066,40 @@ def _submit_media( """ response = self._reddit.post(API_PATH["submit"], data=data) websocket_url = response["json"]["data"]["websocket_url"] - connection = None + ws_response = None if websocket_url is not None and not without_websockets: try: - connection = websocket.create_connection(websocket_url, timeout=timeout) + ws_response = self._reddit._core._requestor._http.get( + websocket_url, + timeout=timeout, + ).raise_for_status() except ( - OSError, - websocket.WebSocketException, - BlockingIOError, + HTTPError, + Timeout, ) as ws_exception: msg = "Error establishing websocket connection." raise WebSocketException(msg, ws_exception) from None - if connection is None: + if ws_response is None: return None try: - ws_update = loads(connection.recv()) - connection.close() - except (OSError, websocket.WebSocketException, BlockingIOError) as ws_exception: + ws_update = loads(ws_response.extension.next_payload()) + except ( + ReadTimeout, + HTTPError, + ) as ws_exception: msg = "Websocket error. Check your media file. Your post may still have been created." raise WebSocketException( msg, ws_exception, ) from None + finally: + if ( + ws_response.extension is not None + and ws_response.extension.closed is False + ): + ws_response.extension.close() if ws_update.get("type") == "failed": raise MediaPostFailed url = ws_update["payload"]["redirect"] diff --git a/pyproject.toml b/pyproject.toml index 271541237..760171e97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ dependencies = [ "prawcore@git+https://github.com/Ousret/prawcore@feat-niquests", "update_checker >=0.18", - "websocket-client >=0.54.0" + "niquests[ws]>=3.10,<4" ] dynamic = ["version", "description"] keywords = ["reddit", "api", "wrapper"] diff --git a/tests/integration/models/reddit/test_subreddit.py b/tests/integration/models/reddit/test_subreddit.py index f916d0409..98003c0a2 100644 --- a/tests/integration/models/reddit/test_subreddit.py +++ b/tests/integration/models/reddit/test_subreddit.py @@ -5,6 +5,7 @@ from unittest import mock from unittest.mock import MagicMock +import niquests import pytest import requests import websocket @@ -1223,6 +1224,10 @@ class WebsocketMock: def make_dict(cls, post_id): return {"payload": {"redirect": cls.POST_URL.format(post_id)}} + @property + def closed(self) -> bool: + return False + def __call__(self, *args, **kwargs): return self @@ -1233,14 +1238,27 @@ def __init__(self, *post_ids): def close(self, *args, **kwargs): pass - def recv(self): + def next_payload(self): if not self.post_ids: - raise websocket.WebSocketTimeoutException() + raise niquests.ReadTimeout assert 0 <= self.i + 1 < len(self.post_ids) self.i += 1 return dumps(self.make_dict(self.post_ids[self.i])) +class ResponseWithWebSocketExtMock: + + def __init__(self, fake_extension: WebsocketMock): + self.extension = fake_extension + + @property + def status_code(self) -> int: + return 101 + + def raise_for_status(self): + return self + + class WebsocketMockException: def __init__(self, close_exc=None, recv_exc=None): """Initialize a WebsocketMockException. @@ -1548,9 +1566,11 @@ def test_submit_gallery__flair(self, image_path, reddit): assert submission.link_flair_text == flair_text @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMock("16xb01r", "16xb06z", "16xb0aa") + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("16xb01r", "16xb06z", "16xb0aa") + ) ), # update with cassette ) def test_submit_image(self, image_path, reddit): @@ -1565,7 +1585,8 @@ def test_submit_image(self, image_path, reddit): @pytest.mark.cassette_name("TestSubreddit.test_submit_image") @mock.patch( - "websocket.create_connection", new=MagicMock(return_value=WebsocketMock()) + "niquests.Session.get", + new=MagicMock(return_value=ResponseWithWebSocketExtMock(WebsocketMock())), ) def test_submit_image__bad_websocket(self, image_path, reddit): reddit.read_only = False @@ -1576,8 +1597,10 @@ def test_submit_image__bad_websocket(self, image_path, reddit): subreddit.submit_image("Test Title", image) @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("ah3gqo")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("ah3gqo")) + ), ) # update with cassette def test_submit_image__flair(self, image_path, reddit): flair_id = "6bd28436-1aa7-11e9-9902-0e05ab0fad46" @@ -1628,7 +1651,7 @@ def patch_request(url, *args, **kwargs): reddit._core._requestor._http.post = _post @mock.patch( - "websocket.create_connection", new=MagicMock(side_effect=BlockingIOError) + "niquests.Session.get", new=MagicMock(side_effect=niquests.Timeout) ) # happens with timeout=0 @pytest.mark.cassette_name("TestSubreddit.test_submit_image") def test_submit_image__timeout_1(self, image_path, reddit): @@ -1638,69 +1661,6 @@ def test_submit_image__timeout_1(self, image_path, reddit): with pytest.raises(WebSocketException): subreddit.submit_image("Test Title", image) - @mock.patch( - "websocket.create_connection", - new=MagicMock( - side_effect=socket.timeout - # happens with timeout=0.00001 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_2(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - recv_exc=websocket.WebSocketTimeoutException() - ), # happens with timeout=0.1 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_3(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - close_exc=websocket.WebSocketTimeoutException() - ), # could happen, and PRAW should handle it - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_4(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - recv_exc=websocket.WebSocketConnectionClosedException() - ), # from issue #1124 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_5(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - def test_submit_image__without_websockets(self, image_path, reddit): reddit.read_only = False subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) @@ -1712,8 +1672,10 @@ def test_submit_image__without_websockets(self, image_path, reddit): assert submission is None @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("k5s3b3")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("k5s3b3")) + ), ) # update with cassette def test_submit_image_chat(self, image_path, reddit): reddit.read_only = False @@ -1785,9 +1747,11 @@ def test_submit_poll__live_chat(self, reddit): assert submission.discussion_type == "CHAT" @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMock("k5rsq3", "k5rt9d"), # update with cassette + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("k5rsq3", "k5rt9d") + ), # update with cassette ), ) def test_submit_video(self, image_path, reddit): @@ -1802,7 +1766,8 @@ def test_submit_video(self, image_path, reddit): @pytest.mark.cassette_name("TestSubreddit.test_submit_video") @mock.patch( - "websocket.create_connection", new=MagicMock(return_value=WebsocketMock()) + "niquests.Session.get", + new=MagicMock(return_value=ResponseWithWebSocketExtMock(WebsocketMock())), ) def test_submit_video__bad_websocket(self, image_path, reddit): reddit.read_only = False @@ -1813,8 +1778,10 @@ def test_submit_video__bad_websocket(self, image_path, reddit): subreddit.submit_video("Test Title", video) @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("ahells")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("ahells")) + ), ) # update with cassette def test_submit_video__flair(self, image_path, reddit): flair_id = "6bd28436-1aa7-11e9-9902-0e05ab0fad46" @@ -1830,9 +1797,9 @@ def test_submit_video__flair(self, image_path, reddit): assert submission.link_flair_text == flair_text @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMock("k5rvt5", "k5rwbo") + return_value=ResponseWithWebSocketExtMock(WebsocketMock("k5rvt5", "k5rwbo")) ), # update with cassette ) def test_submit_video__thumbnail(self, image_path, reddit): @@ -1852,7 +1819,7 @@ def test_submit_video__thumbnail(self, image_path, reddit): assert submission.title == "Test Title" @mock.patch( - "websocket.create_connection", new=MagicMock(side_effect=BlockingIOError) + "niquests.Session.get", new=MagicMock(side_effect=niquests.Timeout) ) # happens with timeout=0 @pytest.mark.cassette_name("TestSubreddit.test_submit_video") def test_submit_video__timeout_1(self, image_path, reddit): @@ -1863,72 +1830,11 @@ def test_submit_video__timeout_1(self, image_path, reddit): subreddit.submit_video("Test Title", video) @mock.patch( - "websocket.create_connection", - new=MagicMock( - side_effect=socket.timeout - # happens with timeout=0.00001 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_2(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - recv_exc=websocket.WebSocketTimeoutException() - ), # happens with timeout=0.1 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_3(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - close_exc=websocket.WebSocketTimeoutException() - ), # could happen, and PRAW should handle it - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_4(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMockException( - close_exc=websocket.WebSocketConnectionClosedException() - ), # from issue #1124 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_5(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMock("k5s10u", "k5s11v"), # update with cassette + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("k5s10u", "k5s11v") + ), # update with cassette ), ) def test_submit_video__videogif(self, image_path, reddit): @@ -1952,8 +1858,10 @@ def test_submit_video__without_websockets(self, image_path, reddit): assert submission is None @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("flnyhf")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("flnyhf")) + ), ) # update with cassette def test_submit_video_chat(self, image_path, reddit): reddit.read_only = False diff --git a/tests/unit/models/reddit/test_subreddit.py b/tests/unit/models/reddit/test_subreddit.py index 2259ab651..46e392e7c 100644 --- a/tests/unit/models/reddit/test_subreddit.py +++ b/tests/unit/models/reddit/test_subreddit.py @@ -1,6 +1,7 @@ import json import pickle from unittest import mock +from unittest.mock import MagicMock import pytest @@ -53,7 +54,7 @@ def test_hash(self, reddit): assert hash(subreddit2) != hash(subreddit3) assert hash(subreddit1) != hash(subreddit3) - @mock.patch("websocket.create_connection") + @mock.patch("niquests.Session.get") @mock.patch( "praw.models.Subreddit._upload_media", return_value=("fake_media_url", "fake_websocket_url"), @@ -64,14 +65,23 @@ def test_hash(self, reddit): def test_invalid_media( self, _mock_post, _mock_upload_media, connection_mock, reddit ): - connection_mock().recv.return_value = json.dumps( - {"payload": {}, "type": "failed"} + connection_mock.return_value = MagicMock( + status_code=101, + raise_for_status=MagicMock( + return_value=MagicMock( + extension=MagicMock( + next_payload=MagicMock( + return_value=json.dumps({"payload": {}, "type": "failed"}) + ) + ), + ) + ), ) with pytest.raises(MediaPostFailed): reddit.subreddit("test").submit_image("Test", "dummy path") @mock.patch("praw.models.Subreddit._read_and_post_media") - @mock.patch("websocket.create_connection") + @mock.patch("niquests.Session.get") @mock.patch( "praw.Reddit.post", return_value={