diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 1728217d..927989ec 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -71,7 +71,7 @@ def test_bigint(trino_connection): def test_real(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS REAL)", python=None) \ - .add_field(sql="CAST('NaN' AS REAL)", python=math.nan) \ + .add_field(sql="CAST('NaN' AS REAL)", python=math.nan, has_nan=True) \ .add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \ .add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \ .add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \ @@ -82,7 +82,7 @@ def test_real(trino_connection): def test_double(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS DOUBLE)", python=None) \ - .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan) \ + .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan, has_nan=True) \ .add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \ .add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \ .add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \ @@ -747,11 +747,43 @@ def test_interval(trino_connection): def test_array(trino_connection): + # primitive types SqlTest(trino_connection) \ .add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \ - .add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \ + .add_field(sql="ARRAY[]", python=[]) \ + .add_field(sql="ARRAY[true, false, null]", python=[True, False, None]) \ + .add_field(sql="ARRAY[1, 2, null]", python=[1, 2, None]) \ + .add_field( + sql="ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), null]", + python=[math.nan, -math.inf, 3.4028235e+38, 1.4e-45, math.inf, None], + has_nan=True) \ + .add_field( + sql="ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null]", + python=[math.nan, -math.inf, 1.7976931348623157e+308, 5e-324, math.inf, None], + has_nan=True) \ .add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \ - .add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \ + .add_field(sql="ARRAY[CAST('hello' AS VARCHAR), null]", python=["hello", None]) \ + .add_field(sql="ARRAY[CAST('a' AS CHAR(3)), null]", python=['a ', None]) \ + .add_field(sql="ARRAY[X'', X'65683F', null]", python=[b'', b'eh?', None]) \ + .add_field(sql="ARRAY[JSON 'null', JSON '{}', null]", python=['null', '{}', None]) \ + .execute() + + # temporal types + SqlTest(trino_connection) \ + .add_field(sql="ARRAY[DATE '1970-01-01', null]", python=[date(1970, 1, 1), None]) \ + .add_field(sql="ARRAY[TIME '01:01:01', null]", python=[time(1, 1, 1), None]) \ + .add_field(sql="ARRAY[TIME '01:01:01 +05:30', null]", python=[time(1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \ + .add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01', null]", python=[datetime(1970, 1, 1, 1, 1, 1), None]) \ + .add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null]", python=[datetime(1970, 1, 1, 1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \ + .execute() + + # structural types + SqlTest(trino_connection) \ + .add_field(sql="ARRAY[ARRAY[1, null], ARRAY[2, 3], null]", python=[[1, None], [2, 3], None]) \ + .add_field( + sql="ARRAY[MAP(ARRAY['foo', 'bar', 'baz'], ARRAY['one', 'two', null]), MAP(), null]", + python=[{"foo": "one", "bar": "two", "baz": None}, {}, None]) \ + .add_field(sql="ARRAY[ROW(1, 2), ROW(1, null), null]", python=[(1, 2), (1, None), None]) \ .execute() @@ -806,30 +838,80 @@ class SqlTest: def __init__(self, trino_connection): self.cur = trino_connection.cursor(legacy_primitive_types=False) self.sql_args = [] - self.expected_result = [] + self.expected_results = [] + self.has_nan = [] - def add_field(self, sql, python): + def add_field(self, sql, python, has_nan=False): self.sql_args.append(sql) - self.expected_result.append(python) + self.expected_results.append(python) + self.has_nan.append(has_nan) return self def execute(self): sql = 'SELECT ' + ',\n'.join(self.sql_args) self.cur.execute(sql) - actual_result = self.cur.fetchall() - self._compare_results(actual_result[0], self.expected_result) - - def _compare_results(self, actual, expected): - assert len(actual) == len(expected) - - for idx, actual_val in enumerate(actual): - expected_val = expected[idx] - if type(actual_val) == float and math.isnan(actual_val) \ - and type(expected_val) == float and math.isnan(expected_val): - continue - - assert actual_val == expected_val + actual_results = self.cur.fetchall() + self._compare_results(actual_results[0], self.expected_results) + + def _are_equal_ignoring_nan(self, actual, expected) -> bool: + if isinstance(actual, float) and math.isnan(actual) \ + and isinstance(expected, float) and math.isnan(expected): + # Consider NaNs equal since we only want to make sure values round-trip + return True + return actual == expected + + def _compare_results(self, actual_results, expected_results): + assert len(actual_results) == len(expected_results) + + for idx, actual in enumerate(actual_results): + expected = expected_results[idx] + if not self.has_nan[idx]: + assert actual == expected + else: + # We need to consider NaNs in a collection equal since we only want to make sure values round-trip. + # collections compare identity first instead of value so: + # >>> from math import nan + # >>> [nan] == [nan] + # True + # >>> [nan] == [float("nan")] + # False + # >>> [float("nan")] == [float("nan")] + # False + # We create the NaNs using float("nan") which means PyTest's assert + # will always fail on collections containing nan. + if (isinstance(actual, list) and isinstance(expected, list)) \ + or (isinstance(actual, set) and isinstance(expected, set)) \ + or (isinstance(actual, tuple) and isinstance(expected, tuple)): + for i, _ in enumerate(actual): + if not self._are_equal_ignoring_nan(actual[i], expected[i]): + # Will fail, here to provide useful assertion message + assert actual == expected + elif isinstance(actual, dict) and isinstance(expected, dict): + for actual_key, actual_value in actual.items(): + # Note that Trino disallows multiple NaN keys in a MAP, so we don't consider the case where + # multiple NaN keys exist in either dict. + if math.isnan(actual_key): + expected_has_nan_key = False + for expected_key, expected_value in expected.items(): + if math.isnan(expected_key): + expected_has_nan_key = True + # Found the other NaN key. Let's compare the values from both dicts. + if not self._are_equal_ignoring_nan(actual_value, expected_value): + # Will fail, here to provide useful assertion message + assert actual == expected + # If expected has no NaN keys then the dicts cannot be equal since actual has a NaN key. + if not expected_has_nan_key: + # Will fail, here to provide useful assertion message + assert actual == expected + else: + if not self._are_equal_ignoring_nan(actual.get(actual_key), expected.get(actual_key)): + # Will fail, here to provide useful assertion message + assert actual == expected + else: + if not self._are_equal_ignoring_nan(actual, expected): + # Will fail, here to provide useful assertion message + assert actual == expected class SqlExpectFailureTest: