Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing connection with multiple hosts and unavailable replicas #579

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions aiopg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import errno
import platform
import select
import socket
import sys
import traceback
import warnings
import weakref
import ctypes

import psycopg2
from psycopg2 import extras
Expand All @@ -28,6 +30,17 @@
# to OSError.errno EBADF
WSAENOTSOCK = 10038

# Connection status from psycopg2 (psycopg2/psycopg/connection.h)
CONN_STATUS_CONNECTING = 20

# In socket.socket we should know type and family to shutdown socket by fd
# This function is used for shutdown libpq connection
# where family and type is unknown
Comment on lines +36 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the problem with just using the socket object?

libc = ctypes.CDLL(None)
socket_shutdown = libc.shutdown
socket_shutdown.restypes = ctypes.c_int
socket_shutdown.argtypes = ctypes.c_int, ctypes.c_int


async def _enable_hstore(conn):
cur = await conn.cursor()
Expand Down Expand Up @@ -111,6 +124,11 @@ def __init__(self, dsn, loop, timeout, waiter, echo, **kwargs):
self._loop = loop
self._conn = psycopg2.connect(dsn, async_=True, **kwargs)
self._dsn = self._conn.dsn
self._dns_params = self._conn.get_dsn_parameters()
self._conn_timeout = self._dns_params.get('connect_timeout')
if self._conn_timeout:
self._conn_timeout = float(self._conn_timeout)

assert self._conn.isexecuting(), "Is conn an async at all???"
self._fileno = self._conn.fileno()
self._timeout = timeout
Expand All @@ -124,10 +142,26 @@ def __init__(self, dsn, loop, timeout, waiter, echo, **kwargs):
self._notifies = asyncio.Queue(loop=loop)
self._weakref = weakref.ref(self)
self._loop.add_reader(self._fileno, self._ready, self._weakref)
self._conn_timeout_handler = None
if self._conn_timeout:
self._conn_timeout_handler = self._loop.call_later(
self._conn_timeout, self._shutdown, self._weakref)

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

@staticmethod
def _shutdown(weak_self):
# Make sure that we won't get stuck in a blocking read or write
# inside poll, then give libpq a chance to try another host.
# If there is an error, we'll get it from poll.
self = weak_self()

if self._conn.status == CONN_STATUS_CONNECTING:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This occasionally fails due to self being None.

Minimal fix would be to skip the rest when that is the case,

but generally there shouldn't be active timers for a deleted object.

socket_shutdown(ctypes.c_int(self._fileno),
ctypes.c_int(socket.SHUT_RDWR))
self._ready(self._weakref)

@staticmethod
def _ready(weak_self):
self = weak_self()
Expand Down Expand Up @@ -172,6 +206,28 @@ def _ready(weak_self):
if waiter is not None and not waiter.done():
waiter.set_exception(
psycopg2.OperationalError("Connection closed"))

if self._conn.status == CONN_STATUS_CONNECTING \
and state != POLL_ERROR:
# libpq could close and open new connection to the next host
old_fileno = self._fileno
self._fileno = self._conn.fileno()

if self._conn_timeout_handler:
self._conn_timeout_handler.cancel()
self._conn_timeout_handler = self._loop.call_later(
self._conn_timeout, self._shutdown, self._weakref)

with contextlib.suppress(OSError):
# if we are using select selector
self._loop.remove_reader(old_fileno)
if self._writing:
self._loop.remove_writer(old_fileno)

self._loop.add_reader(self._fileno, self._ready, weak_self)
if self._writing:
self._loop.add_writer(self._fileno, self._ready, weak_self)

if state == POLL_OK:
if self._writing:
self._loop.remove_writer(self._fileno)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,71 @@ async def test_connect(connect):
assert not conn.echo


class TestMultipleHostsWithUnavailable:

@pytest.fixture
def pg_params(self, pg_params, pg_server, unused_port):
pg_params = pg_params.copy()
host = pg_params['host']
port = pg_params['port']

extra_host = "127.0.0.1"
extra_port = unused_port()

pg_params['host'] = '{extra_host},{host}'.format(
extra_host=extra_host, host=host)
pg_params['port'] = '{extra_port},{port}'.format(
extra_port=extra_port, port=port)
return pg_params

async def test_connect(self, connect):
# We should skip unavailable replica
conn = await connect()
assert isinstance(conn, Connection)
assert not conn._writing
assert conn._conn is conn.raw
assert not conn.echo


class TestMultipleHostsWithStuckConnection:
@pytest.yield_fixture
def stuck_server_port(self, unused_port):
# creates server which is not responding on SYN
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
port = unused_port()
s.bind(('127.0.0.1', port))
yield port
s.close()

@pytest.fixture
def pg_params(self, pg_params, pg_server, stuck_server_port):
pg_params = pg_params.copy()
host = pg_params['host']
port = pg_params['port']

extra_host = "127.0.0.1"
extra_port = stuck_server_port

pg_params['host'] = '{extra_host},{host}'.format(
extra_host=extra_host, host=host)
pg_params['port'] = '{extra_port},{port}'.format(
extra_port=extra_port, port=port)
pg_params['connect_timeout'] = 1
pg_params['timeout'] = 3

return pg_params

@pytest.mark.skipif(sys.platform != "linux",
reason='unstuck works only on linux')
async def test_connect(self, connect):
# We should skip unavailable replica
conn = await connect()
assert isinstance(conn, Connection)
assert not conn._writing
assert conn._conn is conn.raw
assert not conn.echo


async def test_simple_select(connect):
conn = await connect()
cur = await conn.cursor()
Expand Down