diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 6b6c1ebf828..0bbddb460fa 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -16525,7 +16525,7 @@ are limited.
|
|
|
-PS only a single character is allowed; Literal value only |
+PS Literal value only |
|
|
|
diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py
index 6ca0e1a1967..0ed2c4cc41a 100644
--- a/integration_tests/src/main/python/string_test.py
+++ b/integration_tests/src/main/python/string_test.py
@@ -77,7 +77,9 @@ def test_split_positive_limit():
@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
- (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
+ (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^'),
+ (mk_str_gen('([XYZ]{0,3}XYZ?){0,5}'), 'XYZ'),
+ (mk_str_gen('([DEF]{0,3}DELIM?){0,5}'), 'DELIM')], ids=idfn)
def test_substring_index(data_gen,delim):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 7fa18b2b782..9429df6b83b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2983,8 +2983,7 @@ object GpuOverrides extends Logging {
"substring_index operator",
ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
- ParamCheck("delim", TypeSig.lit(TypeEnum.STRING)
- .withPsNote(TypeEnum.STRING, "only a single character is allowed"), TypeSig.STRING),
+ ParamCheck("delim", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))),
(in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)),
expr[StringRepeat](
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
index a435988686d..4e243c79736 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
@@ -22,11 +22,12 @@ import java.util.{Locale, Optional}
import scala.collection.mutable.ArrayBuffer
-import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar, Table}
+import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.CastStrings
+import com.nvidia.spark.rapids.jni.GpuSubstringIndexUtils
import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}
@@ -1584,63 +1585,16 @@ class SubstringIndexMeta(
override val parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends TernaryExprMeta[SubstringIndex](expr, conf, parent, rule) {
- private var regexp: String = _
- override def tagExprForGpu(): Unit = {
- val delim = GpuOverrides.extractStringLit(expr.delimExpr).getOrElse("")
- if (delim == null || delim.length != 1) {
- willNotWorkOnGpu("only a single character deliminator is supported")
- }
-
- val count = GpuOverrides.extractLit(expr.countExpr)
- if (canThisBeReplaced) {
- val c = count.get.value.asInstanceOf[Integer]
- this.regexp = GpuSubstringIndex.makeExtractRe(delim, c)
- }
- }
override def convertToGpu(
column: Expression,
delim: Expression,
- count: Expression): GpuExpression = GpuSubstringIndex(column, this.regexp, delim, count)
+ count: Expression): GpuExpression = GpuSubstringIndex(column, delim, count)
}
-object GpuSubstringIndex {
- def makeExtractRe(delim: String, count: Integer): String = {
- if (delim.length != 1) {
- throw new IllegalStateException("NOT SUPPORTED")
- }
- val quotedDelim = CudfRegexp.cudfQuote(delim.charAt(0))
- val notDelim = CudfRegexp.notCharSet(delim.charAt(0))
- // substring_index has a deliminator and a count. If the count is positive then
- // you get back a substring from 0 until the Nth deliminator is found
- // If the count is negative it goes in reverse
- if (count == 0) {
- // Count is zero so return a null regexp as a special case
- null
- } else if (count == 1) {
- // If the count is 1 we want to match everything from the beginning of the string until we
- // find the first occurrence of the deliminator or the end of the string
- "\\A(" + notDelim + "*)"
- } else if (count > 0) {
- // If the count is > 1 we first match 0 up to count - 1 occurrences of the patten
- // `not the deliminator 0 or more times followed by the deliminator`
- // After that we go back to matching everything until we find the deliminator or the end of
- // the string
- "\\A((?:" + notDelim + "*" + quotedDelim + "){0," + (count - 1) + "}" + notDelim + "*)"
- } else if (count == -1) {
- // A -1 looks like 1 but we start looking at the end of the string
- "(" + notDelim + "*)\\Z"
- } else { //count < 0
- // All others look like a positive count, but again we are matching starting at the end of
- // the string instead of the beginning
- "((?:" + notDelim + "*" + quotedDelim + "){0," + ((-count) - 1) + "}" + notDelim + "*)\\Z"
- }
- }
-}
case class GpuSubstringIndex(strExpr: Expression,
- regexp: String,
ignoredDelimExpr: Expression,
ignoredCountExpr: Expression)
extends GpuTernaryExpressionArgsAnyScalarScalar with ImplicitCastInputTypes {
@@ -1654,22 +1608,13 @@ case class GpuSubstringIndex(strExpr: Expression,
override def prettyName: String = "substring_index"
- // This is a bit hacked up at the moment. We are going to use a regular expression to extract
- // a single value. It only works if the delim is a single character. A full version of
- // substring_index for the GPU has been requested at https://github.com/rapidsai/cudf/issues/5158
- // spark-rapids plugin issue https://github.com/NVIDIA/spark-rapids/issues/8750
override def doColumnar(str: GpuColumnVector, delim: GpuScalar,
count: GpuScalar): ColumnVector = {
- if (regexp == null) {
- withResource(str.getBase.isNull) { isNull =>
- withResource(Scalar.fromString("")) { emptyString =>
- isNull.ifElse(str.getBase, emptyString)
- }
- }
+ if (delim.isValid && count.isValid) {
+ GpuSubstringIndexUtils.substringIndex(str.getBase, delim.getBase,
+ count.getValue.asInstanceOf[Int])
} else {
- withResource(str.getBase.extractRe(new RegexProgram(regexp))) { table: Table =>
- table.getColumn(0).incRefCount()
- }
+ GpuColumnVector.columnVectorFromNull(str.getRowCount.toInt, StringType)
}
}
diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala
index ddaa58ff2f2..634d74ff272 100644
--- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala
+++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala
@@ -74,7 +74,6 @@ class RapidsTestSettings extends BackendTestSettings {
.exclude("SPARK-36229 conv should return result equal to -1 in base of toBase", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11142"))
enableSuite[RapidsRegexpExpressionsSuite]
enableSuite[RapidsStringExpressionsSuite]
- .exclude("string substring_index function", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/8750"))
.exclude("SPARK-22550: Elt should not generate codes beyond 64KB", WONT_FIX_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
.exclude("SPARK-22603: FormatString should not generate codes beyond 64KB", WONT_FIX_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10775"))
enableSuite[RapidsStringFunctionsSuite]