Skip to content

Commit

Permalink
Use the PID to terminate the session (#568)
Browse files Browse the repository at this point in the history
* The first element of the result is the PID
* Debug-level logging of high-level message + SQL
* Using redshift_connector `cursor.fetchone()`  returns `(<something>,)`
* Use cursor to call `select pg_terminate_backend({pid})` directly rather than using the `SQLConnectionManager`

---------

Co-authored-by: Mike Alfare <13974384+mikealfare@users.noreply.github.com>
  • Loading branch information
dbeatty10 and mikealfare authored Oct 11, 2023
1 parent 1116e47 commit 5d4f3f5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230807-174409.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Use the PID to terminate the session
time: 2023-08-07T17:44:09.15097-06:00
custom:
Author: dbeatty10
Issue: "553"
10 changes: 6 additions & 4 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,9 @@ class RedshiftConnectionManager(SQLConnectionManager):
def _get_backend_pid(self):
sql = "select pg_backend_pid()"
_, cursor = self.add_query(sql)

res = cursor.fetchone()
return res
return res[0]

def cancel(self, connection: Connection):
try:
Expand All @@ -253,9 +254,10 @@ def cancel(self, connection: Connection):
raise

sql = f"select pg_terminate_backend({pid})"
_, cursor = self.add_query(sql)
res = cursor.fetchone()
logger.debug(f"Cancel query '{connection.name}': {res}")
cursor = connection.handle.cursor()
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)
cursor.execute(sql)

@classmethod
def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse:
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,14 +471,13 @@ def test_cancel_open_connections_single(self):
with mock.patch.object(self.adapter.connections, "add_query") as add_query:
query_result = mock.MagicMock()
cursor = mock.Mock()
cursor.fetchone.return_value = 42
cursor.fetchone.return_value = (42,)
add_query.side_effect = [(None, cursor), (None, query_result)]

self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1)
add_query.assert_has_calls(
[
call("select pg_backend_pid()"),
call("select pg_terminate_backend(42)"),
]
)

Expand Down

0 comments on commit 5d4f3f5

Please sign in to comment.