diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 6b6c1ebf828..0bbddb460fa 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -16525,7 +16525,7 @@ are limited. -PS
only a single character is allowed;
Literal value only
+PS
Literal value only
diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 6ca0e1a1967..0ed2c4cc41a 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -77,7 +77,9 @@ def test_split_positive_limit(): @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), - (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn) + (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^'), + (mk_str_gen('([XYZ]{0,3}XYZ?){0,5}'), 'XYZ'), + (mk_str_gen('([DEF]{0,3}DELIM?){0,5}'), 'DELIM')], ids=idfn) def test_substring_index(data_gen,delim): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).select( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 7fa18b2b782..9429df6b83b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2983,8 +2983,7 @@ object GpuOverrides extends Logging { "substring_index operator", ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), - ParamCheck("delim", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "only a single character is allowed"), TypeSig.STRING), + ParamCheck("delim", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)), expr[StringRepeat]( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index a435988686d..4e243c79736 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -22,11 +22,12 @@ import java.util.{Locale, Optional} import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar, Table} +import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.CastStrings +import com.nvidia.spark.rapids.jni.GpuSubstringIndexUtils import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} @@ -1584,63 +1585,16 @@ class SubstringIndexMeta( override val parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) extends TernaryExprMeta[SubstringIndex](expr, conf, parent, rule) { - private var regexp: String = _ - override def tagExprForGpu(): Unit = { - val delim = GpuOverrides.extractStringLit(expr.delimExpr).getOrElse("") - if (delim == null || delim.length != 1) { - willNotWorkOnGpu("only a single character deliminator is supported") - } - - val count = GpuOverrides.extractLit(expr.countExpr) - if (canThisBeReplaced) { - val c = count.get.value.asInstanceOf[Integer] - this.regexp = GpuSubstringIndex.makeExtractRe(delim, c) - } - } override def convertToGpu( column: Expression, delim: Expression, - count: Expression): GpuExpression = GpuSubstringIndex(column, this.regexp, delim, count) + count: Expression): GpuExpression = GpuSubstringIndex(column, delim, count) } -object GpuSubstringIndex { - def makeExtractRe(delim: String, count: Integer): String = { - if (delim.length != 1) { - throw new IllegalStateException("NOT SUPPORTED") - } - val quotedDelim = CudfRegexp.cudfQuote(delim.charAt(0)) - val notDelim = CudfRegexp.notCharSet(delim.charAt(0)) - // substring_index has a deliminator and a count. If the count is positive then - // you get back a substring from 0 until the Nth deliminator is found - // If the count is negative it goes in reverse - if (count == 0) { - // Count is zero so return a null regexp as a special case - null - } else if (count == 1) { - // If the count is 1 we want to match everything from the beginning of the string until we - // find the first occurrence of the deliminator or the end of the string - "\\A(" + notDelim + "*)" - } else if (count > 0) { - // If the count is > 1 we first match 0 up to count - 1 occurrences of the patten - // `not the deliminator 0 or more times followed by the deliminator` - // After that we go back to matching everything until we find the deliminator or the end of - // the string - "\\A((?:" + notDelim + "*" + quotedDelim + "){0," + (count - 1) + "}" + notDelim + "*)" - } else if (count == -1) { - // A -1 looks like 1 but we start looking at the end of the string - "(" + notDelim + "*)\\Z" - } else { //count < 0 - // All others look like a positive count, but again we are matching starting at the end of - // the string instead of the beginning - "((?:" + notDelim + "*" + quotedDelim + "){0," + ((-count) - 1) + "}" + notDelim + "*)\\Z" - } - } -} case class GpuSubstringIndex(strExpr: Expression, - regexp: String, ignoredDelimExpr: Expression, ignoredCountExpr: Expression) extends GpuTernaryExpressionArgsAnyScalarScalar with ImplicitCastInputTypes { @@ -1654,22 +1608,13 @@ case class GpuSubstringIndex(strExpr: Expression, override def prettyName: String = "substring_index" - // This is a bit hacked up at the moment. We are going to use a regular expression to extract - // a single value. It only works if the delim is a single character. A full version of - // substring_index for the GPU has been requested at https://github.com/rapidsai/cudf/issues/5158 - // spark-rapids plugin issue https://github.com/NVIDIA/spark-rapids/issues/8750 override def doColumnar(str: GpuColumnVector, delim: GpuScalar, count: GpuScalar): ColumnVector = { - if (regexp == null) { - withResource(str.getBase.isNull) { isNull => - withResource(Scalar.fromString("")) { emptyString => - isNull.ifElse(str.getBase, emptyString) - } - } + if (delim.isValid && count.isValid) { + GpuSubstringIndexUtils.substringIndex(str.getBase, delim.getBase, + count.getValue.asInstanceOf[Int]) } else { - withResource(str.getBase.extractRe(new RegexProgram(regexp))) { table: Table => - table.getColumn(0).incRefCount() - } + GpuColumnVector.columnVectorFromNull(str.getRowCount.toInt, StringType) } } diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala index ddaa58ff2f2..634d74ff272 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala @@ -74,7 +74,6 @@ class RapidsTestSettings extends BackendTestSettings { .exclude("SPARK-36229 conv should return result equal to -1 in base of toBase", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11142")) enableSuite[RapidsRegexpExpressionsSuite] enableSuite[RapidsStringExpressionsSuite] - .exclude("string substring_index function", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/8750")) .exclude("SPARK-22550: Elt should not generate codes beyond 64KB", WONT_FIX_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB", WONT_FIX_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775")) enableSuite[RapidsStringFunctionsSuite]