Skip to content

Commit

Permalink
Add more tests for array types
Browse files Browse the repository at this point in the history
  • Loading branch information
hashhar committed Jan 21, 2024
1 parent 24a04d4 commit 5746d5f
Showing 1 changed file with 102 additions and 20 deletions.
122 changes: 102 additions & 20 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_bigint(trino_connection):
def test_real(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS REAL)", python=None) \
.add_field(sql="CAST('NaN' AS REAL)", python=math.nan) \
.add_field(sql="CAST('NaN' AS REAL)", python=math.nan, has_nan=True) \
.add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \
.add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \
.add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \
Expand All @@ -82,7 +82,7 @@ def test_real(trino_connection):
def test_double(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS DOUBLE)", python=None) \
.add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan) \
.add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan, has_nan=True) \
.add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \
.add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \
.add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \
Expand Down Expand Up @@ -747,11 +747,43 @@ def test_interval(trino_connection):


def test_array(trino_connection):
# primitive types
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \
.add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \
.add_field(sql="ARRAY[]", python=[]) \
.add_field(sql="ARRAY[true, false, null]", python=[True, False, None]) \
.add_field(sql="ARRAY[1, 2, null]", python=[1, 2, None]) \
.add_field(
sql="ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), null]",
python=[math.nan, -math.inf, 3.4028235e+38, 1.4e-45, math.inf, None],
has_nan=True) \
.add_field(
sql="ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null]",
python=[math.nan, -math.inf, 1.7976931348623157e+308, 5e-324, math.inf, None],
has_nan=True) \
.add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \
.add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \
.add_field(sql="ARRAY[CAST('hello' AS VARCHAR), null]", python=["hello", None]) \
.add_field(sql="ARRAY[CAST('a' AS CHAR(3)), null]", python=['a ', None]) \
.add_field(sql="ARRAY[X'', X'65683F', null]", python=[b'', b'eh?', None]) \
.add_field(sql="ARRAY[JSON 'null', JSON '{}', null]", python=['null', '{}', None]) \
.execute()

# temporal types
SqlTest(trino_connection) \
.add_field(sql="ARRAY[DATE '1970-01-01', null]", python=[date(1970, 1, 1), None]) \
.add_field(sql="ARRAY[TIME '01:01:01', null]", python=[time(1, 1, 1), None]) \
.add_field(sql="ARRAY[TIME '01:01:01 +05:30', null]", python=[time(1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \
.add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01', null]", python=[datetime(1970, 1, 1, 1, 1, 1), None]) \
.add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null]", python=[datetime(1970, 1, 1, 1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \
.execute()

# structural types
SqlTest(trino_connection) \
.add_field(sql="ARRAY[ARRAY[1, null], ARRAY[2, 3], null]", python=[[1, None], [2, 3], None]) \
.add_field(
sql="ARRAY[MAP(ARRAY['foo', 'bar', 'baz'], ARRAY['one', 'two', null]), MAP(), null]",
python=[{"foo": "one", "bar": "two", "baz": None}, {}, None]) \
.add_field(sql="ARRAY[ROW(1, 2), ROW(1, null), null]", python=[(1, 2), (1, None), None]) \
.execute()


Expand Down Expand Up @@ -806,30 +838,80 @@ class SqlTest:
def __init__(self, trino_connection):
self.cur = trino_connection.cursor(legacy_primitive_types=False)
self.sql_args = []
self.expected_result = []
self.expected_results = []
self.has_nan = []

def add_field(self, sql, python):
def add_field(self, sql, python, has_nan=False):
self.sql_args.append(sql)
self.expected_result.append(python)
self.expected_results.append(python)
self.has_nan.append(has_nan)
return self

def execute(self):
sql = 'SELECT ' + ',\n'.join(self.sql_args)

self.cur.execute(sql)
actual_result = self.cur.fetchall()
self._compare_results(actual_result[0], self.expected_result)

def _compare_results(self, actual, expected):
assert len(actual) == len(expected)

for idx, actual_val in enumerate(actual):
expected_val = expected[idx]
if type(actual_val) == float and math.isnan(actual_val) \
and type(expected_val) == float and math.isnan(expected_val):
continue

assert actual_val == expected_val
actual_results = self.cur.fetchall()
self._compare_results(actual_results[0], self.expected_results)

def _are_equal_ignoring_nan(self, actual, expected) -> bool:
if isinstance(actual, float) and math.isnan(actual) \
and isinstance(expected, float) and math.isnan(expected):
# Consider NaNs equal since we only want to make sure values round-trip
return True
return actual == expected

def _compare_results(self, actual_results, expected_results):
assert len(actual_results) == len(expected_results)

for idx, actual in enumerate(actual_results):
expected = expected_results[idx]
if not self.has_nan[idx]:
assert actual == expected
else:
# We need to consider NaNs in a collection equal since we only want to make sure values round-trip.
# collections compare identity first instead of value so:
# >>> from math import nan
# >>> [nan] == [nan]
# True
# >>> [nan] == [float("nan")]
# False
# >>> [float("nan")] == [float("nan")]
# False
# We create the NaNs using float("nan") which means PyTest's assert
# will always fail on collections containing nan.
if (isinstance(actual, list) and isinstance(expected, list)) \
or (isinstance(actual, set) and isinstance(expected, set)) \
or (isinstance(actual, tuple) and isinstance(expected, tuple)):
for i, _ in enumerate(actual):
if not self._are_equal_ignoring_nan(actual[i], expected[i]):
# Will fail, here to provide useful assertion message
assert actual == expected
elif isinstance(actual, dict) and isinstance(expected, dict):
for actual_key, actual_value in actual.items():
# Note that Trino disallows multiple NaN keys in a MAP, so we don't consider the case where
# multiple NaN keys exist in either dict.
if math.isnan(actual_key):
expected_has_nan_key = False
for expected_key, expected_value in expected.items():
if math.isnan(expected_key):
expected_has_nan_key = True
# Found the other NaN key. Let's compare the values from both dicts.
if not self._are_equal_ignoring_nan(actual_value, expected_value):
# Will fail, here to provide useful assertion message
assert actual == expected
# If expected has no NaN keys then the dicts cannot be equal since actual has a NaN key.
if not expected_has_nan_key:
# Will fail, here to provide useful assertion message
assert actual == expected
else:
if not self._are_equal_ignoring_nan(actual.get(actual_key), expected.get(actual_key)):
# Will fail, here to provide useful assertion message
assert actual == expected
else:
if not self._are_equal_ignoring_nan(actual, expected):
# Will fail, here to provide useful assertion message
assert actual == expected


class SqlExpectFailureTest:
Expand Down

0 comments on commit 5746d5f

Please sign in to comment.