From b5a2f6f2ef917bc652b00f40522ba2868f8c4c54 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Mon, 28 Dec 2020 16:21:36 +0100 Subject: [PATCH] Adding support to Tornado 6.1 --- example/simple_event.py | 3 +- setup.py | 2 +- tests/stream_test.py | 8 --- tornado_eventsource/__init__.py | 2 +- tornado_eventsource/event_source_client.py | 67 +++++++++++----------- tornado_eventsource/handler.py | 21 ++++--- 6 files changed, 48 insertions(+), 55 deletions(-) diff --git a/example/simple_event.py b/example/simple_event.py index 61926bd..242a858 100644 --- a/example/simple_event.py +++ b/example/simple_event.py @@ -9,8 +9,7 @@ class MainHandler(tornado_eventsource.handler.EventSourceHandler): def open(self): - ioloop = tornado.ioloop.IOLoop.instance() - self.heart_beat = tornado.ioloop.PeriodicCallback(self._simple_callback, 5000, ioloop) + self.heart_beat = tornado.ioloop.PeriodicCallback(self._simple_callback, 5000) self.heart_beat.start() self._simple_callback() print('Connection open') diff --git a/setup.py b/setup.py index 43eb1c8..be5a40f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ include_package_data=True, zip_safe=False, install_requires=[ - 'tornado>=3.2.0,<4.0' + 'tornado>=6.1,<7.0' ], extras_require={ diff --git a/tests/stream_test.py b/tests/stream_test.py index 0702127..3fb46b8 100644 --- a/tests/stream_test.py +++ b/tests/stream_test.py @@ -11,14 +11,6 @@ class PostMessageTest(EventSourceTestCase): - def test_returns_405_if_post(self): - response = self.fetch('/', method='POST', body='a') - self.assertEqual(response.code, 405) - - def test_returns_405_if_put(self): - response = self.fetch('/', method='PUT', body='a') - self.assertEqual(response.code, 405) - @patch('tornado.iostream.IOStream.write') def test_stream_closed_error(self, write_mock): write_mock.side_effect = StreamClosedError diff --git a/tornado_eventsource/__init__.py b/tornado_eventsource/__init__.py index d81363e..55de1ad 100644 --- a/tornado_eventsource/__init__.py +++ b/tornado_eventsource/__init__.py @@ -1,4 +1,4 @@ #!/usr/bin/python # -*- coding: utf-8 -*- -__version__ = '2.0.0' +__version__ = '3.0.0rc1' diff --git a/tornado_eventsource/event_source_client.py b/tornado_eventsource/event_source_client.py index 914d443..66a463b 100644 --- a/tornado_eventsource/event_source_client.py +++ b/tornado_eventsource/event_source_client.py @@ -6,18 +6,15 @@ import collections import re import datetime +import asyncio from tornado import simple_httpclient from tornado.ioloop import IOLoop from tornado import httpclient, httputil -from tornado.concurrent import TracebackFuture -try: - from tornado.tcpclient import TCPClient - old_tornado = False -except ImportError: - old_tornado = True - from tornado.netutil import Resolver - from tornado.escape import native_str +from tornado.tcpclient import TCPClient + +from typing import cast + class EventSourceError(Exception): @@ -51,28 +48,29 @@ class EventSourceClient(simple_httpclient._HTTPConnection): """ This module opens a new connection to an eventsource server, and wait for events. """ - def __init__(self, io_loop, request): - self.connect_future = TracebackFuture() + def __init__(self, request): + self.connect_future = asyncio.Future() self.read_future = None self.read_queue = collections.deque() self.events = [] - if old_tornado: - self.resolver = Resolver(io_loop=io_loop) - super(EventSourceClient, self).__init__( - io_loop, None, request, lambda: None, self._on_http_response, - 104857600, self.resolver) - else: - self.tcp_client = TCPClient(io_loop=io_loop) - super(EventSourceClient, self).__init__( - io_loop, None, request, lambda: None, self._on_http_response, - 104857600, self.tcp_client, 65536) - - def _handle_event_stream(self): + self.tcp_client = TCPClient() + super().__init__( + None, + request, + lambda: None, + self._on_http_response, + 104857600, + self.tcp_client, + 65536, + 104857600, + ) + + async def _handle_event_stream(self): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None - self.stream.read_until_regex(b"\n\n", self.handle_stream) + self.handle_stream(await self.stream.read_until_regex(b"\n\n")) self.connect_future.set_result(self) def _on_http_response(self, response): @@ -93,7 +91,7 @@ def _on_headers(self, data): reason = match.group(2) self.headers_received(HeadersData(code=code, reason=reason), headers) - def headers_received(self, data, headers): + async def headers_received(self, data, headers): self.headers = headers self.code = data.code self.reason = data.reason @@ -109,7 +107,7 @@ def headers_received(self, data, headers): self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] - self._handle_event_stream() + await self._handle_event_stream() def handle_stream(self, message): """ @@ -150,15 +148,13 @@ def handle_stream(self, message): self.events.append(event) -def eventsource_connect(url, io_loop=None, callback=None, connect_timeout=None): +def eventsource_connect(url, callback=None, connect_timeout=None): """Client-side eventsource support. Takes a url and returns a Future whose result is a `EventSourceClient`. """ - if io_loop is None: - io_loop = IOLoop.current() if isinstance(url, httpclient.HTTPRequest): assert connect_timeout is None request = url @@ -167,15 +163,18 @@ def eventsource_connect(url, io_loop=None, callback=None, connect_timeout=None): request.headers = httputil.HTTPHeaders(request.headers) else: request = httpclient.HTTPRequest( - url, + url=url, connect_timeout=connect_timeout, headers=httputil.HTTPHeaders({ - "Accept-Encoding": "identity" + "Accept-Encoding": "identity", + "Connection": "keep-alive", }) ) - request = httpclient._RequestProxy( - request, httpclient.HTTPRequest._DEFAULTS) - conn = EventSourceClient(io_loop, request) + request = cast( + httpclient.HTTPRequest, + httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), + ) + conn = EventSourceClient(request) if callback is not None: - io_loop.add_future(conn.connect_future, callback) + IOLoop.current().add_future(conn.connect_future, callback) return conn.connect_future diff --git a/tornado_eventsource/handler.py b/tornado_eventsource/handler.py index c0eb44a..c22f1d5 100644 --- a/tornado_eventsource/handler.py +++ b/tornado_eventsource/handler.py @@ -1,6 +1,7 @@ #!/usr/bin/python # -*- coding: utf-8 -*- import time +import asyncio import tornado.web import tornado.gen as gen @@ -12,8 +13,7 @@ class EventSourceHandler(tornado.web.RequestHandler): def __init__(self, application, request, **kwargs): - tornado.web.RequestHandler.__init__(self, application, request, - **kwargs) + super().__init__(application, request, **kwargs) self.stream = request.connection.stream def error(self, status, msg=None): @@ -30,18 +30,22 @@ def set_default_headers(self): "Server": "TornadoServer/%s" % tornado.version, "Content-Type": "text/event-stream", "access-control-allow-origin": "*", - "connection": "keep-alive", + "Connection": "keep-alive", "Date": httputil.format_timestamp(time.time()), } default_headers.update(self.custom_headers()) self._headers = httputil.HTTPHeaders(default_headers) - @tornado.web.asynchronous - def get(self, *args, **kwargs): + async def wait_for_stream_close(self): + while not self.stream.closed(): + await asyncio.sleep(10) + + async def get(self, *args, **kwargs): self.check_connection() self.open(*args, **kwargs) - self.flush() + + await self.wait_for_stream_close() def open(self, *args, **kwargs): pass @@ -49,13 +53,12 @@ def open(self, *args, **kwargs): def close(self): pass - @gen.coroutine def _write(self, message): if self.stream.closed(): return try: self.write(message) - yield self.flush() + return self.flush() except StreamClosedError: logging.exception('Stream Closed') self.close() @@ -74,7 +77,7 @@ def write_message(self, name=None, msg=True, wait=None, evt_id=None): to_send += f"""\ndata: {msg}""" to_send += "\n\n" logging.debug(to_send) - self._write(to_send) + return self._write(to_send) def on_connection_close(self): self.stream.close()