Skip to content

Commit

Permalink
❇️ Sync with asyncpraw and use Niquests native websocket capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
Ousret committed Nov 14, 2024
1 parent 995854c commit 54f0a48
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 164 deletions.
31 changes: 20 additions & 11 deletions praw/models/reddit/subreddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from warnings import warn
from xml.etree.ElementTree import XML

import websocket
from niquests.exceptions import HTTPError, ReadTimeout, Timeout
from prawcore import Redirect
from prawcore.exceptions import ServerError
from requests.exceptions import HTTPError

from ...const import API_PATH, JPEG_HEADER
from ...exceptions import (
Expand Down Expand Up @@ -3067,30 +3066,40 @@ def _submit_media(
"""
response = self._reddit.post(API_PATH["submit"], data=data)
websocket_url = response["json"]["data"]["websocket_url"]
connection = None
ws_response = None
if websocket_url is not None and not without_websockets:
try:
connection = websocket.create_connection(websocket_url, timeout=timeout)
ws_response = self._reddit._core._requestor._http.get(
websocket_url,
timeout=timeout,
).raise_for_status()
except (
OSError,
websocket.WebSocketException,
BlockingIOError,
HTTPError,
Timeout,
) as ws_exception:
msg = "Error establishing websocket connection."
raise WebSocketException(msg, ws_exception) from None

if connection is None:
if ws_response is None:
return None

try:
ws_update = loads(connection.recv())
connection.close()
except (OSError, websocket.WebSocketException, BlockingIOError) as ws_exception:
ws_update = loads(ws_response.extension.next_payload())
except (
ReadTimeout,
HTTPError,
) as ws_exception:
msg = "Websocket error. Check your media file. Your post may still have been created."
raise WebSocketException(
msg,
ws_exception,
) from None
finally:
if (
ws_response.extension is not None
and ws_response.extension.closed is False
):
ws_response.extension.close()
if ws_update.get("type") == "failed":
raise MediaPostFailed
url = ws_update["payload"]["redirect"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
dependencies = [
"prawcore@git+https://github.com/Ousret/prawcore@feat-niquests",
"update_checker >=0.18",
"websocket-client >=0.54.0"
"niquests[ws]>=3.10,<4"
]
dynamic = ["version", "description"]
keywords = ["reddit", "api", "wrapper"]
Expand Down
204 changes: 56 additions & 148 deletions tests/integration/models/reddit/test_subreddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest import mock
from unittest.mock import MagicMock

import niquests
import pytest
import requests
import websocket
Expand Down Expand Up @@ -1223,6 +1224,10 @@ class WebsocketMock:
def make_dict(cls, post_id):
return {"payload": {"redirect": cls.POST_URL.format(post_id)}}

@property
def closed(self) -> bool:
return False

def __call__(self, *args, **kwargs):
return self

Expand All @@ -1233,14 +1238,27 @@ def __init__(self, *post_ids):
def close(self, *args, **kwargs):
pass

def recv(self):
def next_payload(self):
if not self.post_ids:
raise websocket.WebSocketTimeoutException()
raise niquests.ReadTimeout
assert 0 <= self.i + 1 < len(self.post_ids)
self.i += 1
return dumps(self.make_dict(self.post_ids[self.i]))


class ResponseWithWebSocketExtMock:

def __init__(self, fake_extension: WebsocketMock):
self.extension = fake_extension

@property
def status_code(self) -> int:
return 101

def raise_for_status(self):
return self


class WebsocketMockException:
def __init__(self, close_exc=None, recv_exc=None):
"""Initialize a WebsocketMockException.
Expand Down Expand Up @@ -1548,9 +1566,11 @@ def test_submit_gallery__flair(self, image_path, reddit):
assert submission.link_flair_text == flair_text

@mock.patch(
"websocket.create_connection",
"niquests.Session.get",
new=MagicMock(
return_value=WebsocketMock("16xb01r", "16xb06z", "16xb0aa")
return_value=ResponseWithWebSocketExtMock(
WebsocketMock("16xb01r", "16xb06z", "16xb0aa")
)
), # update with cassette
)
def test_submit_image(self, image_path, reddit):
Expand All @@ -1565,7 +1585,8 @@ def test_submit_image(self, image_path, reddit):

@pytest.mark.cassette_name("TestSubreddit.test_submit_image")
@mock.patch(
"websocket.create_connection", new=MagicMock(return_value=WebsocketMock())
"niquests.Session.get",
new=MagicMock(return_value=ResponseWithWebSocketExtMock(WebsocketMock())),
)
def test_submit_image__bad_websocket(self, image_path, reddit):
reddit.read_only = False
Expand All @@ -1576,8 +1597,10 @@ def test_submit_image__bad_websocket(self, image_path, reddit):
subreddit.submit_image("Test Title", image)

@mock.patch(
"websocket.create_connection",
new=MagicMock(return_value=WebsocketMock("ah3gqo")),
"niquests.Session.get",
new=MagicMock(
return_value=ResponseWithWebSocketExtMock(WebsocketMock("ah3gqo"))
),
) # update with cassette
def test_submit_image__flair(self, image_path, reddit):
flair_id = "6bd28436-1aa7-11e9-9902-0e05ab0fad46"
Expand Down Expand Up @@ -1628,7 +1651,7 @@ def patch_request(url, *args, **kwargs):
reddit._core._requestor._http.post = _post

@mock.patch(
"websocket.create_connection", new=MagicMock(side_effect=BlockingIOError)
"niquests.Session.get", new=MagicMock(side_effect=niquests.Timeout)
) # happens with timeout=0
@pytest.mark.cassette_name("TestSubreddit.test_submit_image")
def test_submit_image__timeout_1(self, image_path, reddit):
Expand All @@ -1638,69 +1661,6 @@ def test_submit_image__timeout_1(self, image_path, reddit):
with pytest.raises(WebSocketException):
subreddit.submit_image("Test Title", image)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
side_effect=socket.timeout
# happens with timeout=0.00001
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_image")
def test_submit_image__timeout_2(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
image = image_path("test.jpg")
with pytest.raises(WebSocketException):
subreddit.submit_image("Test Title", image)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
return_value=WebsocketMockException(
recv_exc=websocket.WebSocketTimeoutException()
), # happens with timeout=0.1
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_image")
def test_submit_image__timeout_3(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
image = image_path("test.jpg")
with pytest.raises(WebSocketException):
subreddit.submit_image("Test Title", image)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
return_value=WebsocketMockException(
close_exc=websocket.WebSocketTimeoutException()
), # could happen, and PRAW should handle it
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_image")
def test_submit_image__timeout_4(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
image = image_path("test.jpg")
with pytest.raises(WebSocketException):
subreddit.submit_image("Test Title", image)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
return_value=WebsocketMockException(
recv_exc=websocket.WebSocketConnectionClosedException()
), # from issue #1124
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_image")
def test_submit_image__timeout_5(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
image = image_path("test.jpg")
with pytest.raises(WebSocketException):
subreddit.submit_image("Test Title", image)

def test_submit_image__without_websockets(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
Expand All @@ -1712,8 +1672,10 @@ def test_submit_image__without_websockets(self, image_path, reddit):
assert submission is None

@mock.patch(
"websocket.create_connection",
new=MagicMock(return_value=WebsocketMock("k5s3b3")),
"niquests.Session.get",
new=MagicMock(
return_value=ResponseWithWebSocketExtMock(WebsocketMock("k5s3b3"))
),
) # update with cassette
def test_submit_image_chat(self, image_path, reddit):
reddit.read_only = False
Expand Down Expand Up @@ -1785,9 +1747,11 @@ def test_submit_poll__live_chat(self, reddit):
assert submission.discussion_type == "CHAT"

@mock.patch(
"websocket.create_connection",
"niquests.Session.get",
new=MagicMock(
return_value=WebsocketMock("k5rsq3", "k5rt9d"), # update with cassette
return_value=ResponseWithWebSocketExtMock(
WebsocketMock("k5rsq3", "k5rt9d")
), # update with cassette
),
)
def test_submit_video(self, image_path, reddit):
Expand All @@ -1802,7 +1766,8 @@ def test_submit_video(self, image_path, reddit):

@pytest.mark.cassette_name("TestSubreddit.test_submit_video")
@mock.patch(
"websocket.create_connection", new=MagicMock(return_value=WebsocketMock())
"niquests.Session.get",
new=MagicMock(return_value=ResponseWithWebSocketExtMock(WebsocketMock())),
)
def test_submit_video__bad_websocket(self, image_path, reddit):
reddit.read_only = False
Expand All @@ -1813,8 +1778,10 @@ def test_submit_video__bad_websocket(self, image_path, reddit):
subreddit.submit_video("Test Title", video)

@mock.patch(
"websocket.create_connection",
new=MagicMock(return_value=WebsocketMock("ahells")),
"niquests.Session.get",
new=MagicMock(
return_value=ResponseWithWebSocketExtMock(WebsocketMock("ahells"))
),
) # update with cassette
def test_submit_video__flair(self, image_path, reddit):
flair_id = "6bd28436-1aa7-11e9-9902-0e05ab0fad46"
Expand All @@ -1830,9 +1797,9 @@ def test_submit_video__flair(self, image_path, reddit):
assert submission.link_flair_text == flair_text

@mock.patch(
"websocket.create_connection",
"niquests.Session.get",
new=MagicMock(
return_value=WebsocketMock("k5rvt5", "k5rwbo")
return_value=ResponseWithWebSocketExtMock(WebsocketMock("k5rvt5", "k5rwbo"))
), # update with cassette
)
def test_submit_video__thumbnail(self, image_path, reddit):
Expand All @@ -1852,7 +1819,7 @@ def test_submit_video__thumbnail(self, image_path, reddit):
assert submission.title == "Test Title"

@mock.patch(
"websocket.create_connection", new=MagicMock(side_effect=BlockingIOError)
"niquests.Session.get", new=MagicMock(side_effect=niquests.Timeout)
) # happens with timeout=0
@pytest.mark.cassette_name("TestSubreddit.test_submit_video")
def test_submit_video__timeout_1(self, image_path, reddit):
Expand All @@ -1863,72 +1830,11 @@ def test_submit_video__timeout_1(self, image_path, reddit):
subreddit.submit_video("Test Title", video)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
side_effect=socket.timeout
# happens with timeout=0.00001
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_video")
def test_submit_video__timeout_2(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
video = image_path("test.mov")
with pytest.raises(WebSocketException):
subreddit.submit_video("Test Title", video)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
return_value=WebsocketMockException(
recv_exc=websocket.WebSocketTimeoutException()
), # happens with timeout=0.1
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_video")
def test_submit_video__timeout_3(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
video = image_path("test.mov")
with pytest.raises(WebSocketException):
subreddit.submit_video("Test Title", video)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
return_value=WebsocketMockException(
close_exc=websocket.WebSocketTimeoutException()
), # could happen, and PRAW should handle it
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_video")
def test_submit_video__timeout_4(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
video = image_path("test.mov")
with pytest.raises(WebSocketException):
subreddit.submit_video("Test Title", video)

@mock.patch(
"websocket.create_connection",
"niquests.Session.get",
new=MagicMock(
return_value=WebsocketMockException(
close_exc=websocket.WebSocketConnectionClosedException()
), # from issue #1124
),
)
@pytest.mark.cassette_name("TestSubreddit.test_submit_video")
def test_submit_video__timeout_5(self, image_path, reddit):
reddit.read_only = False
subreddit = reddit.subreddit(pytest.placeholders.test_subreddit)
video = image_path("test.mov")
with pytest.raises(WebSocketException):
subreddit.submit_video("Test Title", video)

@mock.patch(
"websocket.create_connection",
new=MagicMock(
return_value=WebsocketMock("k5s10u", "k5s11v"), # update with cassette
return_value=ResponseWithWebSocketExtMock(
WebsocketMock("k5s10u", "k5s11v")
), # update with cassette
),
)
def test_submit_video__videogif(self, image_path, reddit):
Expand All @@ -1952,8 +1858,10 @@ def test_submit_video__without_websockets(self, image_path, reddit):
assert submission is None

@mock.patch(
"websocket.create_connection",
new=MagicMock(return_value=WebsocketMock("flnyhf")),
"niquests.Session.get",
new=MagicMock(
return_value=ResponseWithWebSocketExtMock(WebsocketMock("flnyhf"))
),
) # update with cassette
def test_submit_video_chat(self, image_path, reddit):
reddit.read_only = False
Expand Down
Loading

0 comments on commit 54f0a48

Please sign in to comment.