diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 099eb28c0535..9731caba78b1 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error from data_gen import * from pyspark.sql.types import * + +from src.main.python.spark_session import is_before_spark_400 from string_test import mk_str_gen import pyspark.sql.functions as f import pyspark.sql.utils @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen): @pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn) @allow_non_gpu(*non_utc_allow) def test_sequence_too_long_sequence(stop_gen): - msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \ - or is_spark_350() else "Unsuccessful try to create array with" + msg = "Too long sequence" if is_before_spark_334() \ + or (not is_before_spark_340() and is_before_spark_342()) \ + or is_spark_350() \ + else "Can't create array" if not is_before_spark_400() \ + else "Unsuccessful try to create array with" assert_gpu_and_cpu_error( # To avoid OOM, reduce the row number to 1, it is enough to verify this case. lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 7543d113bfb6..51590bbde281 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import java.util.Optional import ai.rapids.cudf -import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table} +import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked @@ -1535,7 +1535,8 @@ object GpuSequenceUtil { def computeSequenceSize( start: ColumnVector, stop: ColumnVector, - step: ColumnVector): ColumnVector = { + step: ColumnVector, + functionName: String): ColumnVector = { checkSequenceInputs(start, stop, step) val actualSize = GetSequenceSize(start, stop, step) val sizeAsLong = withResource(actualSize) { _ => @@ -1557,7 +1558,11 @@ object GpuSequenceUtil { // check max size withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen => withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid => - require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE) + withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar => + require(isAllValidTrue(allValid), + GetSequenceSize.TOO_LONG_SEQUENCE(maxSizeScalar.getLong.asInstanceOf[Int], + functionName)) + } } } // cast to int and return @@ -1597,7 +1602,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr val steps = stepGpuColOpt.map(_.getBase.incRefCount()) .getOrElse(defaultStepsFunc(startCol, stopCol)) closeOnExcept(steps) { _ => - (computeSequenceSize(startCol, stopCol, steps), steps) + (computeSequenceSize(startCol, stopCol, steps, prettyName), steps) } } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala index 32ca03974bf1..e00aa26baad3 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala @@ -42,7 +42,12 @@ import com.nvidia.spark.rapids.Arm._ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH object GetSequenceSize { - val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH" + def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String) = { + // For these Spark versions, the sequence length and function name + // do not appear in the exception message. + s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH" + } + /** * Compute the size of each sequence according to 'start', 'stop' and 'step'. * A row (Row[start, stop, step]) contains at least one null element will produce diff --git a/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala b/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala index aba0f4654835..939566109334 100644 --- a/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala +++ b/sql-plugin/src/main/spark334/scala/com/nvidia/spark/rapids/shims/GetSequenceSize.scala @@ -28,11 +28,12 @@ import ai.rapids.cudf._ import com.nvidia.spark.rapids.Arm._ import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks} +import org.apache.spark.sql.rapids.shims.SequenceSizeError import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH object GetSequenceSize { - val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " + - s"size limit $MAX_ROUNDED_ARRAY_LENGTH" + def TOO_LONG_SEQUENCE(sequenceLength: Int, functionName: String): String = + SequenceSizeError.getTooLongSequenceErrorString(sequenceLength, functionName) /** * Compute the size of each sequence according to 'start', 'stop' and 'step'. * A row (Row[start, stop, step]) contains at least one null element will produce diff --git a/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala new file mode 100644 index 000000000000..eadcc497318d --- /dev/null +++ b/sql-plugin/src/main/spark334/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "334"} +{"spark": "342"} +{"spark": "343"} +{"spark": "351"} +{"spark": "352"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + +object SequenceSizeError { + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + // The errant function's name does not feature in the exception message + // prior to Spark 4.0. Neither does the attempted allocation size. + "Unsuccessful try to create array with elements exceeding the array " + + s"size limit $MAX_ROUNDED_ARRAY_LENGTH" + } +} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala new file mode 100644 index 000000000000..aede65d02fa0 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeError.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.sql.errors.QueryExecutionErrors + +object SequenceSizeError { + def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = { + QueryExecutionErrors.createArrayWithElementsExceedLimitError(functionName, sequenceSize).getMessage + } +}