Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for polling. #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions presto/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,43 +423,15 @@ def process(self, http_response):
)


class PrestoResult(object):
class PrestoQuery(object):
"""
Represent the result of a Presto query as an iterator on rows.
Represent the execution of a SQL statement by Presto.

This class implements the iterator protocol as a generator type
Results of the query can be extracted by iterating over this class, since it
implements the iterator protocol as a generator type
https://docs.python.org/3/library/stdtypes.html#generator-types
"""

def __init__(self, query, rows=None):
self._query = query
self._rows = rows or []
self._rownumber = 0

@property
def rownumber(self):
# type: () -> int
return self._rownumber

def __iter__(self):
# Initial fetch from the first POST request
for row in self._rows:
self._rownumber += 1
yield row
self._rows = None

# Subsequent fetches from GET requests until next_uri is empty.
while not self._query.is_finished():
rows = self._query.fetch()
for row in rows:
self._rownumber += 1
logger.debug("row {}".format(row))
yield row


class PrestoQuery(object):
"""Represent the execution of a SQL statement by Presto."""

def __init__(
self,
request, # type: PrestoRequest
Expand All @@ -476,7 +448,10 @@ def __init__(
self._cancelled = False
self._request = request
self._sql = sql
self._result = PrestoResult(self)

self._rownumber = 0
self._rows = []
self._rowsoffset = 0

@property
def columns(self):
Expand All @@ -490,10 +465,6 @@ def stats(self):
def warnings(self):
return self._warnings

@property
def result(self):
return self._result

def execute(self):
# type: () -> PrestoResult
"""Initiate a Presto query by sending the SQL statement
Expand All @@ -514,10 +485,10 @@ def execute(self):
self._warnings = getattr(status, "warnings", [])
if status.next_uri is None:
self._finished = True
self._result = PrestoResult(self, status.rows)
return self._result
self._rows = status.rows
return self

def fetch(self):
def _fetch(self):
# type: () -> List[List[Any]]
"""Continue fetching data for the current query_id"""
response = self._request.get(self._request.next_uri)
Expand All @@ -530,6 +501,14 @@ def fetch(self):
self._finished = True
return status.rows

def poll(self):
# type: () -> Dict
"""Retrieve the current status of a presto query, caching any results."""
if not self.query_id or self._finished:
return self.stats
self._rows.extend(self._fetch())
return self.stats

def cancel(self):
# type: () -> None
"""Cancel the current query"""
Expand All @@ -549,3 +528,14 @@ def cancel(self):
def is_finished(self):
# type: () -> bool
return self._finished

def __iter__(self):
while self._rows or not self.is_finished():
for row in self._rows[self._rowsoffset:]:
self._rownumber += 1
self._rowsoffset +=1
logger.debug('row {}'.format(row))
yield row
self._rows = []
self._rowsoffset = 0
self.poll()
30 changes: 16 additions & 14 deletions presto/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,18 @@ def warnings(self):
return self._query.warnings
return None

def poll(self):
return self._query.poll()

def setinputsizes(self, sizes):
raise presto.exceptions.NotSupportedError

def setoutputsize(self, size, column):
raise presto.exceptions.NotSupportedError

def execute(self, operation, params=None):
self._query = presto.client.PrestoQuery(self._request, sql=operation)
result = self._query.execute()
self._iterator = iter(result)
return result
self._query = presto.client.PrestoQuery(self._request, sql=operation).execute()
return self._query

def executemany(self, operation, seq_of_params):
raise presto.exceptions.NotSupportedError
Expand All @@ -250,13 +251,10 @@ def fetchone(self):
An Error (or subclass) exception is raised if the previous call to
.execute*() did not produce any result set or no call was issued yet.
"""

try:
return next(self._iterator)
except StopIteration:
result = self.fetchmany(1)
if len(result) != 1:
return None
except presto.exceptions.HttpError as err:
raise presto.exceptions.OperationalError(str(err))
return result[0]

def fetchmany(self, size=None):
# type: (Optional[int]) -> List[List[Any]]
Expand Down Expand Up @@ -284,16 +282,20 @@ def fetchmany(self, size=None):
size = self.arraysize

result = []
iterator = iter(self._query)

for _ in range(size):
row = self.fetchone()
if row is None:
try:
result.append(next(iterator))
except StopIteration:
break
result.append(row)
except prestodb.exceptions.HttpError as err:
raise prestodb.exceptions.OperationalError(str(err))

return result

def genall(self):
return self._query.result
return self._query

def fetchall(self):
# type: () -> List[List[Any]]
Expand Down