From 5606503c45766130c399df00cd2ab19aaeb6f503 Mon Sep 17 00:00:00 2001 From: MithunR Date: Wed, 28 Aug 2024 21:43:51 -0700 Subject: [PATCH] Fix collection_ops_tests for Spark 4.0. Fixes #11011. This commit fixes the failures in `collection_ops_tests` on Spark 4.0. On all versions of Spark, when a Sequence is collected with rows that exceed MAX_INT, an exception is thrown indicating that the collected Sequence/array is larger than permissible. The different versions of Spark vary in the contents of the exception message. On Spark 4, one sees that the error message now contains more information than all prior versions, including: 1. The name of the op causing the error 2. The errant sequence size This commit introduces a shim to make this new information available in the exception. Note that this shim does not fit cleanly in RapidsErrorUtils, because there are differences within major Spark versions. For instance, Spark 3.4.0-1 have a different message as compared to 3.4.2 and 3.4.3. Likewise, the differences in 3.5.0, 3.5.1, 3.5.2. --- .../src/main/python/collection_ops_test.py | 11 ++++-- .../sql/rapids/collectionOperations.scala | 13 ++++--- .../spark/rapids/shims/GetSequenceSize.scala | 7 +++- .../spark/rapids/shims/GetSequenceSize.scala | 5 +-- .../sql/rapids/shims/SequenceSizeError.scala | 35 +++++++++++++++++++ .../sql/rapids/shims/SequenceSizeError.scala | 28 +++++++++++++++ 6 files changed, 89 insertions(+), 10 deletions(-) create mode 100644 sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala create mode 100644 sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 099eb28c0535..9731caba78b1 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error from data_gen import * from pyspark.sql.types import * + +from src.main.python.spark_session import is_before_spark_400 from string_test import mk_str_gen import pyspark.sql.functions as f import pyspark.sql.utils @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen): @pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn) @allow_non_gpu(*non_utc_allow) def test_sequence_too_long_sequence(stop_gen): - msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \ - or is_spark_350() else "Unsuccessful try to create array with" + msg = "Too long sequence" if is_before_spark_334() \ + or (not is_before_spark_340() and is_before_spark_342()) \ + or is_spark_350() \ + else "Can't create array" if not is_before_spark_400() \ + else "Unsuccessful try to create array with" assert_gpu_and_cpu_error( # To avoid OOM, reduce the row number to 1, it is enough to verify this case. lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 7543d113bfb6..51590bbde281 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import java.util.Optional import ai.rapids.cudf -import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table} +import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked @@ -1535,7 +1535,8 @@ object GpuSequenceUtil { def computeSequenceSize( start: ColumnVector, stop: ColumnVector, - step: ColumnVector): ColumnVector = { + step: ColumnVector, + functionName: String): ColumnVector = { checkSequenceInputs(start, stop, step) val actualSize = GetSequenceSize(start, stop, step) val sizeAsLong = withResource(actualSize) { _ => @@ -1557,7 +1558,11 @@ object GpuSequenceUtil { // check max size withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen => withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid => - require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE) + withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar => + require(isAllValidTrue(allValid), + GetSequenceSize.TOO_LONG_SEQUENCE(maxSizeScalar.getLong.asInstanceOf[Int], + functionName)) + } } } // cast to int and return @@ -1597,7 +1602,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr val steps = stepGpuColOpt.map(_.getBase.incRefCount()) .getOrElse(defaultStepsFunc(startCol, stopCol)) closeOnExcept(steps) { _ => - (computeSequenceSize(startCol, stopCol, steps), steps) + (computeSequenceSize(startCol, stopCol, steps, prettyName), steps) } } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala index 32ca03974bf1..e00aa26baad3 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala @@ -42,7 +42,12 @@ import com.nvidia.spark.rapids.Arm._ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH object GetSequenceSize { - val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH" + def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String) = { + // For these Spark versions, the sequence length and function name + // do not appear in the exception message. + s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH" + } + /** * Compute the size of each sequence according to 'start', 'stop' and 'step'. * A row (Row[start, stop, step]) contains at least one null element will produce diff --git a/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala b/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala index aba0f4654835..939566109334 100644 --- a/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala +++ b/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala @@ -28,11 +28,12 @@ import ai.rapids.cudf._ import com.nvidia.spark.rapids.Arm._ import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks} +import org.apache.spark.sql.rapids.shims.SequenceSizeError import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH object GetSequenceSize { - val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " + - s"size limit $MAX_ROUNDED_ARRAY_LENGTH" + def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String): String = + SequenceSizeError.getTooLongSequenceErrorString(sequenceLength, functionName) /** * Compute the size of each sequence according to 'start', 'stop' and 'step'. * A row (Row[start, stop, step]) contains at least one null element will produce diff --git a/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala new file mode 100644 index 000000000000..eadcc497318d --- /dev/null +++ b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "334"} +{"spark": "342"} +{"spark": "343"} +{"spark": "351"} +{"spark": "352"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + +object SequenceSizeError { + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + // The errant function's name does not feature in the exception message + // prior to Spark 4.0. Neither does the attempted allocation size. + "Unsuccessful try to create array with elements exceeding the array " + + s"size limit $MAX_ROUNDED_ARRAY_LENGTH" + } +} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala new file mode 100644 index 000000000000..aede65d02fa0 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.sql.errors.QueryExecutionErrors + +object SequenceSizeError { + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize).getMessage + } +}