Skip to content

Commit

Permalink
Adding support to Tornado 6.1
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermef committed Dec 28, 2020
1 parent 5cff197 commit b5a2f6f
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 55 deletions.
3 changes: 1 addition & 2 deletions example/simple_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

class MainHandler(tornado_eventsource.handler.EventSourceHandler):
def open(self):
ioloop = tornado.ioloop.IOLoop.instance()
self.heart_beat = tornado.ioloop.PeriodicCallback(self._simple_callback, 5000, ioloop)
self.heart_beat = tornado.ioloop.PeriodicCallback(self._simple_callback, 5000)
self.heart_beat.start()
self._simple_callback()
print('Connection open')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
include_package_data=True,
zip_safe=False,
install_requires=[
'tornado>=3.2.0,<4.0'
'tornado>=6.1,<7.0'
],

extras_require={
Expand Down
8 changes: 0 additions & 8 deletions tests/stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@

class PostMessageTest(EventSourceTestCase):

def test_returns_405_if_post(self):
response = self.fetch('/', method='POST', body='a')
self.assertEqual(response.code, 405)

def test_returns_405_if_put(self):
response = self.fetch('/', method='PUT', body='a')
self.assertEqual(response.code, 405)

@patch('tornado.iostream.IOStream.write')
def test_stream_closed_error(self, write_mock):
write_mock.side_effect = StreamClosedError
Expand Down
2 changes: 1 addition & 1 deletion tornado_eventsource/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

__version__ = '2.0.0'
__version__ = '3.0.0rc1'
67 changes: 33 additions & 34 deletions tornado_eventsource/event_source_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
import collections
import re
import datetime
import asyncio

from tornado import simple_httpclient
from tornado.ioloop import IOLoop
from tornado import httpclient, httputil
from tornado.concurrent import TracebackFuture
try:
from tornado.tcpclient import TCPClient
old_tornado = False
except ImportError:
old_tornado = True
from tornado.netutil import Resolver
from tornado.escape import native_str
from tornado.tcpclient import TCPClient

from typing import cast



class EventSourceError(Exception):
Expand Down Expand Up @@ -51,28 +48,29 @@ class EventSourceClient(simple_httpclient._HTTPConnection):
"""
This module opens a new connection to an eventsource server, and wait for events.
"""
def __init__(self, io_loop, request):
self.connect_future = TracebackFuture()
def __init__(self, request):
self.connect_future = asyncio.Future()
self.read_future = None
self.read_queue = collections.deque()
self.events = []

if old_tornado:
self.resolver = Resolver(io_loop=io_loop)
super(EventSourceClient, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.resolver)
else:
self.tcp_client = TCPClient(io_loop=io_loop)
super(EventSourceClient, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.tcp_client, 65536)

def _handle_event_stream(self):
self.tcp_client = TCPClient()
super().__init__(
None,
request,
lambda: None,
self._on_http_response,
104857600,
self.tcp_client,
65536,
104857600,
)

async def _handle_event_stream(self):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
self.stream.read_until_regex(b"\n\n", self.handle_stream)
self.handle_stream(await self.stream.read_until_regex(b"\n\n"))
self.connect_future.set_result(self)

def _on_http_response(self, response):
Expand All @@ -93,7 +91,7 @@ def _on_headers(self, data):
reason = match.group(2)
self.headers_received(HeadersData(code=code, reason=reason), headers)

def headers_received(self, data, headers):
async def headers_received(self, data, headers):
self.headers = headers
self.code = data.code
self.reason = data.reason
Expand All @@ -109,7 +107,7 @@ def headers_received(self, data, headers):
self.headers["Content-Length"])
self.headers["Content-Length"] = pieces[0]

self._handle_event_stream()
await self._handle_event_stream()

def handle_stream(self, message):
"""
Expand Down Expand Up @@ -150,15 +148,13 @@ def handle_stream(self, message):
self.events.append(event)


def eventsource_connect(url, io_loop=None, callback=None, connect_timeout=None):
def eventsource_connect(url, callback=None, connect_timeout=None):
"""Client-side eventsource support.
Takes a url and returns a Future whose result is a
`EventSourceClient`.
"""
if io_loop is None:
io_loop = IOLoop.current()
if isinstance(url, httpclient.HTTPRequest):
assert connect_timeout is None
request = url
Expand All @@ -167,15 +163,18 @@ def eventsource_connect(url, io_loop=None, callback=None, connect_timeout=None):
request.headers = httputil.HTTPHeaders(request.headers)
else:
request = httpclient.HTTPRequest(
url,
url=url,
connect_timeout=connect_timeout,
headers=httputil.HTTPHeaders({
"Accept-Encoding": "identity"
"Accept-Encoding": "identity",
"Connection": "keep-alive",
})
)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = EventSourceClient(io_loop, request)
request = cast(
httpclient.HTTPRequest,
httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS),
)
conn = EventSourceClient(request)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
IOLoop.current().add_future(conn.connect_future, callback)
return conn.connect_future
21 changes: 12 additions & 9 deletions tornado_eventsource/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import asyncio

import tornado.web
import tornado.gen as gen
Expand All @@ -12,8 +13,7 @@
class EventSourceHandler(tornado.web.RequestHandler):

def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request,
**kwargs)
super().__init__(application, request, **kwargs)
self.stream = request.connection.stream

def error(self, status, msg=None):
Expand All @@ -30,32 +30,35 @@ def set_default_headers(self):
"Server": "TornadoServer/%s" % tornado.version,
"Content-Type": "text/event-stream",
"access-control-allow-origin": "*",
"connection": "keep-alive",
"Connection": "keep-alive",
"Date": httputil.format_timestamp(time.time()),
}
default_headers.update(self.custom_headers())
self._headers = httputil.HTTPHeaders(default_headers)

@tornado.web.asynchronous
def get(self, *args, **kwargs):
async def wait_for_stream_close(self):
while not self.stream.closed():
await asyncio.sleep(10)

async def get(self, *args, **kwargs):
self.check_connection()

self.open(*args, **kwargs)
self.flush()

await self.wait_for_stream_close()

def open(self, *args, **kwargs):
pass

def close(self):
pass

@gen.coroutine
def _write(self, message):
if self.stream.closed():
return
try:
self.write(message)
yield self.flush()
return self.flush()
except StreamClosedError:
logging.exception('Stream Closed')
self.close()
Expand All @@ -74,7 +77,7 @@ def write_message(self, name=None, msg=True, wait=None, evt_id=None):
to_send += f"""\ndata: {msg}"""
to_send += "\n\n"
logging.debug(to_send)
self._write(to_send)
return self._write(to_send)

def on_connection_close(self):
self.stream.close()
Expand Down

0 comments on commit b5a2f6f

Please sign in to comment.