Skip to content

Commit

Permalink
Support ASGI servers for FastAPI (#2)
Browse files Browse the repository at this point in the history
* Add support for ASGI applications

* Update readme

* Bump version

* Linting
  • Loading branch information
banesullivan authored Jul 27, 2022
1 parent 32b2547 commit d37415b
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 8 deletions.
46 changes: 43 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![PyPI](https://img.shields.io/pypi/v/server-thread.svg?logo=python&logoColor=white)](https://pypi.org/project/server-thread/)
[![conda](https://img.shields.io/conda/vn/conda-forge/server-thread.svg?logo=conda-forge&logoColor=white)](https://anaconda.org/conda-forge/server-thread)

Launch a WSGIApplication in a background thread with werkzeug.
Launch a WSGI or ASGI Application in a background thread with werkzeug or uvicorn.

This application was created for [`localtileserver`](https://github.com/banesullivan/localtileserver)
and provides the basis for how it can launch an image tile server as a
Expand All @@ -17,9 +17,9 @@ Python packages I have created that require a background service.

## 🚀 Usage

Use the `ServerThread` with any WSGIApplication.
Use the `ServerThread` with any WSGI or ASGI Application.

Start by creating a WSGIApplication (this can be a flask app or a simple app
Start by creating a application (this can be a flask app or a simple app
like below):


Expand Down Expand Up @@ -80,3 +80,43 @@ If filing a bug report, please share a scooby `Report`:
import server_thread
print(server_thread.Report())
```


## 🚀 Examples

Minimal examples for using `server-thread` with common micro-frameworks.


### 💨 FastAPI

```py
from fastapi import FastAPI

app = FastAPI()


@app.get("/")
def root():
return {"message": "Howdy!"}


server = ServerThread(app)
requests.get(f"http://{server.host}:{server.port}/").json()
```

### ⚗️ Flask

```py
from flask import Flask

app = Flask("testapp")


@app.route("/")
def howdy():
return {"message": "Howdy!"}


server = ServerThread(app)
requests.get(f"http://{server.host}:{server.port}/").json()
```
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ requests
scooby
pytest
pytest-cov
fastapi
uvicorn
112 changes: 109 additions & 3 deletions server_thread/server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import logging
import os
import threading
import time
from typing import Union

import uvicorn
from werkzeug.serving import make_server

logger = logging.getLogger(__name__)


def is_fastapi(app):
try:
from fastapi import FastAPI

return isinstance(app, FastAPI)
except ImportError: # pragma: no cover
pass


class ServerDownError(Exception):
"""Raised when a ServerThread is down."""

Expand Down Expand Up @@ -61,16 +72,108 @@ def shutdown_server(key: int, force: bool = False):
logger.error(f"Server for key ({key}) not found.")


class ServerBase:
def __init__(self, app, host, port, debug: bool = False):
raise NotImplementedError # pragma: no cover

def shutdown(self):
raise NotImplementedError # pragma: no cover

def __del__(self):
self.shutdown()

@property
def port(self):
raise NotImplementedError # pragma: no cover

@property
def host(self):
raise NotImplementedError # pragma: no cover

@property
def serve_forever(self):
raise NotImplementedError # pragma: no cover


class WSGIServer(ServerBase):
"""Manager for WSGI applications."""

def __init__(self, app, host, port, debug: bool = False):
self.server = make_server(host, port, app, threaded=True, passthrough_errors=debug)

def shutdown(self):
self.server.shutdown()
self.server.server_close()

@property
def port(self):
return self.server.port

@property
def host(self):
return self.server.host

@property
def serve_forever(self):
return self.server.serve_forever


class ASGIServer(ServerBase):
"""Manager for ASGI applications."""

def __init__(self, app, host, port, debug: bool = False):
config = uvicorn.Config(
app, host=host, port=port, log_level="debug" if debug else "critical"
)
self.server = uvicorn.Server(config)

@property
def sock(self):
if self.server.started:
if (
hasattr(self.server, "servers")
and len(self.server.servers) # noqa: W503
and len(self.server.servers[0].sockets) # noqa: W503
):
return self.server.servers[0].sockets[0]
else:
raise ServerDownError("Server started, but no servers present")
else:
timeout = time.time() + 10
while not self.server.started:
if time.time() > timeout:
raise ServerDownError("Server not started")
return self.sock

def shutdown(self):
self.server.should_exit = True
self.server.handle_exit(0, None)
# await self.server.shutdown()

@property
def port(self):
return self.sock.getsockname()[1]

@property
def host(self):
return self.sock.getsockname()[0]

@property
def serve_forever(self):
return self.server.run


class ServerThread(threading.Thread):
"""Launch a server as a background thread."""

def __init__(
self,
app, # WSGIApplication
app, # WSGI or ASGI Application
port: int = 0,
debug: bool = False,
start: bool = True,
host: str = "127.0.0.1",
wsgi: bool = None,
):
self._lts_initialized = False
if not isinstance(port, int):
Expand All @@ -86,7 +189,11 @@ def __init__(

if os.name == "nt" and host == "127.0.0.1":
host = "localhost"
self.srv = make_server(host, port, app, threaded=True)

if (wsgi is not None and not wsgi) or (wsgi is None and is_fastapi(app)):
self.srv = ASGIServer(app, host, port, debug=debug)
else: # Fallback to WSGI
self.srv = WSGIServer(app, host, port, debug=debug)

if hasattr(app, "app_context"):
self.ctx = app.app_context()
Expand All @@ -104,7 +211,6 @@ def __init__(
def shutdown(self):
if self._lts_initialized and self.is_alive():
self.srv.shutdown()
self.srv.server_close()
self.join()

def __del__(self):
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
long_description = ""

# major, minor, patch
version_info = 0, 1, 0
version_info = 0, 2, 0
# Nice string for the version
__version__ = ".".join(map(str, version_info))

setup(
name="server-thread",
version=__version__,
description="Launch a WSGIApplication in a background thread with werkzeug.",
description="Launch a WSGI or ASGI Application in a background thread with werkzeug or uvicorn.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Bane Sullivan",
Expand All @@ -39,6 +39,7 @@
python_requires=">=3.7",
install_requires=[
"scooby",
"uvicorn",
"werkzeug",
],
)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from fastapi import FastAPI
from flask import Flask
import pytest

Expand All @@ -11,3 +12,14 @@ def howdy():
return "howdy!"

return app


@pytest.fixture
def fastapi_app():
app = FastAPI()

@app.get("/")
def root():
return {"message": "Howdy!"}

return app
22 changes: 22 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
def test_basic_flask_app(flask_app):
server = ServerThread(flask_app, debug=True)
requests.get(f"http://{server.host}:{server.port}/").raise_for_status()
server = ServerThread(flask_app, debug=True, wsgi=True)
requests.get(f"http://{server.host}:{server.port}/").raise_for_status()


def test_basic_fastapi_app(fastapi_app):
server = ServerThread(fastapi_app, debug=True, wsgi=False)
requests.get(f"http://{server.host}:{server.port}/").raise_for_status()


def test_fastapi_app_auto_detect(fastapi_app):
server = ServerThread(fastapi_app, debug=True)
requests.get(f"http://{server.host}:{server.port}/").raise_for_status()


def test_bad_port(flask_app):
Expand All @@ -22,3 +34,13 @@ def test_server_shutdown(flask_app):
del server
with pytest.raises(requests.ConnectionError):
requests.get(url).raise_for_status()


def test_server_shutdown_fastapi(fastapi_app):
server = ServerThread(fastapi_app, debug=True)
url = f"http://{server.host}:{server.port}/"
requests.get(url).raise_for_status()
server.shutdown()
del server
with pytest.raises(requests.ConnectionError):
requests.get(url).raise_for_status()

0 comments on commit d37415b

Please sign in to comment.