Skip to content

Commit

Permalink
Allow dbt to cancel connections (#718)
Browse files Browse the repository at this point in the history
* Save backend_pid on initiation
* Adjust OnConfigurationChangeOption import
* Use MagicMock to accomodate __getitem__ call on cursor results
* Test backend_pid usage
* Store backend_pid on connection directly

---------

Co-authored-by: Mike Alfare <13974384+mikealfare@users.noreply.github.com>
Co-authored-by: Teresa Martyny <135771777+martynydbt@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 13, 2024
1 parent deb92e5 commit fdad756
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 32 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240326-123703.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: dbt can cancel open queries upon interrupt
time: 2024-03-26T12:37:03.17481-05:00
custom:
Author: holly-evans
Issue: "705"
29 changes: 15 additions & 14 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,27 +233,26 @@ def connect():
class RedshiftConnectionManager(SQLConnectionManager):
TYPE = "redshift"

def _get_backend_pid(self):
sql = "select pg_backend_pid()"
_, cursor = self.add_query(sql)

res = cursor.fetchone()
return res[0]

def cancel(self, connection: Connection):
pid = connection.backend_pid # type: ignore
sql = f"select pg_terminate_backend({pid})"
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)

try:
pid = self._get_backend_pid()
self.add_query(sql)
except redshift_connector.InterfaceError as e:
if "is closed" in str(e):
logger.debug(f"Connection {connection.name} was already closed")
return
raise

sql = f"select pg_terminate_backend({pid})"
cursor = connection.handle.cursor()
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)
cursor.execute(sql)
@classmethod
def _get_backend_pid(cls, connection):
with connection.handle.cursor() as c:
sql = "select pg_backend_pid()"
res = c.execute(sql).fetchone()
return res[0]

@classmethod
def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse:
Expand Down Expand Up @@ -325,14 +324,16 @@ def exponential_backoff(attempt: int):
redshift_connector.DataError,
]

return cls.retry_connection(
open_connection = cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
retryable_exceptions=retryable_exceptions,
)
open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore
return open_connection

def execute(
self,
Expand Down
76 changes: 58 additions & 18 deletions tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import mock

from dbt_common.exceptions import DbtRuntimeError
from unittest.mock import Mock, call
from unittest.mock import MagicMock, call

import agate
import dbt
Expand Down Expand Up @@ -67,7 +67,7 @@ def adapter(self):
inject_adapter(self._adapter, RedshiftPlugin)
return self._adapter

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_implicit_database_conn(self):
connection = self.adapter.acquire_connection("dummy")
connection.handle
Expand All @@ -84,7 +84,7 @@ def test_implicit_database_conn(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region_with_database_conn(self):
self.config.method = "database"

Expand All @@ -103,7 +103,7 @@ def test_explicit_region_with_database_conn(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_conn_without_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand All @@ -129,7 +129,7 @@ def test_explicit_iam_conn_without_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_conn_timeout_30(self):
self.config.credentials = self.config.credentials.replace(connect_timeout=30)
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -147,7 +147,7 @@ def test_conn_timeout_30(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_conn_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_explicit_iam_conn_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_serverless_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand All @@ -201,7 +201,7 @@ def test_explicit_iam_serverless_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_explicit_region(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region_failure(self):
# Failure test with no region
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_explicit_region_failure(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_invalid_region(self):
# Invalid region test
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_explicit_invalid_region(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_disable(self):
self.config.credentials.sslmode = "disable"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -306,7 +306,7 @@ def test_sslmode_disable(self):
sslmode=None,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_allow(self):
self.config.credentials.sslmode = "allow"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -325,7 +325,7 @@ def test_sslmode_allow(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_verify_full(self):
self.config.credentials.sslmode = "verify-full"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -344,7 +344,7 @@ def test_sslmode_verify_full(self):
sslmode="verify-full",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_verify_ca(self):
self.config.credentials.sslmode = "verify-ca"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -363,7 +363,7 @@ def test_sslmode_verify_ca(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_prefer(self):
self.config.credentials.sslmode = "prefer"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -382,7 +382,7 @@ def test_sslmode_prefer(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_serverless_iam_failure(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand Down Expand Up @@ -447,6 +447,25 @@ def test_invalid_iam_no_cluster_id(self):

self.assertTrue("'cluster_id' must be provided" in context.exception.msg)

@mock.patch("redshift_connector.connect", MagicMock())
def test_connection_has_backend_pid(self):
backend_pid = 42

cursor = mock.MagicMock()
execute = cursor().__enter__().execute
execute().fetchone.return_value = (backend_pid,)
redshift_connector.connect().cursor = cursor

connection = self.adapter.acquire_connection("dummy")
connection.handle
assert connection.backend_pid == backend_pid

execute.assert_has_calls(
[
call("select pg_backend_pid()"),
]
)

def test_cancel_open_connections_empty(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

Expand Down Expand Up @@ -475,11 +494,32 @@ def test_cancel_open_connections_single(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1)
add_query.assert_has_calls(
[
call("select pg_backend_pid()"),
call(f"select pg_terminate_backend({model.backend_pid})"),
]
)

master.handle.get_backend_pid.assert_not_called()
master.handle.backend_pid.assert_not_called()

@mock.patch("redshift_connector.connect", MagicMock())
def test_backend_pid_used_in_pg_terminate_backend(self):
with mock.patch.object(self.adapter.connections, "add_query") as add_query:
backend_pid = 42
query_result = (backend_pid,)

cursor = mock.MagicMock()
cursor().__enter__().execute().fetchone.return_value = query_result
redshift_connector.connect().cursor = cursor

connection = self.adapter.acquire_connection("dummy")
connection.handle

self.adapter.connections.cancel(connection)

add_query.assert_has_calls(
[
call(f"select pg_terminate_backend({backend_pid})"),
]
)

def test_dbname_verification_is_case_insensitive(self):
# Override adapter settings from setUp()
Expand Down

0 comments on commit fdad756

Please sign in to comment.